├── KernelHRM_sim1.py ├── KernelHRM_sim2.py ├── README.md └── reproduce.sh /KernelHRM_sim1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.decomposition import TruncatedSVD 3 | import torch 4 | from tqdm import tqdm 5 | from torch.autograd import grad 6 | import copy 7 | from torch import nn 8 | import argparse 9 | from sklearn.utils.extmath import randomized_svd 10 | import random 11 | import torch.optim as optim 12 | import math 13 | import torch.nn.functional as F 14 | 15 | np.set_printoptions(precision=4) 16 | 17 | def pretty(vector): 18 | if type(vector) is list: 19 | vlist = vector 20 | elif type(vector) is np.ndarray: 21 | vlist = vector.reshape(-1).tolist() 22 | else: 23 | vlist = vector.view(-1).tolist() 24 | return "[" + ", ".join("{:+.4f}".format(vi) for vi in vlist) + "]" 25 | 26 | 27 | def setup_seed(seed): 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | np.random.seed(seed) 31 | random.seed(seed) 32 | torch.backends.cudnn.deterministic = True 33 | 34 | 35 | def generate_data(num_data=10000, d=5, bias=0.5, scramble=0, sigma_s=3.0, sigma_v=0.3): 36 | from scipy.stats import ortho_group 37 | S = np.float32(ortho_group.rvs(size=1, dim=2*d, random_state=1)) 38 | y = np.random.choice([1, -1], size=(num_data, 1)) 39 | X = np.random.randn(num_data, d * 2) 40 | X[:, :d] *= sigma_s 41 | X[:, d:] *= sigma_v 42 | flip = np.random.choice([1, -1], size=(num_data, 1), p=[bias, 1. - bias]) * y 43 | X[:, :d] += y 44 | X[:, d:] += flip 45 | if scramble == 1: 46 | X = np.matmul(X, S) 47 | X, y = torch.from_numpy(X).float(), torch.from_numpy(y).float() 48 | return X, y 49 | 50 | 51 | def generate_data_list(args): 52 | X_list, y_list = [], [] 53 | for i, r in enumerate(args.r_list): 54 | X, y = generate_data(num_data=args.num_list[i], bias=r, scramble=args.scramble) 55 | X_list.append(X.to(args.device)) 56 | y_list.append(y.to(args.device)) 57 | return X_list, y_list 58 | 59 | 60 | class MLP(nn.Module): 61 | def __init__(self, m=1024): 62 | super().__init__() 63 | self.layer1 = nn.Linear(10, m) 64 | self.layer2 = nn.Linear(m, 1) 65 | self.relu = nn.ReLU(True) 66 | 67 | def forward(self, x): 68 | x = self.relu(self.layer1(x)) 69 | x = self.layer2(x) 70 | return x 71 | 72 | 73 | def compute_num_params(model, verbose): 74 | num_params = 0 75 | for p in model.parameters(): 76 | num_params += len(p.view(-1).detach().cpu().numpy()) 77 | if verbose: 78 | print("Number of parameters is: %d" % num_params) 79 | return num_params 80 | 81 | 82 | def compute_NTF(model, X, num_params, args): 83 | model.zero_grad() 84 | y = model(X).squeeze() 85 | ret = torch.zeros(len(y), num_params).to(args.device) 86 | for i, loss in (enumerate(y)): 87 | loss.backward(retain_graph=True) 88 | gradients = [] 89 | for p in model.parameters(): 90 | gradients.append(p.grad.view(-1)) 91 | gradients = torch.cat(gradients, dim=-1) - torch.sum(ret, dim=0) 92 | assert len(gradients) == num_params 93 | ret[i, :] = gradients 94 | return ret.detach().cpu().numpy() 95 | 96 | 97 | def main_Compute_NTF(args): 98 | X_list, _ = generate_data_list(args) 99 | X_list = torch.cat(X_list, dim=0) 100 | 101 | model = MLP().to(args.device) 102 | 103 | num_params = compute_num_params(model, False) 104 | 105 | NTF = compute_NTF(model, X_list, num_params, args) 106 | U, S, VT = randomized_svd(NTF, n_components=50, n_iter=10, random_state=42) 107 | print((np.mean(U[:1000, :], axis=0) - np.mean(U[1000:, :], axis=0))[0:21]) 108 | return 109 | 110 | 111 | 112 | class LinearRegression(nn.Module): 113 | def __init__(self, input_dim, output_dim=1): 114 | super(LinearRegression, self).__init__() 115 | self.linear = nn.Linear(input_dim, output_dim, bias=True) 116 | self.weight_init() 117 | 118 | def weight_init(self): 119 | torch.nn.init.xavier_uniform_(self.linear.weight) 120 | 121 | def forward(self, x): 122 | return self.linear(x) 123 | 124 | class OLS: 125 | def __init__(self, X, y, args): 126 | self.model = LinearRegression(X.shape[1], 1) 127 | self.X = X 128 | self.y = y 129 | self.loss = nn.MSELoss() 130 | self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3) 131 | self.device = args.device 132 | 133 | def to_cuda(self): 134 | self.model.cuda(self.device) 135 | self.X = self.X.to(self.device) 136 | self.y = self.y.to(self.device) 137 | 138 | def train(self): 139 | # self.to_cuda() 140 | self.model.weight_init() 141 | epochs = 3000 142 | 143 | for epoch in range(epochs): 144 | self.optimizer.zero_grad() 145 | pred = self.model(self.X) 146 | loss = self.loss(pred, self.y) \ 147 | + 1e-2 * torch.mean(torch.abs(self.model.linear.weight)) 148 | loss.backward(retain_graph=True) 149 | self.optimizer.step() 150 | # if epoch % 100 == 0: 151 | # print("Epoch %d | Loss = %.4f" % (epoch, loss)) 152 | return self.model.linear.weight.clone().cpu().detach(), self.model.linear.bias.clone().cpu().detach() 153 | 154 | class Cluster: 155 | def __init__(self, feature, y, K, args): 156 | self.feature = feature.cpu() 157 | self.label = y.cpu() 158 | self.K = K 159 | self.args = args 160 | self.center = None 161 | self.bias = None 162 | self.domain = None 163 | 164 | # run weighted lasso for each cluster and get new coefs and biases 165 | def ols(self): 166 | for i in range(self.K): 167 | index = torch.where(self.domain == i)[0] 168 | tempx = (self.feature[index, :]).reshape(-1, self.feature.shape[1]) 169 | tempy = (self.label[index, :]).reshape(-1, 1) 170 | clf = OLS(tempx, tempy, self.args) 171 | self.center[i, :], self.bias[i] = clf.train() 172 | 173 | def clustering(self, past_domains=None): 174 | # init 175 | self.center = torch.tensor(np.zeros((self.K, self.feature.shape[1]), dtype=np.float32)) 176 | self.bias = torch.tensor(np.zeros(self.K, dtype=np.float32)) 177 | 178 | # using last domains as the initialization 179 | if past_domains is None: 180 | self.domain = torch.tensor(np.random.randint(0, self.K, self.feature.shape[0])) 181 | else: 182 | self.domain = past_domains 183 | assert self.domain.shape[0] == self.feature.shape[0] 184 | 185 | # flags 186 | iter = 0 187 | end_flag = False 188 | delta_threshold = 0.1 * self.feature.shape[0]/self.K 189 | while not end_flag: 190 | iter += 1 191 | self.ols() 192 | ols_error = [] 193 | 194 | for i in range(self.K): 195 | coef = self.center[i].reshape(-1, 1) 196 | error = torch.abs(torch.mm(self.feature, coef) + self.bias[i] - self.label) 197 | assert error.shape == (self.feature.shape[0], 1) 198 | ols_error.append(error) 199 | ols_error = torch.stack(ols_error, dim=0).reshape(self.K, self.feature.shape[0]) 200 | 201 | new_domain = torch.argmin(ols_error, dim=0) 202 | assert new_domain.shape[0] == self.feature.shape[0] 203 | diff = self.domain.reshape(-1, 1) - new_domain.reshape(-1, 1) 204 | diff[diff != 0] = 1 205 | delta = torch.sum(diff) 206 | if iter % 10 == 9: 207 | print("Iter %d | Delta = %d" % (iter, delta)) 208 | if delta <= delta_threshold: 209 | end_flag = True 210 | self.domain = new_domain 211 | 212 | 213 | return self.domain 214 | 215 | def main_KernelHRM(args): 216 | print("Kernel HRM") 217 | 218 | class Linear_Model(nn.Module): 219 | def __init__(self, d=30): 220 | super().__init__() 221 | self.linear = nn.Linear(d, 1, bias=False) 222 | # nn.init.xavier_uniform_(self.linear.weight, gain=nn.init.calculate_gain('relu')) 223 | nn.init.xavier_uniform_(self.linear.weight, gain=0.1) 224 | 225 | def forward(self, f_w0, X): 226 | return f_w0 + self.linear(X) 227 | 228 | 229 | 230 | train_record = [] 231 | test_record = [] 232 | # data 233 | X_list, y_list = generate_data_list(args) 234 | train_X, train_y = torch.cat([X_list[0], X_list[1]], dim=0), torch.cat([y_list[0], y_list[1]], dim=0) 235 | test_X, test_y = X_list[2], y_list[2] 236 | 237 | model = MLP().to(args.device) 238 | init_params = torch.cat([p.view(-1) for p in model.parameters()], 0) 239 | criterion = torch.nn.MSELoss() 240 | NTF = compute_NTF(model, train_X, compute_num_params(model, False), args) 241 | test_NTF = compute_NTF(model, test_X, compute_num_params(model, False), args) 242 | 243 | U, S, VT = randomized_svd(NTF, n_components=args.k, n_iter=10, random_state=42) 244 | U, S, VT = torch.from_numpy(U).float().to(args.device), torch.from_numpy(S).float().to( 245 | args.device), torch.from_numpy(VT).float().to(args.device) 246 | U_train = torch.matmul(U, torch.diag(S)) 247 | U_test = torch.from_numpy(test_NTF).float().to(args.device) 248 | U_test = torch.matmul(U_test, VT.permute(1, 0)) 249 | train_feature = copy.deepcopy(U_train) 250 | U_train_sum = torch.sum(U_train.pow(2), dim=1) 251 | print(U_train_sum.shape) 252 | U_train_norm = torch.mean(torch.sqrt(U_train_sum)) 253 | print("U_train norm is %.4f" % U_train_norm.data) 254 | 255 | # whole iteration 256 | past_domains = None 257 | for epoch in range(args.whole_epoch): 258 | print('--------------epoch %d---------------' % epoch) 259 | # frontend 260 | cluster_model = Cluster(train_feature, train_y, args.cluster_num, args) 261 | cluster_results = cluster_model.clustering(past_domains) 262 | past_domains = cluster_results 263 | index0 = torch.where(cluster_results==0)[0] 264 | index1 = torch.where(cluster_results==1)[0] 265 | 266 | # calculate envs 267 | env_num_list = [] 268 | for i in range(args.cluster_num): 269 | idx = torch.where(cluster_results[:1000, ] == i)[0] 270 | env_num_list.append(idx.shape[0]) 271 | print('The first environment is split into : %s', pretty(env_num_list)) 272 | 273 | env_num_list = [] 274 | for i in range(args.cluster_num): 275 | idx = torch.where(cluster_results[1000:, ] == i)[0] 276 | env_num_list.append(idx.shape[0]) 277 | print('The second environment is split into : %s', pretty(env_num_list)) 278 | 279 | # backend 280 | correct1 = 0.0 281 | correct2 = 0.0 282 | flag = True 283 | theta_inv = None 284 | 285 | while flag: 286 | print("Step 1: Linear MIP") 287 | model_IRM = Linear_Model(d=U_train.shape[1]).to(args.device) 288 | model.eval() 289 | f_w0 = model(train_X).detach() 290 | opt_IRM = torch.optim.Adam(model_IRM.parameters(), lr=args.lr) 291 | 292 | for epoch in (range(1, args.epochs + 1)): 293 | model_IRM.train() 294 | scale = torch.tensor(1.).to(args.device).requires_grad_() 295 | 296 | yhat = model_IRM(f_w0[index0], U_train[index0,:]) 297 | loss_1 = criterion(yhat, train_y[index0]) 298 | grad_1 = grad(criterion(yhat * scale, train_y[index0]), [scale], create_graph=True)[0] 299 | 300 | yhat = model_IRM(f_w0[index1], U_train[index1,:]) 301 | loss_2 = criterion(yhat, train_y[index1]) 302 | grad_2 = grad(criterion(yhat * scale, train_y[index1]), [scale], create_graph=True)[0] 303 | 304 | penalty = (grad_1-grad_2).pow(2).mean() 305 | 306 | IRM_lam = args.IRM_lam if epoch > args.IRM_ann else 1.0 307 | loss = (loss_1 + loss_2) / 2 + IRM_lam * penalty 308 | 309 | opt_IRM.zero_grad() 310 | loss.backward() 311 | opt_IRM.step() 312 | 313 | pred_train = 2 * ((model_IRM(f_w0, U_train) > 0).float() - 0.5) 314 | correct1 = float(pred_train[index0].eq(train_y[index0]).sum().item())/len(index0) 315 | correct2 = float(pred_train[index1].eq(train_y[index1]).sum().item())/len(index1) 316 | 317 | 318 | correct = float(pred_train.eq(train_y).sum().item()) 319 | total = pred_train.size(0) 320 | train_acc = correct / total 321 | 322 | model_IRM.eval() 323 | yhat = model_IRM(model(test_X), U_test) 324 | pred_test = 2 * ((yhat > 0).float() - 0.5) 325 | correct = float(pred_test.eq(test_y).sum().item()) 326 | total = pred_test.size(0) 327 | test_acc = correct / total 328 | 329 | if epoch % args.epochs == 0: 330 | print("Linear MIP epoch: %d, Train Acc: %f, Test Acc: %f" % (epoch, train_acc, test_acc)) 331 | print("Env 1 %.4f Env 2 %.4f" % (correct1, correct2)) 332 | 333 | theta_inv = copy.deepcopy(model_IRM.linear.weight.data) 334 | flag = False 335 | 336 | train_record.append(train_acc) 337 | test_record.append(test_acc) 338 | theta_inv = theta_inv/(torch.sqrt(torch.sum(theta_inv.pow(2)))) 339 | print(theta_inv) 340 | print(torch.sum(theta_inv.pow(2))) 341 | inner_product = torch.matmul(U_train, theta_inv.reshape(-1,1)) 342 | assert inner_product.shape[1]==1 and inner_product.shape[0]==U_train.shape[0] 343 | train_feature = U_train - torch.matmul(inner_product, theta_inv.reshape(1,-1)) 344 | 345 | print(train_feature.shape) 346 | 347 | return train_acc, test_acc, train_record, test_record 348 | 349 | 350 | 351 | 352 | 353 | if __name__ == "__main__": 354 | parser = argparse.ArgumentParser(description='Kernelized-HRM') 355 | parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') 356 | parser.add_argument('--k', type=int, default=60, help='k for SVD') 357 | parser.add_argument('--IRM_lam', type=float, default=6e1, help='IRM lambda') # 3-5e1 for IGD # 1e3 for IRM 358 | parser.add_argument('--IRM_ann', type=int, default=500, help='IRM annealing') # 200-400 for IGD, 400 for IRM 359 | parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs to train') 360 | parser.add_argument('--device', type=str, default='0') 361 | parser.add_argument('--r_list', type=float, nargs='+', default=[0.8, 0.9, 0.1]) 362 | parser.add_argument('--num_list', type=int, nargs='+', default=[1000, 1000, 1000]) 363 | parser.add_argument('--seed', type=int, default=0) 364 | parser.add_argument('--whole_epoch', type=int, default=5) 365 | parser.add_argument('--cluster_num', type=int, default=2) 366 | parser.add_argument('--scramble', type=int, default=0) 367 | args = parser.parse_args() 368 | 369 | args.device = torch.device("cuda:" + args.device if torch.cuda.is_available() and int(args.device)>0 else "cpu") 370 | 371 | setup_seed(args.seed) 372 | 373 | train_acc_list = [] 374 | test_acc_list = [] 375 | train_all = [] 376 | test_all = [] 377 | for seed in range(10): 378 | print("-----------------seed %d ----------------" % seed) 379 | setup_seed(seed) 380 | result = main_KernelHRM(args) 381 | train_acc_list.append(result[0]) 382 | test_acc_list.append(result[1]) 383 | train_all.append(result[2]) 384 | test_all.append(result[3]) 385 | train_acc_list = np.array(train_acc_list) 386 | test_acc_list = np.array(test_acc_list) 387 | print("MIP Train Mean %.4f std %.4f" % (np.mean(train_acc_list), np.std(train_acc_list))) 388 | print("MIP Test Mean %.4f std %.4f" % (np.mean(test_acc_list), np.std(test_acc_list))) 389 | print(train_all) 390 | print(test_all) 391 | print(np.mean(np.array(train_all), axis=0)) 392 | print(np.std(np.array(train_all), axis=0)) 393 | print(np.mean(np.array(test_all), axis=0)) 394 | print(np.std(np.array(test_all), axis=0)) 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | -------------------------------------------------------------------------------- /KernelHRM_sim2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.decomposition import TruncatedSVD 3 | import torch 4 | from tqdm import tqdm 5 | from torch.autograd import grad 6 | import copy 7 | from torch import nn 8 | import argparse 9 | from sklearn.utils.extmath import randomized_svd 10 | import random 11 | import torch.optim as optim 12 | import math 13 | import torch.nn.functional as F 14 | from EIIL import LearnedEnvInvariantRiskMinimization 15 | 16 | np.set_printoptions(precision=4) 17 | 18 | from multiprocessing import cpu_count 19 | import os 20 | cpu_num = 15 21 | os.environ ['OMP_NUM_THREADS'] = str(cpu_num) 22 | os.environ ['OPENBLAS_NUM_THREADS'] = str(cpu_num) 23 | os.environ ['MKL_NUM_THREADS'] = str(cpu_num) 24 | os.environ ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num) 25 | os.environ ['NUMEXPR_NUM_THREADS'] = str(cpu_num) 26 | torch.set_num_threads(cpu_num) 27 | 28 | def pretty(vector): 29 | if type(vector) is list: 30 | vlist = vector 31 | elif type(vector) is np.ndarray: 32 | vlist = vector.reshape(-1).tolist() 33 | else: 34 | vlist = vector.view(-1).tolist() 35 | return "[" + ", ".join("{:+.4f}".format(vi) for vi in vlist) + "]" 36 | 37 | def sign(x): 38 | if x > 0: 39 | return 1 40 | if x < 0: 41 | return -1 42 | return 0 43 | 44 | 45 | def setup_seed(seed): 46 | torch.manual_seed(seed) 47 | torch.cuda.manual_seed_all(seed) 48 | np.random.seed(seed) 49 | random.seed(seed) 50 | torch.backends.cudnn.deterministic = True 51 | 52 | def data_generation(n1, n2, ps, pvb, pv, r, scramble): 53 | S = np.random.normal(0, 2, [n1, ps]) 54 | V = np.random.normal(0, 2, [n1, pvb + pv]) 55 | 56 | Z = np.random.normal(0, 1, [n1, ps + 1]) 57 | for i in range(ps): 58 | S[:, i:i + 1] = 0.8 * Z[:, i:i + 1] + 0.2 * Z[:, i + 1:i + 2] 59 | 60 | beta = np.zeros((ps, 1)) 61 | for i in range(ps): 62 | beta[i] = (-1) ** i * (i % 3 + 1) * 1.0/2 63 | 64 | noise = np.random.normal(0, 1.0, [n1, 1]) 65 | 66 | Y = np.dot(S, beta) + noise + 5 * S[:, 0:1] * S[:, 1:2] * S[:, 2:3] 67 | index_pre = np.ones([n1, 1], dtype=bool) 68 | for i in range(pvb): 69 | D = np.abs(V[:, pv + i:pv + i + 1] * sign(r) - Y) 70 | pro = np.power(np.abs(r), -D * 5) 71 | selection_bias = np.random.random([n1, 1]) 72 | index_pre = index_pre & ( 73 | selection_bias < pro) 74 | index = np.where(index_pre == True) 75 | S_re = S[index[0], :] 76 | V_re = V[index[0], :] 77 | Y_re = Y[index[0]] 78 | n, p = S_re.shape 79 | index_s = np.random.permutation(n) 80 | 81 | X_re = np.hstack((S_re, V_re)) 82 | beta_X = np.vstack((beta, np.zeros((pv + pvb, 1)))) 83 | 84 | X = torch.from_numpy(X_re[index_s[0:n2], :]).float() 85 | y = torch.from_numpy(Y_re[index_s[0:n2], :]).float() 86 | 87 | from scipy.stats import ortho_group 88 | S = np.float32(ortho_group.rvs(size=1, dim=X.shape[1], random_state=1)) 89 | if scramble == 1: 90 | X = torch.matmul(X, torch.Tensor(S)) 91 | 92 | return X, y 93 | 94 | def generate_data_list(args): 95 | n1 = 1000000 96 | p = 10 97 | ps = int(p * 0.5) 98 | pvb = int(p * 0.1) 99 | pv = p - ps - pvb 100 | 101 | X_list, y_list = [], [] 102 | for i, r in enumerate(args.r_list): 103 | X, y = data_generation(n1, args.num_list[i], ps, pvb, pv, args.r_list[i], args.scramble) 104 | X_list.append(X.to(args.device)) 105 | y_list.append(y.to(args.device)) 106 | return X_list, y_list 107 | 108 | 109 | def generate_test_data_list(args): 110 | n1 = 1000000 111 | p = 10 112 | ps = int(p * 0.5) 113 | pvb = int(p * 0.1) 114 | pv = p - ps - pvb 115 | 116 | X_list, y_list = [], [] 117 | for r in [-2.9, -2.7, -2.5, -2.3, -2.1, -1.9]: 118 | X, y = data_generation(n1, 1000, ps, pvb, pv, r, args.scramble) 119 | X_list.append(X.to(args.device)) 120 | y_list.append(y.to(args.device)) 121 | return X_list, y_list 122 | 123 | 124 | class MLP(nn.Module): 125 | def __init__(self, m=1024): 126 | super().__init__() 127 | self.layer1 = nn.Linear(10, m) 128 | self.layer2 = nn.Linear(m, 1) 129 | self.relu = nn.ReLU(True) 130 | 131 | def forward(self, x): 132 | x = self.relu(self.layer1(x)) 133 | x = self.layer2(x) 134 | return x 135 | 136 | 137 | def compute_num_params(model, verbose): 138 | num_params = 0 139 | for p in model.parameters(): 140 | num_params += len(p.view(-1).detach().cpu().numpy()) 141 | if verbose: 142 | print("Number of parameters is: %d" % num_params) 143 | return num_params 144 | 145 | 146 | def compute_NTF(model, X, num_params, args): 147 | model.zero_grad() 148 | y = model(X).squeeze() 149 | ret = torch.zeros(len(y), num_params).to(args.device) 150 | for i, loss in (enumerate(y)): 151 | loss.backward(retain_graph=True) 152 | gradients = [] 153 | for p in model.parameters(): 154 | gradients.append(p.grad.view(-1)) 155 | gradients = torch.cat(gradients, dim=-1) - torch.sum(ret, dim=0) 156 | assert len(gradients) == num_params 157 | ret[i, :] = gradients 158 | return ret.detach().cpu().numpy() 159 | 160 | 161 | def main_Compute_NTF(args): 162 | X_list, _ = generate_data_list(args) 163 | X_list = torch.cat(X_list, dim=0) 164 | 165 | model = MLP().to(args.device) 166 | 167 | num_params = compute_num_params(model, False) 168 | 169 | NTF = compute_NTF(model, X_list, num_params, args) 170 | U, S, VT = randomized_svd(NTF, n_components=50, n_iter=10, random_state=42) 171 | print((np.mean(U[:1000, :], axis=0) - np.mean(U[1000:, :], axis=0))[0:21]) 172 | return 173 | 174 | 175 | 176 | class LinearRegression(nn.Module): 177 | def __init__(self, input_dim, output_dim=1): 178 | super(LinearRegression, self).__init__() 179 | self.linear = nn.Linear(input_dim, output_dim, bias=True) 180 | self.weight_init() 181 | 182 | def weight_init(self): 183 | torch.nn.init.xavier_uniform_(self.linear.weight) 184 | 185 | def forward(self, x): 186 | return self.linear(x) 187 | 188 | class OLS: 189 | def __init__(self, X, y, args): 190 | self.model = LinearRegression(X.shape[1], 1) 191 | self.X = X 192 | self.y = y 193 | self.loss = nn.MSELoss() 194 | self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3) 195 | self.device = args.device 196 | 197 | def to_cuda(self): 198 | self.model.cuda(self.device) 199 | self.X = self.X.to(self.device) 200 | self.y = self.y.to(self.device) 201 | 202 | def train(self): 203 | self.model.weight_init() 204 | epochs = 3000 205 | 206 | for epoch in range(epochs): 207 | self.optimizer.zero_grad() 208 | pred = self.model(self.X) 209 | loss = self.loss(pred, self.y) \ 210 | + 1e-2 * torch.mean(torch.abs(self.model.linear.weight)) 211 | loss.backward(retain_graph=True) 212 | self.optimizer.step() 213 | return self.model.linear.weight.clone().cpu().detach(), self.model.linear.bias.clone().cpu().detach() 214 | 215 | class Cluster: 216 | def __init__(self, feature, y, K, args): 217 | self.feature = feature.cpu() 218 | self.label = y.cpu() 219 | self.K = K 220 | self.args = args 221 | self.center = None 222 | self.bias = None 223 | self.domain = None 224 | 225 | # run weighted lasso for each cluster and get new coefs and biases 226 | def ols(self): 227 | for i in range(self.K): 228 | index = torch.where(self.domain == i)[0] 229 | tempx = (self.feature[index, :]).reshape(-1, self.feature.shape[1]) 230 | tempy = (self.label[index, :]).reshape(-1, 1) 231 | clf = OLS(tempx, tempy, self.args) 232 | self.center[i, :], self.bias[i] = clf.train() 233 | 234 | def clustering(self, past_domains=None): 235 | # init 236 | self.center = torch.tensor(np.zeros((self.K, self.feature.shape[1]), dtype=np.float32)) 237 | self.bias = torch.tensor(np.zeros(self.K, dtype=np.float32)) 238 | 239 | # using last domains as the initialization 240 | if past_domains is None: 241 | self.domain = torch.tensor(np.random.randint(0, self.K, self.feature.shape[0])) 242 | else: 243 | self.domain = past_domains 244 | assert self.domain.shape[0] == self.feature.shape[0] 245 | 246 | # flags 247 | iter = 0 248 | end_flag = False 249 | delta_threshold = 0.1 * self.feature.shape[0]/self.K 250 | while not end_flag: 251 | iter += 1 252 | self.ols() 253 | ols_error = [] 254 | 255 | for i in range(self.K): 256 | coef = self.center[i].reshape(-1, 1) 257 | error = torch.abs(torch.mm(self.feature, coef) + self.bias[i] - self.label) 258 | assert error.shape == (self.feature.shape[0], 1) 259 | ols_error.append(error) 260 | ols_error = torch.stack(ols_error, dim=0).reshape(self.K, self.feature.shape[0]) 261 | 262 | new_domain = torch.argmin(ols_error, dim=0) 263 | assert new_domain.shape[0] == self.feature.shape[0] 264 | diff = self.domain.reshape(-1, 1) - new_domain.reshape(-1, 1) 265 | diff[diff != 0] = 1 266 | delta = torch.sum(diff) 267 | if iter % 10 == 9: 268 | print("Iter %d | Delta = %d" % (iter, delta)) 269 | if delta <= delta_threshold: 270 | end_flag = True 271 | self.domain = new_domain 272 | 273 | 274 | return self.domain 275 | 276 | def main_KernelHRM(args): 277 | print("Kernel HRM") 278 | 279 | class Linear_Model(nn.Module): 280 | def __init__(self, d=30): 281 | super().__init__() 282 | self.linear = nn.Linear(d, 1, bias=False) 283 | nn.init.xavier_uniform_(self.linear.weight, gain=0.1) 284 | 285 | def forward(self, f_w0, X): 286 | return f_w0 + self.linear(X) 287 | 288 | 289 | 290 | train_record = np.zeros(args.whole_epoch) 291 | test_record = np.zeros(args.whole_epoch) 292 | mean_stable_record = np.zeros(args.whole_epoch) 293 | std_stable_record = np.zeros(args.whole_epoch) 294 | # data 295 | X_list, y_list = generate_data_list(args) 296 | train_X, train_y = torch.cat([X_list[0], X_list[1]], dim=0), torch.cat([y_list[0], y_list[1]], dim=0) 297 | test_X, test_y = X_list[2], y_list[2] 298 | 299 | test_X_list, test_y_list = generate_test_data_list(args) 300 | 301 | model = MLP().to(args.device) 302 | init_params = torch.cat([p.view(-1) for p in model.parameters()], 0) 303 | criterion = torch.nn.MSELoss() 304 | NTF = compute_NTF(model, train_X, compute_num_params(model, False), args) 305 | test_NTF = compute_NTF(model, test_X, compute_num_params(model, False), args) 306 | U, S, VT = randomized_svd(NTF, n_components=args.k, n_iter=10, random_state=42) 307 | U, S, VT = torch.from_numpy(U).float().to(args.device), torch.from_numpy(S).float().to( 308 | args.device), torch.from_numpy(VT).float().to(args.device) 309 | U_train = torch.matmul(U, torch.diag(S)) 310 | U_test = torch.from_numpy(test_NTF).float().to(args.device) 311 | U_test = torch.matmul(U_test, VT.permute(1, 0)) 312 | train_feature = copy.deepcopy(U_train) 313 | U_train_sum = torch.sum(U_train.pow(2), dim=1) 314 | print(U_train_sum.shape) 315 | U_train_norm = torch.mean(torch.sqrt(U_train_sum)) 316 | print("U_train norm is %.4f" % U_train_norm.data) 317 | 318 | 319 | tu_list = [] 320 | for idx, tx in enumerate(test_X_list): 321 | tu = compute_NTF(model, tx, compute_num_params(model, False), args) 322 | tu = torch.from_numpy(tu).float().to(args.device) 323 | tu = torch.matmul(tu, VT.permute(1, 0)) 324 | tu_list.append(tu) 325 | 326 | # whole iteration 327 | past_domains = None 328 | for whole_epoch in range(args.whole_epoch): 329 | print('--------------epoch %d---------------' % whole_epoch) 330 | # frontend 331 | cluster_model = Cluster(train_feature, train_y, args.cluster_num, args) 332 | cluster_results = cluster_model.clustering(past_domains) 333 | past_domains = cluster_results 334 | index0 = torch.where(cluster_results==0)[0] 335 | index1 = torch.where(cluster_results==1)[0] 336 | 337 | # calculate envs 338 | env_num_list = [] 339 | for i in range(args.cluster_num): 340 | idx = torch.where(cluster_results[:1000, ] == i)[0] 341 | env_num_list.append(idx.shape[0]) 342 | print('The first environment is split into : %s', pretty(env_num_list)) 343 | 344 | env_num_list = [] 345 | for i in range(args.cluster_num): 346 | idx = torch.where(cluster_results[1000:, ] == i)[0] 347 | env_num_list.append(idx.shape[0]) 348 | print('The second environment is split into : %s', pretty(env_num_list)) 349 | 350 | # backend 351 | flag = True 352 | theta_inv = None 353 | 354 | while flag: 355 | print("Step 1: Linear MIP") 356 | model_IRM = Linear_Model(d=U_train.shape[1]).to(args.device) 357 | model.eval() 358 | f_w0 = model(train_X).detach() 359 | opt_IRM = torch.optim.Adam(model_IRM.parameters(), lr=args.lr) 360 | 361 | for epoch in (range(1, args.epochs + 1)): 362 | model_IRM.train() 363 | 364 | yhat = model_IRM(f_w0[index0], U_train[index0,:]) 365 | loss_1 = criterion(yhat, train_y[index0]) 366 | grad_1 = grad(criterion(yhat, train_y[index0]), model_IRM.parameters(), create_graph=True)[0] 367 | 368 | yhat = model_IRM(f_w0[index1], U_train[index1,:]) 369 | loss_2 = criterion(yhat, train_y[index1]) 370 | grad_2 = grad(criterion(yhat, train_y[index1]), model_IRM.parameters(), create_graph=True)[0] 371 | 372 | penalty = (grad_1-grad_2).pow(2).mean() 373 | 374 | IRM_lam = args.IRM_lam if epoch > args.IRM_ann else 0.6 375 | loss = (loss_1 + loss_2) / 2 + IRM_lam * penalty 376 | 377 | opt_IRM.zero_grad() 378 | loss.backward() 379 | opt_IRM.step() 380 | 381 | model_IRM.eval() 382 | yhat = model_IRM(model(train_X), U_train) 383 | train_error = criterion(yhat, train_y) 384 | yhat = model_IRM(model(test_X), U_test) 385 | test_error = criterion(yhat, test_y) 386 | 387 | if epoch % 100 == 0: 388 | print("Linear MIP epoch: %d, Train Error: %f, Test Error: %f" % (epoch, train_error, test_error)) 389 | 390 | theta_inv = copy.deepcopy(model_IRM.linear.weight.data) 391 | flag = False 392 | train_record[whole_epoch] = train_error.data 393 | test_record[whole_epoch] = test_error.data 394 | theta_inv = theta_inv/(torch.sqrt(torch.sum(theta_inv.pow(2)))) 395 | inner_product = torch.matmul(U_train, theta_inv.reshape(-1,1)) 396 | assert inner_product.shape[1]==1 and inner_product.shape[0]==U_train.shape[0] 397 | train_feature = U_train - torch.matmul(inner_product, theta_inv.reshape(1,-1)) 398 | 399 | print(train_feature.shape) 400 | 401 | # testing stage 402 | stable_test_error_list = [] 403 | for idx, tu in enumerate(tu_list): 404 | model_IRM.eval() 405 | yhat = model_IRM(model(test_X_list[idx]), tu) 406 | s_error = criterion(yhat, test_y_list[idx]) 407 | stable_test_error_list.append(s_error.data) 408 | stable_test_error_list = np.array(stable_test_error_list) 409 | mean_stable_error = np.mean(stable_test_error_list) 410 | std_stable_error = np.std(stable_test_error_list) 411 | mean_stable_record[whole_epoch] = mean_stable_error 412 | std_stable_record[whole_epoch] = std_stable_error 413 | print('Whole Epoch % d Mean %.4f Std %.4f' % (whole_epoch, np.mean(stable_test_error_list), np.std(stable_test_error_list))) 414 | 415 | 416 | return train_error.data, test_error.data, train_record, test_record, mean_stable_record, std_stable_record 417 | 418 | 419 | 420 | 421 | 422 | if __name__ == "__main__": 423 | parser = argparse.ArgumentParser(description='Kernelized-HRM') 424 | parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') 425 | parser.add_argument('--k', type=int, default=60, help='k for SVD') 426 | parser.add_argument('--IRM_lam', type=float, default=6e1, help='IRM lambda') 427 | parser.add_argument('--IRM_ann', type=int, default=500, help='IRM annealing') 428 | parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs to train') 429 | parser.add_argument('--device', type=str, default='0') 430 | parser.add_argument('--r_list', type=float, nargs='+', default=[0.8, 0.9, 0.1]) 431 | parser.add_argument('--num_list', type=int, nargs='+', default=[1000, 1000, 1000]) 432 | parser.add_argument('--method', type=str, default='KIRM') 433 | parser.add_argument('--seed', type=int, default=0) 434 | parser.add_argument('--whole_epoch', type=int, default=5) 435 | parser.add_argument('--cluster_num', type=int, default=2) 436 | parser.add_argument('--scramble', type=int, default=0) 437 | args = parser.parse_args() 438 | 439 | args.device = torch.device("cuda:" + args.device if torch.cuda.is_available() and int(args.device)>0 else "cpu") 440 | 441 | setup_seed(args.seed) 442 | 443 | train_acc_list = [] 444 | test_acc_list = [] 445 | train_all = [] 446 | test_all = [] 447 | mean_all = [] 448 | std_all = [] 449 | for seed in range(9): 450 | print("-----------------seed %d ----------------" % seed) 451 | setup_seed(seed) 452 | result = main_KernelHRM(args) 453 | train_acc_list.append(result[0]) 454 | test_acc_list.append(result[1]) 455 | train_all.append(result[2]) 456 | test_all.append(result[3]) 457 | mean_all.append(result[4]) 458 | std_all.append(result[5]) 459 | train_acc_list = np.vstack(train_acc_list) 460 | test_acc_list = np.vstack(test_acc_list) 461 | 462 | print(train_acc_list) 463 | print(test_acc_list) 464 | 465 | print('===========mean=============') 466 | print(mean_all) 467 | print(std_all) 468 | 469 | 470 | print("MIP Train Mean %.4f std %.4f" % (np.mean(train_acc_list), np.std(train_acc_list))) 471 | print("MIP Test Mean %.4f std %.4f" % (np.mean(test_acc_list), np.std(test_acc_list))) 472 | 473 | print(train_all) 474 | print(test_all) 475 | 476 | print(np.mean(np.array(train_all), axis=0)) 477 | print(np.std(np.array(train_all), axis=0)) 478 | print(np.mean(np.array(test_all), axis=0)) 479 | print(np.std(np.array(test_all), axis=0)) 480 | 481 | print(np.mean(np.array(mean_all), axis=0)) 482 | print(np.std(np.array(mean_all), axis=0)) 483 | print(np.mean(np.array(std_all), axis=0)) 484 | print(np.std(np.array(std_all), axis=0)) 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kernelized-HRM 2 | > Jiashuo Liu, Zheyuan Hu 3 | 4 | 5 | > The code for our NeurIPS 2021 paper "Kernelized Heterogeneous Risk Minimization"[1]. This repo contains the codes for our **Classification with Spurious Correlation** and **Regression with Selection Bias** simulated experiments, including the data generation process, the whole Kernelized-HRM algorithm and the testing process. 6 | 7 | ### Details 8 | There are two files, named `KernelHRM_sim1.py` and `KernelHRM_sim2.py`, which contains the code for the classification simulation experiment and the regression simulation experiment, respectively. 9 | The details of codes are: 10 | 11 | * `generate_data_list`: generate data according to the given parameters `args.r_list`. 12 | 13 | * `generate_test_data_list`: generate the test data for **Selection Bias** experiment, where the `args.r_list` is pre-defined to [-2.9,-2.7,...,-1.9]. 14 | 15 | * `main_KernelHRM`: the whole framework for our Kernelized-HRM algorithm. 16 | 17 | 18 | ### Hypermeters 19 | There are many hyper-parameters to be tuned for the whole framework, which are different among different tasks and require users to carefully tune. Note that although we provide the hyper-parameters for the simulated experiments, it is possible that the results are not exactly the same as ours, which may due to the randomness or something else. 20 | 21 | Generally, the following hyper-parameters need carefully tuned: 22 | 23 | * k: controls the dimension of reduced neural tangent features 24 | * whole_epoch: controls the overall number of iterations between the frontend and the backend 25 | * epochs: controls the number of epochs of optimizing the invariant learning module in each iteration 26 | * IRM_lam: controls the strength of the regularizer for the invariant learning 27 | * lr: learning rate 28 | * cluster_num: controls the number of clusters 29 | 30 | Further, for the experimental settings, the following parameters need to be specified: 31 | 32 | * r_list: controls the strength of spurious correlations 33 | * scramble: similar to IRM[2], whether to mix the raw features 34 | * num_list: controls the number of data points from each environment 35 | 36 | As for the optimal hyper-parameters for our simulation experiments, we put them into the `reproduce.sh` file. 37 | 38 | 39 | ### Others 40 | Similar to HRM[3], we view the proposed Kernelized-HRM as a framework, which converts the non-linear and complicated data into linear and raw feature data by neural tangent kernel and includes the clustering module and the invariant prediction module. In practice, one can replace each model to anything they want with the same effect. 41 | 42 | Though I hate to mention it, our method has the following shortcomings: 43 | 44 | * Just like the original HRM[3], the convergence of the frontend module cannot be guaranteed, and we notice that there may be some cases the next iteration does not improve the current results or even hurts. 45 | * Hyper-parameters for different tasks may be quite different and need to be tuned carefully. 46 | * Whether this algorithm can be extended to more complicated image data, such as PACS, NICO *et al.* remains to be seen.(Maybe later we will have a try?) 47 | 48 | 49 | ### Reference 50 | [1] Jiasuho Liu, Zheyuan Hu, Peng Cui, Bo Li, Zheyan Shen. Kernelized Heterogeneous Risk Minimization. *In NeurIPS 2021*. 51 | 52 | [2] Arjovsky M, Bottou L, Gulrajani I, et al. Invariant risk minimization. 53 | 54 | [3] Jiashuo Liu, Zheyuan Hu, Peng Cui, Bo Li, Zheyan Shen. Heterogeneous Risk Minimziation. *In ICML 2021*. -------------------------------------------------------------------------------- /reproduce.sh: -------------------------------------------------------------------------------- 1 | # For the Classification with Spurious Correlation data 2 | # r = 0.7 3 | python3 KernelHRM_sim1.py --method IGD --k 15 --scramble 1 --whole_epoch 19 --device 6 --epochs 2000 4 | 5 | # r = 0.75 6 | python3 KernelHRM_sim1.py --method IGD --k 15 --scramble 1 --whole_epoch 10 --device 6 --epochs 3000 7 | 8 | # r = 0.8 9 | python3 KernelHRM_sim1.py --method IGD --k 10 --scramble 1 --whole_epoch 5 --device 4 --epochs 1000 10 | 11 | 12 | 13 | 14 | # For the Regression with Selection Bias data 15 | # Scenario 1 16 | # r = 1.5 17 | python3 KernelHRM_sim2.py --r_list 1.5 -1.1 -2.5 --num_list 1000 100 1000 --method IGD \ 18 | --epochs 3000 --lr 7e-3 --k 40 --IRM_lam 0.1 --whole_epoch 20 --scramble 0 --IRM_ann 500 19 | 20 | # r = 1.9 21 | python3 KernelHRM_sim2.py --r_list 1.9 -1.1 -2.5 --num_list 1000 100 1000 --method IGD \ 22 | --epochs 3000 --lr 7e-3 --k 30 --IRM_lam 0.1 --whole_epoch 20 --scramble 0 --IRM_ann 500 23 | 24 | # r = 2.3 25 | python3 KernelHRM_sim2.py --r_list 2.3 -1.1 -2.5 --num_list 1000 100 1000 --method IGD \ 26 | --epochs 3000 --lr 7e-3 --k 20 --IRM_lam 0.1 --whole_epoch 20 --scramble 0 --IRM_ann 500 27 | 28 | 29 | # Scenario 2 30 | # r = 1.5 31 | python3 KernelHRM_sim2.py --r_list 1.5 -1.1 -2.5 --num_list 1000 100 1000 --method IGD \ 32 | --epochs 3000 --lr 7e-3 --k 40 --IRM_lam 0.1 --whole_epoch 20 --scramble 1 --IRM_ann 500 33 | 34 | # r = 1.9 35 | python3 KernelHRM_sim2.py --r_list 1.9 -1.1 -2.5 --num_list 1000 100 1000 --method IGD \ 36 | --epochs 3000 --lr 7e-3 --k 30 --IRM_lam 0.1 --whole_epoch 20 --scramble 1 --IRM_ann 500 37 | 38 | 39 | # r = 2.3 40 | python3 KernelHRM_sim2.py --r_list 2.3 -1.1 -2.5 --num_list 1000 100 1000 --method IGD \ 41 | --epochs 3000 --lr 7e-3 --k 20 --IRM_lam 0.1 --whole_epoch 20 --scramble 1 --IRM_ann 500 --------------------------------------------------------------------------------