├── README.md ├── fig ├── ICLR2023.png └── toy.png ├── requirements.txt └── toy ├── configs ├── dg15.py └── dg60.py ├── data ├── toy_d15_spiral_tight_boundary.pkl └── toy_d60_spiral.pkl ├── data_loader ├── data_loader.py └── utils.py ├── learn_graph.py ├── main.py ├── model ├── model.py └── modules.py ├── run_tro.sh └── utils ├── centrality.py ├── graph_utils.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Topology-aware Robust Optimization for Out-of-Distribution Generalization (TRO) 2 | 3 | This repository holds the Pytorch implementation of [Topology-aware Robust Optimization for Out-of-Distribution Generalization](https://openreview.net/forum?id=ylMq8MBnAp) by Fengchun Qiao and Xi Peng. 4 | If you find our code useful in your research, please consider citing: 5 | 6 | ``` 7 | @inproceedings{qiao2023tro, 8 | title={Topology-aware Robust Optimization for Out-of-Distribution Generalization}, 9 | author={Fengchun Qiao and Xi Peng}, 10 | booktitle={International Conference on Learning Representations (ICLR)}, 11 | year={2023} 12 | } 13 | ``` 14 | 15 | 16 | ## Introduction 17 | 18 | We study the problem of out-of-distribution (OOD) generalization. 19 | As generalizing to arbitrary test distributions is impossible, we hypothesize that further structure on the topology of distributions is crucial in developing strong OOD resilience. 20 | To this end, we propose topology-aware robust optimization (TRO) that seamlessly integrates distributional topology in a principled optimization framework. 21 | 22 | 23 | ## Quick start 24 | This repository reproduces our results on DG-15/60, TPT-48, and DomainBed, which is build upon Python3 and Pytorch v1.9.0 on Ubuntu 18.04. 25 | Please install all required packages by running: 26 | 27 | ``` 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ## Results on DG-15/60 32 | 33 | Illustration of data groups in (a) DG-15 and (b) DG-60 datasets: 34 | 35 | ![toy](fig/toy.png) 36 | 37 | To reproduce the results of TRO with physical-based topology on DG-15, please run: 38 | 39 | ``` 40 | python main.py --dataset toy_d15 --learn 0 --model TRO 41 | ``` 42 | 43 | To reproduce the results of TRO with data-driven topology on DG-15, please run: 44 | 45 | ``` 46 | python main.py --dataset toy_d15 --learn 1 --model TRO 47 | ``` 48 | 49 | To reproduce the results on DG-15, please change `--dataset` to `toy_d60`. 50 | You can also reproduce the results of other baselines by changing `--model`. 51 | 52 | ## TODO 53 | 54 | - [ ] TPT-48 55 | - [ ] DomainBed 56 | 57 | 58 | 59 | 60 | 61 | ## Acknowledgement 62 | 63 | Part of our code is borrowed from the following repositories. 64 | 65 | - [GRDA](https://github.com/Wang-ML-Lab/GRDA) 66 | - [DiffusionEMD](https://github.com/KrishnaswamyLab/DiffusionEMD) 67 | - [DomainBed](https://github.com/facebookresearch/DomainBed) 68 | 69 | We thank to the authors for releasing their codes. Please also consider citing their works. 70 | -------------------------------------------------------------------------------- /fig/ICLR2023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-real/TRO/c2fa464c604cdef9fef256d598fef6a7e24e852d/fig/ICLR2023.png -------------------------------------------------------------------------------- /fig/toy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-real/TRO/c2fa464c604cdef9fef256d598fef6a7e24e852d/fig/toy.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | networkx==3.0 3 | matplotlib 4 | visdom==0.2.4 5 | easydict 6 | DiffusionEMD 7 | torch==1.9.0 8 | torchvision==0.10.0 -------------------------------------------------------------------------------- /toy/configs/dg15.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | opt = EasyDict() 3 | 4 | # Report acc on all domains 5 | opt.test_on_all_dmn = True 6 | 7 | opt.use_visdom = False 8 | opt.visdom_port = 2000 9 | 10 | opt.device = "cuda" 11 | opt.seed = 233 12 | 13 | # Learning 14 | opt.lr_e = 3e-5 15 | opt.groupdro_eta = 1. # DRO's eta hyper-parameter 16 | opt.lmbda = 100. # regularizer 17 | opt.beta1 = 0.9 18 | opt.no_bn = True 19 | opt.threshold = 80 20 | opt.num_epoch = 20 21 | 22 | # model size configs, used for E, F 23 | opt.nx = 2 # dimension of the input data 24 | opt.nh = 512 # dimension of hidden # 512 25 | opt.nc = 2 # number of label class 26 | 27 | opt.test_interval = 20 28 | opt.save_interval = 100 -------------------------------------------------------------------------------- /toy/configs/dg60.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | opt = EasyDict() 3 | 4 | # Report acc on all domains 5 | opt.test_on_all_dmn = True 6 | 7 | opt.use_visdom = False 8 | opt.visdom_port = 2000 9 | 10 | opt.device = "cuda" 11 | opt.seed = 233 12 | 13 | # Learning 14 | opt.lr_e = 3e-5 15 | opt.groupdro_eta = 0.1 # DRO's eta hyper-parameter 16 | opt.lmbda = 0.001 # regularizer 17 | opt.beta1 = 0.9 18 | opt.no_bn = True 19 | opt.threshold = 40 20 | opt.num_epoch = 300 21 | 22 | # model size configs, used for E, F 23 | opt.nx = 2 # dimension of the input data 24 | opt.nh = 512 # dimension of hidden # 512 25 | opt.nc = 2 # number of label class 26 | 27 | opt.test_interval = 20 28 | opt.save_interval = 100 -------------------------------------------------------------------------------- /toy/data/toy_d15_spiral_tight_boundary.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-real/TRO/c2fa464c604cdef9fef256d598fef6a7e24e852d/toy/data/toy_d15_spiral_tight_boundary.pkl -------------------------------------------------------------------------------- /toy/data/toy_d60_spiral.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-real/TRO/c2fa464c604cdef9fef256d598fef6a7e24e852d/toy/data/toy_d60_spiral.pkl -------------------------------------------------------------------------------- /toy/data_loader/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from torch.utils.data import Dataset 4 | 5 | def read_pickle(name): 6 | with open(name, 'rb') as f: 7 | data = pickle.load(f) 8 | return data 9 | 10 | # Toy Dataset 11 | class ToyDataset(Dataset): 12 | def __init__(self, pkl, domain_id): 13 | idx = pkl["domain"] == domain_id 14 | self.data = pkl["data"][idx].astype(np.float32) 15 | self.label = pkl["label"][idx].astype(np.int64) 16 | self.domain = domain_id 17 | 18 | def __getitem__(self, idx): 19 | return self.data[idx], self.label[idx], self.domain 20 | 21 | def __len__(self): 22 | return len(self.data) 23 | 24 | class SeqToyDataset(Dataset): 25 | def __init__(self, datasets, size=3*200): 26 | self.datasets = datasets 27 | self.size = size 28 | print( 29 | "SeqDataset Size {} SubDataset Size {}".format( 30 | size, [len(ds) for ds in datasets] 31 | ) 32 | ) 33 | 34 | def __len__(self): 35 | return self.size 36 | 37 | def __getitem__(self, i): 38 | return [ds[i] for ds in self.datasets] -------------------------------------------------------------------------------- /toy/data_loader/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | # import matplotlib.pyplot as plt 4 | 5 | def get_date_list(begin_date,end_date): 6 | date_list = [x.strftime('%Y-%m-%d') for x in list(pd.date_range(start=begin_date, end=end_date))] 7 | return date_list -------------------------------------------------------------------------------- /toy/learn_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import pickle 5 | import argparse 6 | import os 7 | 8 | from torch.utils.data import DataLoader 9 | from data_loader.data_loader import ToyDataset, SeqToyDataset 10 | 11 | import torch.backends.cudnn as cudnn 12 | 13 | from model.model import ERM as Model 14 | 15 | from utils.centrality import get_centrality 16 | from utils.utils import * 17 | from utils.graph_utils import * 18 | 19 | parser = argparse.ArgumentParser(description='Topology-Aware Robust Optimization for OOD Generalization') 20 | parser.add_argument('--dataset', default='toy_d15', type=str, help='toy_d15, toy_d60') 21 | parser.add_argument('--epochs', default=20, type=int, help='number of total epochs to run ERM') 22 | parser.add_argument('--batch_size', default=10, type=int, help='mini-batch size (default: 10)') 23 | parser.add_argument('--lr_e', default=3e-5, type=float, help='initial learning rate') 24 | parser.add_argument('--gpu_id', default=0, type=int, help='gpu id') 25 | 26 | cudnn.benchmark = True 27 | 28 | args = parser.parse_args() 29 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 on stackoverflow 30 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) 31 | 32 | if args.dataset == "toy_d15": 33 | from configs.dg15 import opt 34 | elif args.dataset == "toy_d60": 35 | from configs.dg60 import opt 36 | else: 37 | raise NotImplementedError("Dataset not implemented. Please try toy_d15 or toy_d60!") 38 | 39 | np.random.seed(opt.seed) 40 | random.seed(opt.seed) 41 | torch.manual_seed(opt.seed) 42 | 43 | opt.model = "ERM" 44 | opt.dataset = args.dataset 45 | opt.outf = None 46 | 47 | print("dataset: {}".format(opt.dataset)) 48 | 49 | # Hyper-params 50 | opt.batch_size = args.batch_size 51 | opt.lr_e = args.lr_e 52 | 53 | # Dataset 54 | if args.dataset == "toy_d15": 55 | opt.num_domain = 15 56 | # the specific source and target domain: 57 | opt.src_domain = [0, 12, 3, 4, 14, 8] 58 | elif args.dataset == "toy_d60": 59 | opt.num_domain = 60 60 | # the specific source and target domain: 61 | opt.src_domain = list(range(6)) 62 | else: 63 | raise NotImplementedError("Dataset not implemented. Please try toy_d15 or toy_d60!") 64 | 65 | opt.src_domain = np.array(opt.src_domain) 66 | opt.num_source = opt.src_domain.shape[0] 67 | opt.num_target = opt.num_domain - opt.num_source 68 | 69 | if args.dataset == "toy_d15": 70 | data_source = os.path.join("data", "toy_d15_spiral_tight_boundary.pkl") 71 | elif args.dataset == "toy_d60": 72 | data_source = os.path.join("data", "toy_d60_spiral.pkl") 73 | with open(data_source, "rb") as data_file: 74 | data_pkl = pickle.load(data_file) 75 | print(f"Data: {data_pkl['data'].shape}\nLabel: {data_pkl['label'].shape}") 76 | 77 | opt.A = data_pkl["A"] 78 | data = data_pkl["data"] 79 | datasets = [ToyDataset(data_pkl, i) for i in range(opt.num_domain)] # sub dataset for each domain 80 | dataset = SeqToyDataset(datasets, size=len(datasets[0])) # mix sub dataset to a large one 81 | dataloader = DataLoader(dataset=dataset, shuffle=True, batch_size=opt.batch_size) 82 | 83 | model = Model(opt).to(opt.device) 84 | 85 | # Train ERM model (super quick) 86 | for epoch in range(opt.num_epoch): 87 | model.learn(epoch, dataloader) 88 | 89 | # Get embeddings from ERM model 90 | # DG-15/60 91 | train_x_seq = data_pkl["data"] 92 | train_x_seq_t = to_tensor(train_x_seq) 93 | train_x_seq_t = train_x_seq_t[None, :].to(torch.float) 94 | # Feature 95 | x_seq_feat = to_np(model.test_x_seq_feat(train_x_seq_t)) 96 | feat_size = x_seq_feat.shape[-1] 97 | x_seq_feat = x_seq_feat.reshape((opt.num_domain, -1, feat_size))[opt.src_domain] 98 | print("x_seq_feat", x_seq_feat.shape) 99 | 100 | n_distributions = x_seq_feat.shape[0] 101 | n_points_per_distribution = x_seq_feat.shape[1] 102 | 103 | # Use diffusion EMD to get the matrix 104 | dis_matrix = emd(x_seq_feat, feat_size, n_distributions, n_points_per_distribution) 105 | 106 | print("distance matrix", dis_matrix) 107 | print(dis_matrix.shape) 108 | 109 | # Get percentile 110 | dis_matrix_flat = dis_matrix.reshape(-1) 111 | thres = np.percentile(dis_matrix_flat, opt.threshold) 112 | # Unweighted graph 113 | A = np.zeros_like(dis_matrix) 114 | rule = (dis_matrix < thres) & (dis_matrix != 0) 115 | A[rule] = 1 116 | 117 | centrality = get_centrality(A) 118 | centrality = np.around(centrality, 3) 119 | print(list(centrality)) -------------------------------------------------------------------------------- /toy/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import pickle 5 | import argparse 6 | import os 7 | 8 | from torch.utils.data import DataLoader 9 | import torch.backends.cudnn as cudnn 10 | 11 | from data_loader.data_loader import ToyDataset, SeqToyDataset 12 | from utils.utils import * 13 | 14 | parser = argparse.ArgumentParser(description='Topology-aware Robust Optimization for OOD Generalization') 15 | parser.add_argument('--dataset', default='toy_d15', type=str, help='toy_d15, toy_d60, weather') 16 | parser.add_argument('--model', default='ERM', type=str, help='ERM, IRM, DRO, TRO') 17 | parser.add_argument('--batch_size', default=10, type=int, help='mini-batch size (default: 10)') 18 | parser.add_argument('--gpu_id', default=0, type=int, help='gpu id') 19 | 20 | # Partial 21 | parser.add_argument('--learn', default=1, type=int, help='data graph (1) or physical (0)') 22 | parser.add_argument('--partial', default=0, type=int, help='source (1) or source + target (0), only for physical graph') 23 | 24 | # IRM 25 | parser.add_argument('--irm_penal', default=1e-1, type=float, help='irm penalty coefficient') 26 | 27 | cudnn.benchmark = True 28 | 29 | args = parser.parse_args() 30 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 on stackoverflow 31 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) 32 | 33 | if args.dataset == "toy_d15": 34 | from configs.dg15 import opt 35 | elif args.dataset == "toy_d60": 36 | from configs.dg60 import opt 37 | 38 | # random seed 39 | np.random.seed(opt.seed) 40 | random.seed(opt.seed) 41 | torch.manual_seed(opt.seed) 42 | 43 | opt.model = args.model 44 | opt.dataset = args.dataset 45 | 46 | print("model: {}".format(opt.model)) 47 | print("dataset: {}".format(opt.dataset)) 48 | 49 | if opt.model == "ERM": 50 | from model.model import ERM as Model 51 | elif opt.model == "DRO": 52 | from model.model import DRO as Model 53 | elif opt.model == "IRM": 54 | from model.model import IRM as Model 55 | opt.irm_penal = args.irm_penal 56 | elif opt.model == "TRO": 57 | from model.model import TRO as Model 58 | 59 | # Important params 60 | opt.batch_size = args.batch_size 61 | 62 | opt.learn = args.learn 63 | opt.partial = args.partial 64 | 65 | if args.dataset == "toy_d15": 66 | opt.num_domain = 15 67 | # the specific source and target domain: 68 | opt.src_domain = [0, 12, 3, 4, 14, 8] #Corresponds to 1-6 in Figure 2 (a) 69 | elif args.dataset == "toy_d60": 70 | opt.num_domain = 60 71 | # the specific source and target domain: 72 | opt.src_domain = list(range(6)) 73 | 74 | opt.num_source = len(opt.src_domain) 75 | opt.num_target = opt.num_domain - opt.num_source 76 | 77 | if args.dataset == "toy_d15": 78 | data_source = os.path.join("data", "toy_d15_spiral_tight_boundary.pkl") 79 | elif args.dataset == "toy_d60": 80 | data_source = os.path.join("data", "toy_d60_spiral.pkl") 81 | else: 82 | raise NotImplementedError("Dataset not implemented. Please try toy_d15 or toy_d60!") 83 | 84 | with open(data_source, "rb") as data_file: 85 | data_pkl = pickle.load(data_file) 86 | print(f"Data: {data_pkl['data'].shape}\nLabel: {data_pkl['label'].shape}") 87 | 88 | # set up experiment directory 89 | opt.outf = setup_experiment(args, opt) 90 | 91 | # build dataset 92 | opt.A = data_pkl["A"] # physical graph's adjacent matrix 93 | data = data_pkl["data"] 94 | 95 | # dataloader 96 | datasets = [ToyDataset(data_pkl, i) for i in range(opt.num_domain)] # sub dataset for each domain 97 | dataset = SeqToyDataset(datasets, size=len(datasets[0])) # mix sub dataset to a large one 98 | dataloader = DataLoader(dataset=dataset, shuffle=True, batch_size=opt.batch_size) 99 | 100 | model = Model(opt).to(opt.device) 101 | 102 | # train 103 | for epoch in range(opt.num_epoch): 104 | model.learn(epoch, dataloader) 105 | 106 | if (epoch + 1) % opt.save_interval == 0 or (epoch + 1) == opt.num_epoch: 107 | model.save(model.model_path) 108 | if (epoch + 1) % opt.test_interval == 0 or (epoch + 1) == opt.num_epoch: 109 | model.test(epoch, dataloader) -------------------------------------------------------------------------------- /toy/model/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.optim.lr_scheduler as lr_scheduler 6 | 7 | import numpy as np 8 | 9 | from utils.centrality import get_centrality, get_centrality_all 10 | from utils.utils import * 11 | from model.modules import ( 12 | FeatureNet, 13 | PredNet 14 | ) 15 | 16 | from visdom import Visdom 17 | import torch.autograd as autograd 18 | 19 | # the base model 20 | class BaseModel(nn.Module): 21 | def __init__(self, opt): 22 | super(BaseModel, self).__init__() 23 | # set output format 24 | np.set_printoptions(suppress=True, precision=6) 25 | 26 | self.model_name = opt.model 27 | self.opt = opt 28 | self.device = opt.device 29 | self.batch_size = opt.batch_size 30 | self.A = opt.A 31 | self.dataset = opt.dataset 32 | 33 | self.device = opt.device 34 | if "DRO" in self.model_name or "TRO" in self.model_name: 35 | self.groupdro_eta = opt.groupdro_eta 36 | if "TRO" in self.model_name: 37 | self.lmbda = opt.lmbda 38 | 39 | # visualization 40 | self.use_visdom = opt.use_visdom 41 | if opt.use_visdom: 42 | self.env = Visdom(port=opt.visdom_port) 43 | self.test_pane = dict() 44 | 45 | self.src_domain = opt.src_domain 46 | self.criterion = nn.NLLLoss().cuda() 47 | self.netE = FeatureNet(opt).to(opt.device) # encoder 48 | self.netF = PredNet(opt).to(opt.device) # predictor 49 | 50 | self.__init_weight__() 51 | 52 | EF_parameters = list(self.netE.parameters()) + list(self.netF.parameters()) 53 | 54 | self.optimizer_EF = optim.Adam( 55 | EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999) 56 | ) 57 | 58 | self.lr_scheduler_EF = lr_scheduler.ExponentialLR( 59 | optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100) 60 | ) 61 | 62 | self.lr_schedulers = [self.lr_scheduler_EF] 63 | self.loss_names = ["loss"] 64 | 65 | self.num_domain = opt.num_domain 66 | if self.opt.test_on_all_dmn: 67 | self.test_dmn_num = self.num_domain 68 | else: 69 | self.test_dmn_num = self.opt.tgt_dmn_num 70 | 71 | self.outf = opt.outf 72 | if opt.outf: 73 | self.train_log = os.path.join(opt.outf, "loss.log") 74 | self.model_path = os.path.join(opt.outf, "model.pth") 75 | 76 | if not os.path.exists(self.opt.outf): 77 | os.mkdir(self.opt.outf) 78 | with open(self.train_log, "w") as f: 79 | f.write("log start!\n") 80 | 81 | src_domains = [str(i) for i in self.src_domain] 82 | if self.outf: 83 | self.__log_write__("src domains: " + " ".join(src_domains)) 84 | 85 | mask_list = np.zeros(opt.num_domain) 86 | mask_list[opt.src_domain] = 1 87 | 88 | self.domain_mask = torch.IntTensor(mask_list).to(opt.device) 89 | 90 | 91 | def learn(self, epoch, dataloader): 92 | self.train() 93 | 94 | self.epoch = epoch 95 | loss_values = {loss: 0 for loss in self.loss_names} 96 | 97 | count = 0 98 | for data in dataloader: 99 | count += 1 100 | self.__set_input__(data) 101 | self.__train_forward__() 102 | new_loss_values = self.__optimize__() 103 | 104 | # for the loss visualization 105 | for key, loss in new_loss_values.items(): 106 | loss_values[key] += loss 107 | 108 | for key, _ in new_loss_values.items(): 109 | loss_values[key] /= count 110 | 111 | if self.use_visdom: 112 | self.__vis_loss__(loss_values) 113 | 114 | if (self.epoch + 1) % 10 == 0: 115 | print("epoch {}: {}".format(self.epoch, loss_values)) 116 | 117 | # learning rate decay 118 | for lr_scheduler in self.lr_schedulers: 119 | lr_scheduler.step() 120 | 121 | def test(self, epoch, dataloader): 122 | self.eval() # validation mode 123 | 124 | acc_curve = [] 125 | l_x = [] 126 | l_domain = [] 127 | l_label = [] 128 | l_encode = [] 129 | 130 | for data in dataloader: 131 | self.__set_input__(data) 132 | 133 | # forward 134 | with torch.no_grad(): 135 | self.__test_forward__() 136 | 137 | acc_curve.append( 138 | self.g_seq.eq(self.y_seq) 139 | .to(torch.float) 140 | .mean(-1, keepdim=True) 141 | ) 142 | l_x.append(to_np(self.x_seq)) 143 | l_domain.append(to_np(self.domain_seq)) 144 | l_encode.append(to_np(self.e_seq)) 145 | l_label.append(to_np(self.g_seq)) 146 | 147 | x_all = np.concatenate(l_x, axis=1) 148 | e_all = np.concatenate(l_encode, axis=1) 149 | domain_all = np.concatenate(l_domain, axis=1) 150 | label_all = np.concatenate(l_label, axis=1) 151 | 152 | d_all = dict() 153 | 154 | d_all["data"] = flat(x_all) 155 | d_all["domain"] = flat(domain_all) 156 | d_all["label"] = flat(label_all) 157 | d_all["encodeing"] = flat(e_all) 158 | 159 | acc = to_np(torch.cat(acc_curve, 1).mean(-1)) 160 | 161 | test_acc = ( 162 | (acc.sum() - acc[self.opt.src_domain].sum()) 163 | / (self.opt.num_target) 164 | * 100 165 | ) 166 | 167 | acc_msg = "[{}] Acc: all avg {:.1f}, test avg {:.2f}".format(epoch, acc.mean() * 100, test_acc) 168 | each_domain_acc = [str(i) for i in np.around(acc * 100, decimals=1)] 169 | each_domain_acc_msg = "[" + ", ".join(each_domain_acc) + "]" 170 | 171 | if self.outf: 172 | self.__log_write__(acc_msg) 173 | self.__log_write__(each_domain_acc_msg) 174 | 175 | if self.use_visdom: 176 | self.__vis_test_error__(test_acc, "test acc") 177 | d_all["acc_msg"] = acc_msg 178 | 179 | def __vis_test_error__(self, loss, title): 180 | if self.epoch == self.opt.test_interval - 1: 181 | # initialize 182 | self.test_pane[title] = self.env.line( 183 | X=np.array([self.epoch]), 184 | Y=np.array([loss]), 185 | opts=dict(title=title), 186 | ) 187 | else: 188 | self.env.line( 189 | X=np.array([self.epoch]), 190 | Y=np.array([loss]), 191 | win=self.test_pane[title], 192 | update="append", 193 | ) 194 | 195 | def save(self, name): 196 | torch.save(self.state_dict(), name) 197 | 198 | def load(self, name): 199 | torch.load(name) 200 | 201 | def __set_input__(self, data, train=True): 202 | """ 203 | :param 204 | x_seq: Number of domain x Batch size x Data dim 205 | y_seq: Number of domain x Batch size x Predict Data dim 206 | one_hot_seq: Number of domain x Batch size x Number of vertices (domains) 207 | domain_seq: Number of domain x Batch size x domain dim (1) 208 | """ 209 | x_seq, y_seq, domain_seq = ( 210 | [d[0][None, :, :] for d in data], 211 | [d[1][None, :] for d in data], 212 | [d[2][None, :] for d in data], 213 | ) 214 | self.x_seq = torch.cat(x_seq, 0).to(self.device) 215 | self.y_seq = torch.cat(y_seq, 0).to(self.device) 216 | 217 | self.domain_seq = torch.cat(domain_seq, 0).to(self.device) 218 | self.tmp_batch_size = self.x_seq.shape[1] 219 | one_hot_seq = [ 220 | torch.nn.functional.one_hot(d[2], self.num_domain) 221 | for d in data 222 | ] 223 | 224 | if train: 225 | self.one_hot_seq = ( 226 | torch.cat(one_hot_seq, 0) 227 | .reshape(self.num_domain, self.tmp_batch_size, -1) 228 | .to(self.device) 229 | ) 230 | else: 231 | self.one_hot_seq = ( 232 | torch.cat(one_hot_seq, 0) 233 | .reshape(self.test_dmn_num, self.tmp_batch_size, -1) 234 | .to(self.device) 235 | ) 236 | 237 | def __train_forward__(self): 238 | self.e_seq = self.netE(self.x_seq) # encoder of the data 239 | self.f_seq = self.netF(self.e_seq) # prediction 240 | 241 | def __test_forward__(self): 242 | self.e_seq = self.netE(self.x_seq) # encoder of the data 243 | self.f_seq = self.netF(self.e_seq) 244 | if "toy" in self.dataset: 245 | self.g_seq = torch.argmax(self.f_seq.detach(), dim=2) # class of the prediction 246 | 247 | def test_x_seq(self, x_seq): 248 | e_seq = self.netE(x_seq) # encoder of the data 249 | f_seq = self.netF(e_seq) 250 | if "toy" in self.dataset: 251 | g_seq = torch.argmax(f_seq.detach(), dim=2) # class of the prediction 252 | return g_seq 253 | 254 | def test_x_seq_feat(self, x_seq): 255 | e_seq = self.netE(x_seq) # encoder of the data 256 | _, x_seq_feat = self.netF(e_seq, return_feature=True) 257 | return x_seq_feat 258 | 259 | def __optimize__(self): 260 | loss_value = dict() 261 | 262 | self.loss_E_pred = self.__loss_EF__() 263 | 264 | self.optimizer_EF.zero_grad() 265 | self.loss_E_pred.backward(retain_graph=True) 266 | self.optimizer_EF.step() 267 | 268 | loss_value["loss"] = self.loss_E_pred.item() 269 | return loss_value 270 | 271 | def __loss_EF__(self): 272 | pass 273 | 274 | def __log_write__(self, loss_msg): 275 | print(loss_msg) 276 | with open(self.train_log, "a") as f: 277 | f.write(loss_msg + "\n") 278 | 279 | def __vis_loss__(self, loss_values): 280 | if self.epoch == 0: 281 | self.panes = { 282 | loss_name: self.env.line( 283 | X=np.array([self.epoch]), 284 | Y=np.array([loss_values[loss_name]]), 285 | opts=dict(title="loss for {} on epochs".format(loss_name)), 286 | ) 287 | for loss_name in self.loss_names 288 | } 289 | else: 290 | for loss_name in self.loss_names: 291 | self.env.line( 292 | X=np.array([self.epoch]), 293 | Y=np.array([loss_values[loss_name]]), 294 | win=self.panes[loss_name], 295 | update="append", 296 | ) 297 | 298 | def __init_weight__(self, net=None): 299 | if net is None: 300 | net = self 301 | for m in net.modules(): 302 | if isinstance(m, nn.Linear): 303 | nn.init.normal_(m.weight, mean=0, std=0.01) 304 | nn.init.constant_(m.bias, val=0) 305 | 306 | class ERM(BaseModel): 307 | """ 308 | ERM Model 309 | """ 310 | def __init__(self, opt): 311 | super(ERM, self).__init__(opt) 312 | 313 | def __loss_EF__(self): 314 | y_seq_source = self.y_seq[self.domain_mask == 1] 315 | f_seq_source = self.f_seq[self.domain_mask == 1] 316 | loss_E_pred = self.criterion(flat(f_seq_source), flat(y_seq_source)) 317 | return loss_E_pred 318 | 319 | class DRO(BaseModel): 320 | """ 321 | DRO Model 322 | """ 323 | def __init__(self, opt): 324 | super(DRO, self).__init__(opt) 325 | 326 | # q 327 | self.register_buffer("q", torch.Tensor()) 328 | 329 | def __loss_EF__(self): 330 | 331 | y_seq_source = self.y_seq[self.domain_mask == 1] 332 | f_seq_source = self.f_seq[self.domain_mask == 1] 333 | 334 | if not len(self.q): 335 | self.q = torch.ones(len(self.src_domain)).to(self.device) 336 | 337 | losses = torch.zeros(len(self.src_domain)).to(self.device) 338 | 339 | for m in range(len(self.src_domain)): 340 | losses[m] = self.criterion(f_seq_source[m], y_seq_source[m]) 341 | self.q[m] *= (self.groupdro_eta * losses[m].data).exp() 342 | 343 | self.q /= self.q.sum() 344 | 345 | loss_E_pred = torch.dot(losses, self.q) 346 | 347 | return loss_E_pred 348 | 349 | 350 | class IRM(BaseModel): 351 | """ 352 | IRM Model 353 | """ 354 | def __init__(self, opt): 355 | super(IRM, self).__init__(opt) 356 | 357 | # Doesn't use it for now (for penalty and lr decay) 358 | self.register_buffer('update_count', torch.tensor([0])) 359 | self.irm_penal = opt.irm_penal 360 | 361 | def _irm_penalty(self, logits, y): 362 | scale = torch.tensor(1.).to("cuda").requires_grad_() 363 | loss_1 = self.criterion(logits[::2] * scale, y[::2]) 364 | loss_2 = self.criterion(logits[1::2] * scale, y[1::2]) 365 | grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0] 366 | grad_2 = autograd.grad(loss_2, [scale], create_graph=True)[0] 367 | result = torch.sum(grad_1 * grad_2) 368 | return result 369 | 370 | def __loss_EF__(self): 371 | 372 | y_seq_source = self.y_seq[self.domain_mask == 1] 373 | f_seq_source = self.f_seq[self.domain_mask == 1] 374 | 375 | nll = 0. 376 | penalty = 0. 377 | 378 | for m in range(len(self.src_domain)): 379 | nll += self.criterion(f_seq_source[m], y_seq_source[m]) 380 | penalty += self._irm_penalty(f_seq_source[m], y_seq_source[m]) 381 | 382 | nll /= len(self.src_domain) 383 | penalty /= len(self.src_domain) 384 | loss_E_pred = nll + (self.irm_penal * penalty) 385 | 386 | return loss_E_pred 387 | 388 | class TRO(BaseModel): 389 | """ 390 | TRO Model 391 | """ 392 | def __init__(self, opt): 393 | super(TRO, self).__init__(opt) 394 | self.register_buffer("q", torch.Tensor()) 395 | 396 | # Graph centrality 397 | if opt.learn == 0: # physical graph is provided 398 | if opt.partial == 0: # use both source + target graph 399 | central = get_centrality_all(self.A, opt.src_domain, opt.num_domain) 400 | self.prior = central[opt.src_domain] 401 | else: # only use source 402 | central = get_centrality(self.A) 403 | self.prior = central 404 | 405 | self.prior /= self.prior.sum() 406 | self.prior = np.around(self.prior, 3) 407 | 408 | else: # data graph is used 409 | # DG-15 410 | if "15" in opt.dataset: 411 | # replace self.prior with values generated from learn_graph.py 412 | self.prior = [0.0, 0.0, 0.0, 0.0, 0.0, 0.4] 413 | elif "60" in opt.dataset: 414 | self.prior = [0.0, 0.0, 0.0, 0.0, 0.0, 0.2] 415 | 416 | self.prior = np.asarray(self.prior) 417 | self.prior /= self.prior.sum() 418 | 419 | self.prior = torch.from_numpy(self.prior).to(self.device) 420 | 421 | def __loss_EF__(self): 422 | 423 | y_seq_source = self.y_seq[self.domain_mask == 1] 424 | f_seq_source = self.f_seq[self.domain_mask == 1] 425 | 426 | if not len(self.q): 427 | # Learnable q 428 | self.q = torch.ones(len(self.src_domain)).to(self.device) 429 | self.q /= self.q.sum() # Key! 430 | 431 | losses = torch.zeros(len(self.src_domain)).to(self.device) 432 | 433 | for m in range(len(self.src_domain)): 434 | losses[m] = self.criterion(f_seq_source[m], y_seq_source[m]) 435 | self.q[m] += self.groupdro_eta * (losses[m] - self.lmbda * (self.q[m] - self.prior[m])) 436 | 437 | self.q = to_tensor(projection_simplex(to_np(self.q)), self.device) 438 | 439 | loss_E_pred = torch.dot(losses, self.q.to(torch.float)) 440 | 441 | return loss_E_pred -------------------------------------------------------------------------------- /toy/model/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Identity(nn.Module): 6 | def __init__(self): 7 | super(Identity, self).__init__() 8 | 9 | def forward(self, x): 10 | return x 11 | 12 | class FeatureNet(nn.Module): 13 | def __init__(self, opt): 14 | super(FeatureNet, self).__init__() 15 | nx, nh = opt.nx, opt.nh 16 | 17 | self.fc1 = nn.Linear(nx, nh) 18 | self.fc2 = nn.Linear(nh, nh) 19 | self.fc3 = nn.Linear(nh, nh) 20 | self.fc4 = nn.Linear(nh, nh) 21 | self.fc_final = nn.Linear(nh, nh) 22 | 23 | def forward(self, x): 24 | re = x.dim() == 3 25 | if re: 26 | T, B, _ = x.shape 27 | x = x.reshape(T * B, -1) 28 | 29 | x = F.relu(self.fc1(x)) 30 | 31 | if re: 32 | return x.reshape(T, B, -1) 33 | else: 34 | return x 35 | 36 | 37 | class PredNet(nn.Module): 38 | def __init__(self, opt): 39 | super(PredNet, self).__init__() 40 | nh, nc = opt.nh, opt.nc 41 | self.fc3 = nn.Linear(nh, nh) 42 | self.bn3 = nn.BatchNorm1d(nh) 43 | self.fc4 = nn.Linear(nh, nh) 44 | self.bn4 = nn.BatchNorm1d(nh) 45 | self.fc_final = nn.Linear(nh, nc) 46 | if opt.no_bn: 47 | self.bn3 = Identity() 48 | self.bn4 = Identity() 49 | 50 | def forward(self, x, return_feature=False): 51 | re = x.dim() == 3 52 | if re: 53 | T, B, _ = x.shape 54 | x = x.reshape(T * B, -1) 55 | 56 | x = F.relu(self.bn3(self.fc3(x))) 57 | x_feat = F.relu(self.bn4(self.fc4(x))) 58 | 59 | # Classification 60 | x = self.fc_final(x_feat) 61 | x_softmax = F.softmax(x, dim=1) 62 | x = torch.log(x_softmax + 1e-4) 63 | 64 | if re: 65 | x = x.reshape(T, B, -1) 66 | x_softmax = x_softmax.reshape(T, B, -1) 67 | 68 | if return_feature: 69 | return x, x_feat 70 | else: 71 | return x -------------------------------------------------------------------------------- /toy/run_tro.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # DG-15 Physical Graph 4 | python3 main.py --dataset toy_d15 --learn 0 --model TRO 5 | 6 | # DG-15 Data Graph 7 | # Learn the graph and obtain the centrality 8 | python3 learn_graph.py --dataset toy_d15 9 | 10 | python3 main.py --dataset toy_d15 --learn 1 --model TRO 11 | 12 | # DG-60 Physical Graph 13 | python3 main.py --dataset toy_d60 --learn 0 --model TRO 14 | 15 | # DG-60 Data Graph 16 | # Learn the graph and obtain the centrality 17 | python3 learn_graph.py --dataset toy_d60 18 | 19 | python3 main.py $PREFIX --dataset toy_d60 --learn 1 --model TRO -------------------------------------------------------------------------------- /toy/utils/centrality.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | 4 | def get_centrality(A): 5 | G = nx.from_numpy_array(A) 6 | centrality = nx.betweenness_centrality(G, normalized=True) 7 | centrality = np.asarray(list(centrality.values()), dtype=float) 8 | return np.around(centrality, 3) 9 | 10 | def get_centrality_all(A, src_domain, num_domain): 11 | # when we have the physical graph 12 | G = nx.from_numpy_array(A) 13 | all_domains = list(range(num_domain)) 14 | tgt_domain = list(set(all_domains) - set(src_domain)) 15 | centrality = nx.betweenness_centrality_subset(G, src_domain, tgt_domain, normalized=True) 16 | centrality = np.asarray(list(centrality.values()), dtype=float) 17 | return np.around(centrality, 3) 18 | -------------------------------------------------------------------------------- /toy/utils/graph_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import pairwise_distances 3 | from DiffusionEMD import DiffusionTree 4 | import graphtools 5 | 6 | # Use diffusion EMD to get the matrix 7 | def emd(x_seq_feat, feat_size, n_distributions, n_points_per_distribution): 8 | x_seq_feat = x_seq_feat.reshape(-1, feat_size) 9 | x_seq_feat, indices = np.unique(x_seq_feat, axis=0, return_index=True) 10 | minx = np.min(x_seq_feat, axis=0) 11 | maxx = np.max(x_seq_feat, axis=0) + 1e-10 12 | # Normalize 13 | std_X = (x_seq_feat - minx) / (maxx - minx) 14 | dc = DiffusionTree(max_scale=10, delta=1e-10, min_basis=20) 15 | data_graph = graphtools.Graph(std_X, use_pygsp=True, n_pca=100) 16 | group_ids = np.repeat(np.eye(n_distributions), n_points_per_distribution, axis=0) 17 | group_ids = group_ids[indices] 18 | embeds = dc.fit_transform(data_graph.W, group_ids) 19 | dis_matrix = pairwise_distances(embeds) 20 | return dis_matrix -------------------------------------------------------------------------------- /toy/utils/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import torch 4 | import math 5 | import os 6 | import json 7 | from datetime import datetime 8 | 9 | def to_np(x): 10 | return x.detach().cpu().numpy() 11 | 12 | def to_tensor(x, device="cuda"): 13 | if isinstance(x, np.ndarray): 14 | x = torch.from_numpy(x).to(device) 15 | else: 16 | x = x.to(device) 17 | return x 18 | 19 | def flat(x): 20 | n, m = x.shape[:2] 21 | return x.reshape(n * m, *x.shape[2:]) 22 | 23 | def read_pickle(name): 24 | with open(name, "rb") as f: 25 | data = pickle.load(f) 26 | return data 27 | 28 | def write_pickle(data, name): 29 | with open(name, "wb") as f: 30 | pickle.dump(data, f) 31 | 32 | def setup_experiment(args, opt): 33 | # log file 34 | current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 35 | outfolder = os.path.join('runs', args.dataset, args.model, current_time) 36 | 37 | # make directories 38 | if not os.path.exists(outfolder): 39 | os.makedirs(outfolder) 40 | 41 | # TODO: copy config file into the directory 42 | with open(os.path.join(outfolder, 'config.json'), 'w') as outfile: 43 | json.dump(opt, outfile, indent=4) 44 | 45 | return outfolder 46 | 47 | def projection_simplex(v, z=1): 48 | """ 49 | Old implementation for test and benchmark purposes. 50 | The arguments v and z should be a vector and a scalar, respectively. 51 | """ 52 | n_features = v.shape[0] 53 | u = np.sort(v)[::-1] 54 | cssv = np.cumsum(u) - z 55 | ind = np.arange(n_features) + 1 56 | cond = u - cssv / ind > 0 57 | rho = ind[cond][-1] 58 | theta = cssv[cond][-1] / float(rho) 59 | w = np.maximum(v - theta, 0) 60 | return w 61 | 62 | def read_pickle(name): 63 | with open(name, 'rb') as f: 64 | data = pickle.load(f) 65 | return data 66 | 67 | def arr2str(arr): 68 | return [str(i) for i in arr] 69 | 70 | def rotate(origin, point, angle): 71 | """ 72 | Rotate a point counterclockwise by a given angle around a given origin. 73 | 74 | The angle should be given in radians. 75 | """ 76 | ox, oy = origin 77 | px, py = point 78 | 79 | qx = ox + math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy) 80 | qy = oy + math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy) 81 | return np.asarray([qx, qy]) --------------------------------------------------------------------------------