├── LICENSE ├── README.md ├── create_dataset.py ├── util.py ├── requirements.txt ├── da_algo.py ├── train_model.py ├── model.py ├── ot_util.py ├── dataset.py └── experiments.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yifei He 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Generative Gradual Domain Adaptation with Optimal Transport (GOAT) 2 | 3 | This is the official implementation for the algorithm **G**radual D**O**main **A**daptation with Optimal **T**ransport (GOAT) in the paper ["Gradual Domain Adaptation: Theory and Algorithms."](https://arxiv.org/abs/2310.13852) at JMLR. The algorithm design is motivated by our previous work, ["Understanding gradual domain adaptation: Improved analysis, optimal path and beyond"](https://arxiv.org/abs/2204.08200), published in ICML 2022. 4 | 5 | # Install the repo 6 | ``` 7 | git clone https://github.com/yifei-he/GOAT.git 8 | cd GOAT 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | # Prepare Data 13 | 14 | The covertype dataset can be downloaded from: https://archive.ics.uci.edu/dataset/31/covertype. 15 | 16 | The portraits dataset can be downloaded from [here](https://www.dropbox.com/s/ubjjoo0b2wz4vgz/faces_aligned_small_mirrored_co_aligned_cropped_cleaned.tar.gz?dl=0). We follow the same data preprocessing procedure from https://github.com/p-lambda/gradual_domain_adaptation. Namely after downloading, extract the tar file, and copy the "M" and "F" folders inside a folder called dataset_32x32 inside the current folder. Then run "python create_dataset.py". 17 | 18 | # Run Experiment 19 | To run experiments, follow the following syntax. 20 | ``` 21 | python experiments.py --dataset color_mnist --gt-domains 1 --generated-domains 2 22 | ``` 23 | Here, `dataset` can be selected from `[mnist, portraits, covtype, color_mnist]`; `gt-domains` and `generated-domains` are the number of given ground-truth intermediate domains (only available for the two MNSIT datasets) and domains generated by GOAT respectively, both default to be 0. 24 | 25 | # Citation 26 | 27 | ``` 28 | @article{JMLR:v25:23-1180, 29 | author = {Yifei He and Haoxiang Wang and Bo Li and Han Zhao}, 30 | title = {Gradual Domain Adaptation: Theory and Algorithms}, 31 | journal = {Journal of Machine Learning Research}, 32 | year = {2024}, 33 | volume = {25}, 34 | number = {361}, 35 | pages = {1--40}, 36 | url = {http://jmlr.org/papers/v25/23-1180.html} 37 | } 38 | ``` 39 | ``` 40 | @inproceedings{wang2022understanding, 41 | title={Understanding gradual domain adaptation: Improved analysis, optimal path and beyond}, 42 | author={Wang, Haoxiang and Li, Bo and Zhao, Han}, 43 | booktitle={International Conference on Machine Learning}, 44 | pages={22784--22801}, 45 | year={2022}, 46 | organization={PMLR} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /create_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | from shutil import copyfile 4 | import numpy as np 5 | import numpy as np 6 | import scipy.io 7 | import pickle 8 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 9 | 10 | 11 | image_options = { 12 | 'batch_size': 100, 13 | 'class_mode': 'binary', 14 | 'color_mode': 'grayscale', 15 | } 16 | 17 | 18 | def save_data(data_dir='dataset_32x32', save_file='dataset_32x32.mat', target_size=(32, 32)): 19 | Xs, Ys = [], [] 20 | datagen = ImageDataGenerator(rescale=1./255) 21 | data_generator = datagen.flow_from_directory( 22 | data_dir, shuffle=False, target_size=target_size, **image_options) 23 | while True: 24 | next_x, next_y = data_generator.next() 25 | Xs.append(next_x) 26 | Ys.append(next_y) 27 | if data_generator.batch_index == 0: 28 | break 29 | Xs = np.concatenate(Xs) 30 | Ys = np.concatenate(Ys) 31 | filenames = [f[2:] for f in data_generator.filenames] 32 | assert(len(set(filenames)) == len(filenames)) 33 | filenames_idx = list(zip(filenames, range(len(filenames)))) 34 | filenames_idx = [(f, i) for f, i in zip(filenames, range(len(filenames)))] 35 | # if f[5:8] == 'Cal' or f[5:8] == 'cal'] 36 | indices = [i for f, i in sorted(filenames_idx)] 37 | genders = np.array([f[:1] for f in data_generator.filenames])[indices] 38 | binary_genders = (genders == 'F') 39 | pickle.dump(binary_genders, open('portraits_gender_stats', "wb")) 40 | print("computed gender stats") 41 | # gender_stats = utils.rolling_average(binary_genders, 500) 42 | # print(filenames) 43 | # sort_indices = np.argsort(filenames) 44 | # We need to sort only by year, and not have correlation with state. 45 | # print state stats? print gender stats? print school stats? 46 | # E.g. if this changes a lot by year, then we might want to do some grouping. 47 | # Maybe print out number per year, and then we can decide on a grouping? Or algorithmically decide? 48 | Xs = Xs[indices] 49 | Ys = Ys[indices] 50 | scipy.io.savemat('./' + save_file, mdict={'Xs': Xs, 'Ys': Ys}) 51 | 52 | # Resize images. 53 | def resize(path, size=64): 54 | dirs = os.listdir(path) 55 | for item in dirs: 56 | if os.path.isfile(path+item): 57 | im = Image.open(path+item) 58 | f, e = os.path.splitext(path+item) 59 | imResize = im.resize((size,size), Image.ANTIALIAS) 60 | imResize.save(f + '.png', 'PNG') 61 | 62 | for folder in ['./dataset_32x32/M/', './dataset_32x32/F/']: 63 | resize(folder, size=32) 64 | 65 | save_data(data_dir='dataset_32x32', save_file='dataset_32x32.mat', target_size=(32,32)) -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from imghdr import tests 2 | from random import Random 3 | import torch 4 | from torchvision import datasets, transforms 5 | from torch.utils.data import DataLoader, random_split, ConcatDataset, Subset 6 | from dataset import * 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | from PIL import Image 12 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, ToPILImage, Pad, RandomRotation 13 | try: 14 | from torchvision.transforms import InterpolationMode 15 | BICUBIC = InterpolationMode.BICUBIC 16 | except ImportError: 17 | BICUBIC = Image.BICUBIC 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | 22 | def get_angles(step, target): 23 | angles = [0] 24 | while angles[-1] < target: 25 | angles.append(angles[-1] + step) 26 | 27 | return angles 28 | 29 | 30 | # obtain the combined dataset with all domains 31 | def get_rotated_dataset(raw_set, train, angles): 32 | total_set = [raw_set] 33 | for a in angles: 34 | total_set.append(get_single_rotate(train, a)) 35 | 36 | return ConcatDataset(total_set) 37 | 38 | 39 | def _convert_image_to_rgb(image): 40 | return image.convert("RGB") 41 | 42 | 43 | # obtain a single domain with a certain rotation angle 44 | def get_single_rotate(train, angle, dataset="mnist", encoder=None): 45 | 46 | transform = Compose([ToTensor(), RandomRotation((angle, angle))]) 47 | 48 | if dataset == "mnist": 49 | # uncomment the following line if MNIST is not downloaded 50 | # dataset = datasets.MNIST(root="/data/mnist/", train=train, download=True, transform=transform) 51 | dataset = datasets.MNIST(root="/data/common", train=train, download=False, transform=transform) 52 | 53 | if encoder is not None: 54 | dataset = get_encoded_dataset(encoder, dataset) 55 | 56 | return dataset 57 | 58 | 59 | def get_loaders(raw_trainset, raw_testset, batch_size): 60 | trainset = raw_trainset 61 | testset = raw_testset 62 | 63 | train_size = int(len(trainset) * 0.8) 64 | val_size = len(trainset) - train_size 65 | trains, valid = random_split(trainset, [train_size, val_size]) 66 | trainloader = DataLoader(trains, batch_size=batch_size, shuffle=True, num_workers=2) 67 | valloader = DataLoader(valid, batch_size=batch_size, shuffle=False, num_workers=2) 68 | testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) 69 | 70 | return trainloader, valloader, testloader 71 | 72 | 73 | def get_encoded_dataset(encoder, dataset): 74 | loader = DataLoader(dataset, batch_size=128, shuffle=True) 75 | 76 | latent, labels = [], [] 77 | with torch.no_grad(): 78 | for _, (data, label) in enumerate(loader): 79 | data = data.to(device) 80 | latent.append(encoder(data).cpu()) 81 | labels.append(label) 82 | 83 | latent = torch.cat(latent) 84 | labels = torch.cat(labels) 85 | 86 | encoded_dataset = EncodeDataset(latent.float().cpu().detach(), labels) 87 | 88 | return encoded_dataset 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | apturl==0.5.2 3 | astunparse==1.6.3 4 | bcrypt==3.2.0 5 | blessings==1.7 6 | blinker==1.4 7 | Brlapi==0.8.3 8 | cachetools==5.3.1 9 | certifi==2020.6.20 10 | chardet==4.0.0 11 | click==8.0.3 12 | colorama==0.4.4 13 | command-not-found==0.3 14 | contourpy==1.1.1 15 | cryptography==3.4.8 16 | cupshelpers==1.0 17 | cycler==0.12.1 18 | dbus-python==1.2.18 19 | defer==1.0.6 20 | distro==1.7.0 21 | distro-info===1.1build1 22 | docutils==0.17.1 23 | duplicity==0.8.21 24 | fasteners==0.14.1 25 | filelock==3.9.0 26 | flatbuffers==23.5.26 27 | fonttools==4.43.1 28 | fsspec==2023.4.0 29 | future==0.18.2 30 | gast==0.5.4 31 | google-auth==2.23.3 32 | google-auth-oauthlib==1.0.0 33 | google-pasta==0.2.0 34 | gpg===1.16.0-unknown 35 | grpcio==1.59.0 36 | h5py==3.10.0 37 | httplib2==0.20.2 38 | idna==3.3 39 | importlib-metadata==4.6.4 40 | jeepney==0.7.1 41 | Jinja2==3.1.2 42 | joblib==1.3.2 43 | keras==2.14.0 44 | keyring==23.5.0 45 | kiwisolver==1.4.5 46 | language-selector==0.1 47 | launchpadlib==1.10.16 48 | lazr.restfulclient==0.14.4 49 | lazr.uri==1.0.6 50 | libclang==16.0.6 51 | -e git+https://github.com/median-research-group/LibMTL@f10f7c9ffb72138a4ffae150330fb653da3b7456#egg=LibMTL 52 | lockfile==0.12.2 53 | louis==3.20.0 54 | macaroonbakery==1.3.1 55 | Mako==1.1.3 56 | Markdown==3.3.6 57 | MarkupSafe==2.1.3 58 | matplotlib==3.8.0 59 | ml-dtypes==0.2.0 60 | monotonic==1.6 61 | more-itertools==8.10.0 62 | mpmath==1.2.1 63 | netifaces==0.11.0 64 | networkx==3.0rc1 65 | numpy==1.26.1 66 | oauthlib==3.2.0 67 | olefile==0.46 68 | opt-einsum==3.3.0 69 | packaging==23.2 70 | pandas==2.1.1 71 | paramiko==2.9.3 72 | pexpect==4.8.0 73 | Pillow==9.0.1 74 | POT==0.9.1 75 | protobuf==4.24.4 76 | psutil==5.9.0 77 | ptyprocess==0.7.0 78 | pyasn1==0.5.0 79 | pyasn1-modules==0.3.0 80 | pycairo==1.20.1 81 | pycups==2.0.1 82 | Pygments==2.11.2 83 | PyGObject==3.42.1 84 | PyJWT==2.3.0 85 | pymacaroons==0.13.0 86 | PyNaCl==1.5.0 87 | pyparsing==2.4.7 88 | pyRFC3339==1.1 89 | python-apt==2.4.0+ubuntu1 90 | python-dateutil==2.8.2 91 | python-debian==0.1.43+ubuntu1.1 92 | pytorch-triton==2.1.0+6e4932cda8 93 | pytz==2022.1 94 | pyxdg==0.27 95 | PyYAML==5.4.1 96 | reportlab==3.6.8 97 | requests==2.25.1 98 | requests-oauthlib==1.3.1 99 | roman==3.3 100 | rsa==4.9 101 | scikit-learn==1.3.1 102 | scipy==1.11.3 103 | SecretStorage==3.3.1 104 | six==1.16.0 105 | sklearn==0.0.post10 106 | ssh-import-id==5.11 107 | sympy==1.11.1 108 | systemd-python==234 109 | tensorboard==2.14.1 110 | tensorboard-data-server==0.7.1 111 | tensorflow==2.14.0 112 | tensorflow-estimator==2.14.0 113 | tensorflow-io-gcs-filesystem==0.34.0 114 | termcolor==2.3.0 115 | threadpoolctl==3.2.0 116 | torch==2.2.0.dev20230922+cu121 117 | torchaudio==2.2.0.dev20230922+cu121 118 | torchvision==0.17.0.dev20230922+cu121 119 | trash-cli==0.17.1.14 120 | typing_extensions==4.4.0 121 | tzdata==2023.3 122 | ubuntu-advantage-tools==8001 123 | ubuntu-drivers-common==0.0.0 124 | ufw==0.36.1 125 | urllib3==1.26.5 126 | usb-creator==0.3.7 127 | wadllib==1.3.6 128 | Werkzeug==3.0.0 129 | wrapt==1.14.1 130 | xdg==5 131 | xkit==0.0.0 132 | zipp==1.0.0 133 | -------------------------------------------------------------------------------- /da_algo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, random_split, Subset 3 | import torch.optim as optim 4 | from train_model import * 5 | from util import * 6 | from dataset import * 7 | from ot_util import * 8 | from model import * 9 | import copy 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | def get_pseudo_labels(dataloader, model, confidence_q=0.1): 14 | logits = [] 15 | model.eval() 16 | with torch.no_grad(): 17 | for x in dataloader: 18 | if len(x) == 3: 19 | data, _, _ = x 20 | else: 21 | data, _ = x 22 | data = data.to(device) 23 | logits.append(model(data)) 24 | 25 | logits = torch.cat(logits) 26 | confidence = torch.max(logits, dim=1)[0] - torch.min(logits, dim=1)[0] 27 | alpha = torch.quantile(confidence, confidence_q) 28 | indices = torch.where(confidence >= alpha)[0].to("cpu") 29 | labels = torch.argmax(logits, axis=1) #[indices] 30 | 31 | return labels.cpu().detach().type(torch.int64), list(indices.detach().numpy()) 32 | 33 | 34 | def self_train(args, source_model, datasets, epochs=10): 35 | steps = len(datasets) 36 | teacher = source_model 37 | targetset = datasets[-1] 38 | 39 | targetloader = DataLoader(targetset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 40 | print("------------Direct adapt performance----------") 41 | direct_acc = test(targetloader, teacher) 42 | 43 | # start self-training on intermediate domains 44 | for i in range(steps): 45 | print(f"--------Training on the {i}th domain--------") 46 | trainset = datasets[i] 47 | ogloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 48 | 49 | test(targetloader, teacher) 50 | train_labs, train_idx = get_pseudo_labels(ogloader, teacher) 51 | 52 | if torch.is_tensor(trainset.data): 53 | data = trainset.data.cpu().detach().numpy() 54 | else: 55 | data = trainset.data 56 | trainset = EncodeDataset(data, train_labs, trainset.transform) 57 | 58 | # filter out the least 10% confident data 59 | filter_trainset = Subset(trainset, train_idx) 60 | print("Trainset size: " + str(len(filter_trainset))) 61 | 62 | trainloader = DataLoader(filter_trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 63 | 64 | # initialize and train student model 65 | student = copy.deepcopy(teacher) 66 | optimizer = optim.Adam(student.parameters(), lr=args.lr, weight_decay=1e-4) 67 | 68 | for i in range(1, epochs+1): 69 | train(i, trainloader, student, optimizer) 70 | if i % 5 == 0: 71 | test(targetloader, student) 72 | print("------------Performance on the current domain----------") 73 | test(trainloader, student) 74 | 75 | # test on the target domain 76 | print("------------Performance on the target domain----------") 77 | st_acc = test(targetloader, student) 78 | 79 | teacher = copy.deepcopy(student) 80 | 81 | return direct_acc, st_acc 82 | 83 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | 6 | 7 | # return reconstruction error + KL divergence losses 8 | def loss_function(recon_x, x, mu, log_var): 9 | BCE = F.binary_cross_entropy(recon_x, x, reduction='sum') 10 | KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) 11 | return BCE + KLD 12 | 13 | 14 | def calculate_modal_val_accuracy(model, valloader): 15 | model.eval() 16 | correct = 0. 17 | total = 0. 18 | 19 | with torch.no_grad(): 20 | for x in valloader: 21 | if len(x) == 3: 22 | images, labels, weight = x 23 | else: 24 | images, labels = x 25 | 26 | images, labels = images.to(device), labels.to(device) 27 | outputs = model(images) 28 | predicted = outputs.argmax(dim=1) 29 | total += labels.size(0) 30 | correct += (predicted == labels).sum() 31 | 32 | return 100 * correct / total 33 | 34 | 35 | def train(epoch, train_loader, model, optimizer, lr_scheduler=None, vae=False, verbose=True): 36 | model.train() 37 | train_loss = 0 38 | for _, x in enumerate(train_loader): 39 | if len(x) == 2: 40 | data, labels = x 41 | elif len(x) == 3: 42 | data, labels, weight = x 43 | weight = weight.to(device) 44 | 45 | data = data.to(device) 46 | labels = labels.to(device) 47 | optimizer.zero_grad() 48 | 49 | if vae: 50 | recon_batch, mu, log_var = model(data) 51 | loss = loss_function(recon_batch, data, mu, log_var) 52 | else: 53 | output = model(data) 54 | if len(x) == 2: 55 | loss = F.cross_entropy(output, labels) 56 | elif len(x) == 3: 57 | criterion = nn.CrossEntropyLoss(reduction='none') 58 | loss = criterion(output, labels) 59 | loss = (loss * weight).mean() 60 | 61 | loss.backward() 62 | train_loss += loss.item() 63 | optimizer.step() 64 | 65 | if lr_scheduler is not None: 66 | lr_scheduler.step() 67 | 68 | if verbose: 69 | print('====> Epoch: {} Average loss: {:.8f}'.format(epoch, train_loss / len(train_loader.dataset))) 70 | 71 | 72 | def test(val_loader, model, vae=False, verbose=True): 73 | model.eval() 74 | test_loss = 0 75 | correct = 0. 76 | total = 0. 77 | 78 | with torch.no_grad(): 79 | for x in val_loader: 80 | if len(x) == 2: 81 | data, labels = x 82 | elif len(x) == 3: 83 | data, labels, weight = x 84 | weight = weight.to(device) 85 | data = data.to(device) 86 | labels = labels.to(device) 87 | 88 | if vae: 89 | recon, mu, log_var = model(data) 90 | test_loss += loss_function(recon, data, mu, log_var).item() 91 | else: 92 | output = model(data) 93 | if len(x) == 2: 94 | criterion = nn.CrossEntropyLoss() 95 | test_loss += criterion(output, labels).item() 96 | elif len(x) == 3: 97 | criterion = nn.CrossEntropyLoss(reduction='none') 98 | loss = criterion(output, labels) 99 | test_loss += (loss * weight).mean().item() 100 | 101 | predicted = output.argmax(dim=1) 102 | total += labels.size(0) 103 | correct += (predicted == labels).sum() 104 | 105 | test_loss /= len(val_loader.dataset) 106 | val_accuracy = 100 * correct / total 107 | val_accuracy = val_accuracy.item() 108 | if verbose: 109 | print('====> Test loss: {:.8f}'.format(test_loss)) 110 | if not vae: 111 | print('====> Test Accuracy %.4f' % (val_accuracy)) 112 | 113 | return val_accuracy 114 | 115 | 116 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Encoder(nn.Module): 4 | def __init__(self, backbone): 5 | super(Encoder, self).__init__() 6 | 7 | # self.backbone = backbone 8 | self.backbone = nn.Sequential( 9 | backbone, 10 | nn.AdaptiveAvgPool2d(output_size=(1, 1)), 11 | nn.Flatten() 12 | ) 13 | 14 | def forward(self, x): 15 | x = self.backbone(x) 16 | 17 | return x 18 | 19 | 20 | class Classifier(nn.Module): 21 | def __init__(self, backbone, hdim=512, n_class=10, reg=True): 22 | super(Classifier, self).__init__() 23 | 24 | # self.backbone = backbone 25 | self.backbone = nn.Sequential( 26 | backbone, 27 | nn.AdaptiveAvgPool2d(output_size=(1, 1)), 28 | nn.Flatten() 29 | ) 30 | self.predict = nn.Sequential( 31 | nn.Linear(backbone.out_features, hdim), 32 | nn.BatchNorm1d(hdim), 33 | nn.ReLU(), 34 | nn.Linear(hdim, n_class) 35 | ) 36 | 37 | def forward(self, x): 38 | x = self.backbone(x) 39 | x = self.predict(x) 40 | 41 | return x 42 | 43 | def get_parameters(self, base_lr=1.0): 44 | """A parameter list which decides optimization hyper-parameters, 45 | such as the relative learning rate of each layer 46 | """ 47 | params = [ 48 | {"params": self.backbone.parameters(), "lr": 0.1 * base_lr}, 49 | {"params": self.predict.parameters(), "lr": 1.0 * base_lr} 50 | ] 51 | 52 | return params 53 | 54 | 55 | class ENCODER(nn.Module): 56 | def __init__(self, rgb=False, resnet=False): 57 | super(ENCODER, self).__init__() 58 | 59 | if rgb: 60 | self.encode = nn.Sequential( 61 | nn.Conv2d(3, 32, 3, padding="same"), 62 | nn.ReLU(), 63 | nn.Conv2d(32, 32, 3, padding="same"), 64 | nn.ReLU(), 65 | nn.Conv2d(32, 32, 3, padding="same"), 66 | nn.ReLU(), 67 | ) 68 | else: 69 | self.encode = nn.Sequential( 70 | nn.Conv2d(1, 32, 3, padding="same"), 71 | nn.ReLU(), 72 | nn.Conv2d(32, 32, 3, padding="same"), 73 | nn.ReLU(), 74 | nn.Conv2d(32, 32, 3, padding="same"), 75 | nn.ReLU(), 76 | ) 77 | 78 | 79 | def forward(self, x): 80 | x = self.encode(x) 81 | return x 82 | 83 | 84 | class MLP(nn.Module): 85 | def __init__(self, mode, n_class, hidden=1024): 86 | super(MLP, self).__init__() 87 | 88 | if mode == "mnist": 89 | dim = 25088 90 | elif mode == "portraits": 91 | dim = 32768 92 | else: 93 | dim = 2048 94 | 95 | if mode == "covtype": 96 | hidden = 256 97 | self.mlp = nn.Sequential( 98 | nn.Linear(54, hidden), 99 | nn.ReLU(), 100 | nn.Linear(hidden, hidden), 101 | nn.ReLU(), 102 | nn.Linear(hidden, hidden), 103 | nn.ReLU(), 104 | nn.Dropout(0.5), 105 | nn.BatchNorm1d(hidden), 106 | nn.Linear(hidden, n_class) 107 | ) 108 | else: 109 | hidden = 128 110 | self.mlp = nn.Sequential( 111 | # nn.BatchNorm2d(32), 112 | nn.Flatten(), 113 | # nn.Linear(dim, n_class), 114 | nn.Linear(dim, hidden), 115 | nn.ReLU(), 116 | nn.Linear(hidden, hidden), 117 | nn.ReLU(), 118 | nn.Dropout(0.5), 119 | nn.BatchNorm1d(hidden), 120 | nn.Linear(hidden, n_class) 121 | ) 122 | 123 | def forward(self, x): 124 | return self.mlp(x) 125 | 126 | 127 | class Classifier(nn.Module): 128 | def __init__(self, encoder, mlp): 129 | super(Classifier, self).__init__() 130 | 131 | self.encoder = encoder 132 | self.mlp = mlp 133 | 134 | def forward(self, x): 135 | x = self.encoder(x) 136 | return self.mlp(x) 137 | 138 | 139 | class MLP_Encoder(nn.Module): 140 | def __init__(self, hidden=256): 141 | super(MLP_Encoder, self).__init__() 142 | 143 | self.encode = nn.Sequential( 144 | ) 145 | 146 | def forward(self, x): 147 | return self.encode(x) -------------------------------------------------------------------------------- /ot_util.py: -------------------------------------------------------------------------------- 1 | import ot 2 | import torch 3 | from util import * 4 | import numpy as np 5 | import time 6 | import torch.nn as nn 7 | 8 | 9 | def get_transported_labels(plan, ys, logit=False): 10 | # plan /= np.sum(plan, 0, keepdims=True) 11 | ysTemp = ot.utils.label_normalization(np.copy(ys)) 12 | classes = np.unique(ysTemp) 13 | n = len(classes) 14 | D1 = np.zeros((n, len(ysTemp))) 15 | 16 | # perform label propagation 17 | transp = plan 18 | 19 | # set nans to 0 20 | transp[~ np.isfinite(transp)] = 0 21 | 22 | for c in classes: 23 | D1[int(c), ysTemp == c] = 1 24 | 25 | # compute propagated labels 26 | transp_ys = np.dot(D1, transp).T 27 | 28 | if logit: 29 | return transp_ys 30 | 31 | transp_ys = np.argmax(transp_ys, axis=1) 32 | 33 | return transp_ys 34 | 35 | 36 | def get_conf_idx(logits, confidence_q=0.2): 37 | confidence = np.amax(logits, axis=1) - np.amin(logits, axis=1) 38 | alpha = np.quantile(confidence, confidence_q) 39 | indices = np.argwhere(confidence >= alpha)[:, 0] 40 | labels = np.argmax(logits, axis=1) 41 | 42 | return labels, indices 43 | 44 | 45 | def get_OT_plan(X_S, X_T, solver='sinkhorn', weights_S=None, weights_T=None, Y_S=None, numItermax=1e7, 46 | entropy_coef=1, entry_cutoff=0): 47 | 48 | # X_S, X_T = X_S[:50000], X_T[:50000] 49 | X_S, X_T = X_S, X_T 50 | n, m = len(X_S), len(X_T) 51 | a = np.ones(n) / n if weights_S is None else weights_S 52 | b = np.ones(m) / m if weights_T is None else weights_T 53 | print(f'{n} source data, {m} target data. ') 54 | dist_mat = ot.dist(X_S, X_T).detach().numpy() 55 | t = time.time() 56 | if solver == 'emd': 57 | plan = ot.emd(a, b, dist_mat, numItermax=int(numItermax)) 58 | elif solver == 'sinkhorn': 59 | plan = ot.sinkhorn(a, b, dist_mat, reg=entropy_coef, numItermax=int(numItermax), stopThr=10e-7) 60 | elif solver == 'lpl1': 61 | plan = ot.sinkhorn_lpl1_mm(a, b, Y_S, dist_mat, reg=entropy_coef, numItermax=int(numItermax), stopInnerThr=10e-9) 62 | 63 | if entry_cutoff > 0: 64 | avg_val = 1 / (n * m) 65 | print(f'Zero out entries with value < {entry_cutoff}*{avg_val}') 66 | plan[plan < avg_val * entry_cutoff] = 0 67 | 68 | elapsed = round(time.time() - t, 2) 69 | print(f"Time for OT calculation: {elapsed}s") 70 | # plan /= np.sum(plan, 0, keepdims=True) 71 | # plan[~ np.isfinite(plan)] = 0 72 | plan = plan * n 73 | 74 | return plan 75 | 76 | 77 | def pushforward(X_S, X_T, plan, t): 78 | print(f'Pushforward to t={t}') 79 | assert 0 <= t <= 1 80 | nonzero_indices = np.argwhere(plan > 0) 81 | weights = plan[plan > 0] 82 | assert len(nonzero_indices) == len(weights) 83 | x_t= (1-t)*X_S[nonzero_indices[:,0]] + t*X_T[nonzero_indices[:,1]] 84 | 85 | return x_t, weights 86 | 87 | 88 | def generate_domains(n_inter, dataset_s, dataset_t, plan=None, entry_cutoff=0, conf=0): 89 | print("------------Generate Intermediate domains----------") 90 | all_domains = [] 91 | 92 | xs, xt = dataset_s.data, dataset_t.data 93 | ys = dataset_s.targets 94 | 95 | if plan is None: 96 | if len(xs.shape) > 2: 97 | xs_flat, xt_flat = nn.Flatten()(xs), nn.Flatten()(xt) 98 | plan = get_OT_plan(xs_flat, xt_flat, solver='emd', entry_cutoff=entry_cutoff) 99 | else: 100 | plan = get_OT_plan(xs, xt, solver='emd', entry_cutoff=entry_cutoff) 101 | 102 | logits_t = get_transported_labels(plan, ys, logit=True) 103 | yt_hat, conf_idx = get_conf_idx(logits_t, confidence_q=conf) 104 | xt = xt[conf_idx] 105 | plan = plan[:, conf_idx] 106 | yt_hat = yt_hat[conf_idx] 107 | 108 | print(f"Remaining data after confidence filter: {len(conf_idx)}") 109 | 110 | for i in range(1, n_inter+1): 111 | x, weights = pushforward(xs, xt, plan, i / (n_inter+1)) 112 | if isinstance(x, np.ndarray): 113 | all_domains.append(DomainDataset(torch.from_numpy(x).float(), weights)) 114 | else: 115 | all_domains.append(DomainDataset(x, weights)) 116 | all_domains.append(dataset_t) 117 | 118 | print(f"Total data for each intermediate domain: {len(x)}") 119 | 120 | return all_domains 121 | 122 | 123 | def ot_ablation(size, mode): 124 | ns, nt = size, size 125 | plan = np.zeros((ns, nt)) 126 | ran = np.arange(ns*nt) 127 | np.random.shuffle(ran) 128 | idx = ran[:size] 129 | 130 | for i in idx: 131 | row = i // nt 132 | col = i-i//nt * nt 133 | if mode == "random": 134 | plan[row, col] = np.random.uniform() 135 | elif mode == "uniform": 136 | plan[row, col] = 1 137 | 138 | plan /= np.sum(plan, 1, keepdims=True) 139 | plan[~ np.isfinite(plan)] = 0 140 | 141 | return plan 142 | 143 | 144 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import scipy.io 4 | import numpy as np 5 | import sklearn.preprocessing 6 | from scipy import ndimage 7 | import pandas as pd 8 | from tensorflow.keras.datasets import mnist 9 | 10 | 11 | class DomainDataset(Dataset): 12 | def __init__(self, x, weight, transform=None): 13 | self.data = x.cpu().detach() 14 | self.targets = -1 * torch.ones(len(self.data)) 15 | self.weight = weight 16 | self.transform = transform 17 | 18 | def __len__(self): 19 | return len(self.data) 20 | 21 | def __getitem__(self, idx): 22 | if self.transform is not None: 23 | return self.transform(self.data[idx]), self.targets[idx], self.weight[idx] 24 | return self.data[idx], self.targets[idx], self.weight[idx] 25 | 26 | 27 | class EncodeDataset(Dataset): 28 | def __init__(self, x, y, transform=None): 29 | self.data = x 30 | self.targets = y 31 | self.transform = transform 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def __getitem__(self, idx): 37 | if self.transform is not None: 38 | return self.transform(self.data[idx]).float(), self.targets[idx] 39 | return self.data[idx], self.targets[idx] 40 | 41 | 42 | """ 43 | Make portraits dataset 44 | """ 45 | def shuffle(xs, ys): 46 | indices = list(range(len(xs))) 47 | np.random.shuffle(indices) 48 | return xs[indices], ys[indices] 49 | 50 | 51 | def split_sizes(array, sizes): 52 | indices = np.cumsum(sizes) 53 | return np.split(array, indices) 54 | 55 | 56 | def load_portraits_data(load_file='dataset_32x32.mat'): 57 | data = scipy.io.loadmat('./' + load_file) 58 | return data['Xs'], data['Ys'][0] 59 | 60 | def make_portraits_data(n_src_tr, n_src_val, n_inter, n_target_unsup, n_trg_val, n_trg_tst, 61 | load_file='dataset_32x32.mat'): 62 | xs, ys = load_portraits_data(load_file) 63 | src_end = n_src_tr + n_src_val 64 | inter_end = src_end + n_inter 65 | trg_end = inter_end + n_trg_val + n_trg_tst 66 | src_x, src_y = shuffle(xs[:src_end], ys[:src_end]) 67 | trg_x, trg_y = shuffle(xs[inter_end:trg_end], ys[inter_end:trg_end]) 68 | [src_tr_x, src_val_x] = split_sizes(src_x, [n_src_tr]) 69 | [src_tr_y, src_val_y] = split_sizes(src_y, [n_src_tr]) 70 | [trg_val_x, trg_test_x] = split_sizes(trg_x, [n_trg_val]) 71 | [trg_val_y, trg_test_y] = split_sizes(trg_y, [n_trg_val]) 72 | inter_x, inter_y = xs[src_end:inter_end], ys[src_end:inter_end] 73 | dir_inter_x, dir_inter_y = inter_x[-n_target_unsup:], inter_y[-n_target_unsup:] 74 | return (src_tr_x, src_tr_y, src_val_x, src_val_y, inter_x, inter_y, 75 | dir_inter_x, dir_inter_y, trg_val_x, trg_val_y, trg_test_x, trg_test_y) 76 | 77 | 78 | """ 79 | make covertype dataset 80 | """ 81 | def make_data(n_src_tr, n_src_val, n_inter, n_target_unsup, n_trg_val, n_trg_tst, xs, ys): 82 | src_end = n_src_tr + n_src_val 83 | inter_end = src_end + n_inter 84 | trg_end = inter_end + n_trg_val + n_trg_tst 85 | src_x, src_y = shuffle(xs[:src_end], ys[:src_end]) 86 | trg_x, trg_y = shuffle(xs[inter_end:trg_end], ys[inter_end:trg_end]) 87 | [src_tr_x, src_val_x] = split_sizes(src_x, [n_src_tr]) 88 | [src_tr_y, src_val_y] = split_sizes(src_y, [n_src_tr]) 89 | [trg_val_x, trg_test_x] = split_sizes(trg_x, [n_trg_val]) 90 | [trg_val_y, trg_test_y] = split_sizes(trg_y, [n_trg_val]) 91 | inter_x, inter_y = xs[src_end:inter_end], ys[src_end:inter_end] 92 | dir_inter_x, dir_inter_y = inter_x[-n_target_unsup:], inter_y[-n_target_unsup:] 93 | return (src_tr_x, src_tr_y, src_val_x, src_val_y, inter_x, inter_y, 94 | dir_inter_x, dir_inter_y, trg_val_x, trg_val_y, trg_test_x, trg_test_y) 95 | 96 | 97 | def load_covtype_data(load_file, normalize=True): 98 | df = pd.read_csv(load_file, header=None) 99 | data = df.to_numpy() 100 | xs = data[:, :54] 101 | if normalize: 102 | xs = (xs - np.mean(xs, axis=0)) / np.std(xs, axis=0) 103 | ys = data[:, 54] - 1 104 | 105 | # Keep the first 2 types of crops, these comprise majority of the dataset. 106 | keep = (ys <= 1) 107 | print(len(xs)) 108 | xs = xs[keep] 109 | ys = ys[keep] 110 | print(len(xs)) 111 | 112 | # Sort by (horizontal) distance to water body. 113 | dist_to_water = xs[:, 3] 114 | indices = np.argsort(dist_to_water, axis=0) 115 | xs = xs[indices] 116 | ys = ys[indices] 117 | return xs, ys 118 | 119 | def make_cov_data(n_src_tr, n_src_val, n_inter, n_target_unsup, n_trg_val, n_trg_tst, 120 | load_file="covtype.data", normalize=True): 121 | xs, ys = load_covtype_data(load_file) 122 | return make_data(n_src_tr, n_src_val, n_inter, n_target_unsup, n_trg_val, n_trg_tst, xs, ys) 123 | 124 | 125 | def cov_data_func(): 126 | return make_cov_data(40000, 10000, 400000, 50000, 25000, 20000) 127 | 128 | def cov_data_small_func(): 129 | return make_cov_data(10000, 40000, 400000, 50000, 25000, 20000) 130 | 131 | def cov_data_func_no_normalize(): 132 | return make_cov_data(40000, 10000, 400000, 50000, 25000, 20000, normalize=False) 133 | 134 | """ 135 | Make Color-shift MNIST dataset 136 | """ 137 | def shift_color_images(xs, shift): 138 | return xs + shift 139 | 140 | def get_preprocessed_mnist(): 141 | (train_x, train_y), (test_x, test_y) = mnist.load_data() 142 | train_x, test_x = train_x / 255.0, test_x / 255.0 143 | train_x, train_y = shuffle(train_x, train_y) 144 | train_x = np.expand_dims(np.array(train_x), axis=-1) 145 | test_x = np.expand_dims(np.array(test_x), axis=-1) 146 | return (train_x, train_y), (test_x, test_y) 147 | 148 | def ColorShiftMNIST(shift=10): 149 | (train_x, train_y), (test_x, test_y) = get_preprocessed_mnist() 150 | src_train_end, src_val_end, inter_end, target_end = 5000, 6000, 48000, 50000 151 | src_tr_x, src_tr_y = train_x[:src_train_end], train_y[:src_train_end] 152 | src_val_x, src_val_y = train_x[src_train_end:src_val_end], train_y[src_train_end:src_val_end] 153 | dir_inter_x, dir_inter_y = train_x[src_val_end:inter_end], train_y[src_val_end:inter_end] 154 | trg_val_x, trg_val_y = train_x[inter_end:target_end], train_y[inter_end:target_end] 155 | trg_test_x, trg_test_y = test_x, test_y 156 | trg_val_x, trg_test_x = shift_color_images(trg_val_x, shift), shift_color_images(trg_test_x, shift) 157 | return (src_tr_x, src_tr_y, src_val_x, src_val_y, dir_inter_x, dir_inter_y, 158 | dir_inter_x, dir_inter_y, trg_val_x, trg_val_y, trg_test_x, trg_test_y) 159 | 160 | 161 | def transform_inter_data(dir_inter_x, dir_inter_y, source_scale, target_scale, transform_func=shift_color_images, interval=2000, n_domains=20, n_classes=10, class_balanced=False, reverse_point=None): 162 | all_domain_x = [] 163 | all_domain_y = [] 164 | path_length = target_scale - source_scale 165 | if reverse_point is not None: 166 | assert reverse_point >= source_scale and reverse_point <= target_scale 167 | path_length += reverse_point * 2 168 | scales = source_scale + np.flip(np.linspace(path_length,0,n_domains)) 169 | for domain_idx in range(n_domains): 170 | domain_scale = source_scale + path_length / n_domains * (domain_idx + 1) 171 | if class_balanced: 172 | domain_data_idxes = [] 173 | n_domain_class_data = int(interval / n_classes) 174 | for label in range(n_classes): 175 | class_idxes = np.where(dir_inter_y == label)[0] 176 | domain_data_idxes.append(np.random.choice(class_idxes, n_domain_class_data, replace=False)) 177 | domain_data_idxes = np.concatenate(domain_data_idxes, axis=0) 178 | else: 179 | domain_data_idxes = np.random.choice(dir_inter_x.shape[0], interval, replace=False) 180 | domain_x = dir_inter_x[domain_data_idxes] 181 | domain_y = dir_inter_y[domain_data_idxes] 182 | domain_x = transform_func(domain_x, domain_scale) 183 | all_domain_x.append(domain_x) 184 | all_domain_y.append(domain_y) 185 | all_domain_x = np.concatenate(all_domain_x, axis=0) 186 | all_domain_y = np.concatenate(all_domain_y, axis=0) 187 | return all_domain_x, all_domain_y 188 | -------------------------------------------------------------------------------- /experiments.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import * 3 | import torch.optim as optim 4 | from train_model import * 5 | from util import * 6 | from ot_util import ot_ablation 7 | from da_algo import * 8 | from ot_util import generate_domains 9 | from dataset import * 10 | import copy 11 | import argparse 12 | import random 13 | import torch.backends.cudnn as cudnn 14 | import time 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | def get_source_model(args, trainset, testset, n_class, mode, encoder=None, epochs=50, verbose=True): 20 | 21 | print("Start training source model") 22 | model = Classifier(encoder, MLP(mode=mode, n_class=n_class, hidden=1024)).to(device) 23 | 24 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) 25 | trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 26 | testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 27 | 28 | for epoch in range(1, epochs+1): 29 | train(epoch, trainloader, model, optimizer, verbose=verbose) 30 | if epoch % 5 == 0: 31 | test(testloader, model, verbose=verbose) 32 | 33 | return model 34 | 35 | 36 | def run_goat(model_copy, source_model, src_trainset, tgt_trainset, all_sets, generated_domains, epochs=10): 37 | 38 | # get the performance of direct adaptation from the source to target, st involves self-training on target 39 | direct_acc, st_acc = self_train(args, model_copy, [tgt_trainset], epochs=epochs) 40 | # get the performance of GST from the source to target, st involves self-training on target 41 | direct_acc_all, st_acc_all = self_train(args, source_model, all_sets, epochs=epochs) 42 | 43 | # encode the source and target domains 44 | e_src_trainset, e_tgt_trainset = get_encoded_dataset(source_model.encoder, src_trainset), get_encoded_dataset(source_model.encoder, tgt_trainset) 45 | 46 | # encode the intermediate ground-truth domains 47 | intersets = all_sets[:-1] 48 | encoded_intersets = [e_src_trainset] 49 | for i in intersets: 50 | encoded_intersets.append(get_encoded_dataset(source_model.encoder, i)) 51 | encoded_intersets.append(e_tgt_trainset) 52 | 53 | # generate intermediate domains 54 | generated_acc = 0 55 | if generated_domains > 0: 56 | all_domains = [] 57 | for i in range(len(encoded_intersets)-1): 58 | all_domains += generate_domains(generated_domains, encoded_intersets[i], encoded_intersets[i+1]) 59 | 60 | _, generated_acc = self_train(args, source_model.mlp, all_domains, epochs=epochs) 61 | 62 | return direct_acc, st_acc, direct_acc_all, st_acc_all, generated_acc 63 | 64 | 65 | def run_mnist_experiment(target, gt_domains, generated_domains): 66 | 67 | t = time.time() 68 | 69 | src_trainset, tgt_trainset = get_single_rotate(False, 0), get_single_rotate(False, target) 70 | 71 | encoder = ENCODER().to(device) 72 | source_model = get_source_model(args, src_trainset, src_trainset, 10, "mnist", encoder=encoder, epochs=5) 73 | model_copy = copy.deepcopy(source_model) 74 | 75 | all_sets = [] 76 | for i in range(1, gt_domains+1): 77 | all_sets.append(get_single_rotate(False, i*target//(gt_domains+1))) 78 | print(i*target//(gt_domains+1)) 79 | all_sets.append(tgt_trainset) 80 | 81 | direct_acc, st_acc, direct_acc_all, st_acc_all, generated_acc = run_goat(model_copy, source_model, src_trainset, tgt_trainset, all_sets, generated_domains, epochs=5) 82 | 83 | elapsed = round(time.time() - t, 2) 84 | print(elapsed) 85 | with open(f"logs/mnist_{target}_{gt_domains}_layer.txt", "a") as f: 86 | f.write(f"seed{args.seed}with{gt_domains}gt{generated_domains}generated,{round(direct_acc, 2)},{round(st_acc, 2)},{round(direct_acc_all, 2)},{round(st_acc_all, 2)},{round(generated_acc, 2)}\n") 87 | 88 | 89 | def run_mnist_ablation(target, gt_domains, generated_domains): 90 | 91 | encoder = ENCODER().to(device) 92 | src_trainset, tgt_trainset = get_single_rotate(False, 0), get_single_rotate(False, target) 93 | source_model = get_source_model(args, src_trainset, src_trainset, 10, "mnist", encoder=encoder, epochs=20) 94 | model_copy = copy.deepcopy(source_model) 95 | 96 | all_sets = [] 97 | for i in range(1, gt_domains+1): 98 | all_sets.append(get_single_rotate(False, i*target//(gt_domains+1))) 99 | print(i*target//(gt_domains+1)) 100 | all_sets.append(tgt_trainset) 101 | 102 | direct_acc, st_acc = self_train(args, model_copy, [tgt_trainset], epochs=10) 103 | direct_acc_all, st_acc_all = self_train(args, source_model, all_sets, epochs=10) 104 | model_copy1 = copy.deepcopy(source_model) 105 | model_copy2 = copy.deepcopy(source_model) 106 | model_copy3 = copy.deepcopy(source_model) 107 | model_copy4 = copy.deepcopy(source_model) 108 | 109 | e_src_trainset, e_tgt_trainset = get_encoded_dataset(source_model.encoder, src_trainset), get_encoded_dataset(source_model.encoder, tgt_trainset) 110 | intersets = all_sets[:-1] 111 | encoded_intersets = [e_src_trainset] 112 | for i in intersets: 113 | encoded_intersets.append(get_encoded_dataset(source_model.encoder, i)) 114 | encoded_intersets.append(e_tgt_trainset) 115 | 116 | # random plan 117 | all_domains1 = [] 118 | for i in range(len(encoded_intersets)-1): 119 | plan = ot_ablation(len(src_trainset), "random") 120 | all_domains1 += generate_domains(generated_domains, encoded_intersets[i], encoded_intersets[i+1], plan=plan) 121 | _, generated_acc1 = self_train(args, model_copy1.mlp, all_domains1, epochs=10) 122 | 123 | # uniform plan 124 | all_domains4 = [] 125 | for i in range(len(encoded_intersets)-1): 126 | plan = ot_ablation(len(src_trainset), "uniform") 127 | all_domains4 += generate_domains(generated_domains, encoded_intersets[i], encoded_intersets[i+1], plan=plan) 128 | _, generated_acc4 = self_train(args, model_copy4.mlp, all_domains4, epochs=10) 129 | 130 | # OT plan 131 | all_domains2 = [] 132 | for i in range(len(encoded_intersets)-1): 133 | all_domains2 += generate_domains(generated_domains, encoded_intersets[i], encoded_intersets[i+1]) 134 | _, generated_acc2 = self_train(args, model_copy2.mlp, all_domains2, epochs=10) 135 | 136 | # ground-truth plan 137 | all_domains3 = [] 138 | for i in range(len(encoded_intersets)-1): 139 | plan = np.identity(len(src_trainset)) 140 | all_domains3 += generate_domains(generated_domains, encoded_intersets[i], encoded_intersets[i+1]) 141 | _, generated_acc3 = self_train(args, model_copy3.mlp, all_domains3, epochs=10) 142 | 143 | with open(f"logs/mnist_{target}_{generated_domains}_ablation.txt", "a") as f: 144 | f.write(f"seed{args.seed}generated{generated_domains},{round(direct_acc, 2)},{round(st_acc, 2)},{round(st_acc_all, 2)},{round(generated_acc1, 2)},{round(generated_acc4.item(), 2)},{round(generated_acc2, 2)},{round(generated_acc3, 2)}\n") 145 | 146 | 147 | def run_portraits_experiment(gt_domains, generated_domains): 148 | t = time.time() 149 | 150 | (src_tr_x, src_tr_y, src_val_x, src_val_y, inter_x, inter_y, dir_inter_x, dir_inter_y, 151 | trg_val_x, trg_val_y, trg_test_x, trg_test_y) = make_portraits_data(1000, 1000, 14000, 2000, 1000, 1000) 152 | tr_x, tr_y = np.concatenate([src_tr_x, src_val_x]), np.concatenate([src_tr_y, src_val_y]) 153 | ts_x, ts_y = np.concatenate([trg_val_x, trg_test_x]), np.concatenate([trg_val_y, trg_test_y]) 154 | 155 | encoder = ENCODER().to(device) 156 | transforms = ToTensor() 157 | 158 | src_trainset = EncodeDataset(tr_x, tr_y.astype(int), transforms) 159 | tgt_trainset = EncodeDataset(ts_x, ts_y.astype(int), transforms) 160 | source_model = get_source_model(args, src_trainset, src_trainset, 2, mode="portraits", encoder=encoder, epochs=20) 161 | model_copy = copy.deepcopy(source_model) 162 | 163 | def get_domains(n_domains): 164 | domain_set = [] 165 | n2idx = {0:[], 1:[3], 2:[2,4], 3:[1,3,5], 4:[0,2,4,6], 7:[0,1,2,3,4,5,6]} 166 | domain_idx = n2idx[n_domains] 167 | for i in domain_idx: 168 | start, end = i*2000, (i+1)*2000 169 | domain_set.append(EncodeDataset(inter_x[start:end], inter_y[start:end].astype(int), transforms)) 170 | return domain_set 171 | 172 | all_sets = get_domains(gt_domains) 173 | all_sets.append(tgt_trainset) 174 | 175 | direct_acc, st_acc, direct_acc_all, st_acc_all, generated_acc = run_goat(model_copy, source_model, src_trainset, tgt_trainset, all_sets, generated_domains, epochs=5) 176 | 177 | elapsed = round(time.time() - t, 2) 178 | with open(f"logs/portraits_exp_time.txt", "a") as f: 179 | f.write(f"seed{args.seed}with{gt_domains}gt{generated_domains}generated,{round(direct_acc, 2)},{round(st_acc, 2)},{round(direct_acc_all, 2)},{round(st_acc_all, 2)},{round(generated_acc, 2)}\n") 180 | 181 | 182 | def run_covtype_experiment(gt_domains, generated_domains): 183 | data = make_cov_data(40000, 10000, 400000, 50000, 25000, 20000) 184 | (src_tr_x, src_tr_y, src_val_x, src_val_y, inter_x, inter_y, dir_inter_x, dir_inter_y, 185 | trg_val_x, trg_val_y, trg_test_x, trg_test_y) = data 186 | 187 | src_trainset = EncodeDataset(torch.from_numpy(src_val_x).float(), src_val_y.astype(int)) 188 | tgt_trainset = EncodeDataset(torch.from_numpy(trg_test_x).float(), torch.tensor(trg_test_y.astype(int))) 189 | 190 | encoder = MLP_Encoder().to(device) 191 | source_model = get_source_model(args, src_trainset, src_trainset, 2, mode="covtype", encoder=encoder, epochs=5) 192 | model_copy = copy.deepcopy(source_model) 193 | 194 | def get_domains(n_domains): 195 | domain_set = [] 196 | n2idx = {0:[], 1:[6], 2:[3,7], 3:[2,5,8], 4:[2,4,6,8], 5:[1,3,5,7,9], 10: range(10), 200: range(200)} 197 | domain_idx = n2idx[n_domains] 198 | # domain_idx = range(n_domains) 199 | for i in domain_idx: 200 | # start, end = i*2000, (i+1)*2000 201 | # start, end = i*10000, (i+1)*10000 202 | start, end = i*40000, i*40000 + 2000 203 | domain_set.append(EncodeDataset(torch.from_numpy(inter_x[start:end]).float(), inter_y[start:end].astype(int))) 204 | return domain_set 205 | 206 | all_sets = get_domains(gt_domains) 207 | all_sets.append(tgt_trainset) 208 | 209 | direct_acc, st_acc, direct_acc_all, st_acc_all, generated_acc = run_goat(model_copy, source_model, src_trainset, tgt_trainset, all_sets, generated_domains, epochs=5) 210 | 211 | with open(f"logs/covtype_exp_{args.log_file}.txt", "a") as f: 212 | f.write(f"seed{args.seed}with{gt_domains}gt{generated_domains}generated,{round(direct_acc, 2)},{round(st_acc, 2)},{round(st_acc_all, 2)},{round(generated_acc, 2)}\n") 213 | 214 | 215 | def run_color_mnist_experiment(gt_domains, generated_domains): 216 | shift = 1 217 | total_domains = 20 218 | 219 | src_tr_x, src_tr_y, src_val_x, src_val_y, dir_inter_x, dir_inter_y, dir_inter_x, dir_inter_y, trg_val_x, trg_val_y, trg_test_x, trg_test_y = ColorShiftMNIST(shift=shift) 220 | inter_x, inter_y = transform_inter_data(dir_inter_x, dir_inter_y, 0, shift, interval=len(dir_inter_x)//total_domains, n_domains=total_domains) 221 | 222 | src_x, src_y = np.concatenate([src_tr_x, src_val_x]), np.concatenate([src_tr_y, src_val_y]) 223 | tgt_x, tgt_y = np.concatenate([trg_val_x, trg_test_x]), np.concatenate([trg_val_y, trg_test_y]) 224 | src_trainset, tgt_trainset = EncodeDataset(src_x, src_y.astype(int), ToTensor()), EncodeDataset(trg_val_x, trg_val_y.astype(int), ToTensor()) 225 | 226 | encoder = ENCODER().to(device) 227 | source_model = get_source_model(args, src_trainset, src_trainset, 10, "mnist", encoder=encoder, epochs=20) 228 | model_copy = copy.deepcopy(source_model) 229 | 230 | def get_domains(n_domains): 231 | domain_set = [] 232 | 233 | domain_idx = [] 234 | if n_domains == total_domains: 235 | domain_idx = range(n_domains) 236 | else: 237 | for i in range(1, n_domains+1): 238 | domain_idx.append(total_domains // (n_domains+1) * i) 239 | 240 | interval = 42000 // total_domains 241 | for i in domain_idx: 242 | start, end = i*interval, (i+1)*interval 243 | domain_set.append(EncodeDataset(inter_x[start:end], inter_y[start:end].astype(int), ToTensor())) 244 | return domain_set 245 | 246 | all_sets = get_domains(gt_domains) 247 | all_sets.append(tgt_trainset) 248 | 249 | direct_acc, st_acc, direct_acc_all, st_acc_all, generated_acc = run_goat(model_copy, source_model, src_trainset, tgt_trainset, all_sets, generated_domains, epochs=10) 250 | 251 | with open(f"logs/color{args.log_file}.txt", "a") as f: 252 | f.write(f"seed{args.seed}with{gt_domains}gt{generated_domains}generated,{round(direct_acc, 2)},{round(st_acc, 2)},{round(direct_acc_all, 2)},{round(st_acc_all, 2)},{round(generated_acc, 2)}\n") 253 | 254 | 255 | def main(args): 256 | 257 | print(args) 258 | 259 | if args.dataset == "mnist": 260 | if args.mnist_mode == "normal": 261 | run_mnist_experiment(args.rotation_angle, args.gt_domains, args.generated_domains) 262 | else: 263 | run_mnist_ablation(args.rotation_angle, args.gt_domains, args.generated_domains) 264 | else: 265 | eval(f"run_{args.dataset}_experiment({args.gt_domains}, {args.generated_domains})") 266 | 267 | 268 | if __name__ == '__main__': 269 | 270 | parser = argparse.ArgumentParser(description="GOAT experiments") 271 | parser.add_argument("--dataset", choices=["mnist", "portraits", "covtype", "color_mnist"]) 272 | parser.add_argument("--gt-domains", default=0, type=int) 273 | parser.add_argument("--generated-domains", default=0, type=int) 274 | parser.add_argument("--seed", default=0, type=int) 275 | parser.add_argument("--mnist-mode", default="normal", choices=["normal", "ablation"]) 276 | parser.add_argument("--rotation-angle", default=45, type=int) 277 | parser.add_argument("--batch-size", default=128, type=int) 278 | parser.add_argument("--lr", default=1e-4, type=float) 279 | parser.add_argument("--num-workers", default=2, type=int) 280 | parser.add_argument("--log-file", default="") 281 | args = parser.parse_args() 282 | 283 | main(args) --------------------------------------------------------------------------------