├── README.md └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # Stable_Linear_Model_Learning 2 | 3 | The implementation of AAAI2020 paper "Stable Learning via Sample Reweighting" 4 | 5 | Example: 6 | 7 | ```python 8 | # Load data to a numpy array X with shape n (sample) by p (feature) 9 | ... 10 | ... 11 | ... 12 | # Calculate new sample weights based on decorrelation operator 13 | weights = decorrelation(X) 14 | 15 | # Incorporate sample weights to downstream tasks e.g. Weighted Least Squares 16 | ... 17 | ... 18 | ... 19 | ``` 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import pandas as pd 4 | from sklearn.neural_network import MLPClassifier 5 | from sklearn.linear_model import LogisticRegression 6 | 7 | def column_wise_resampling(x, replacement = False, random_state = 0, **options): 8 | """ 9 | Perform column-wise random resampling to break the joint distribution of p(x). 10 | In practice, we can perform resampling without replacement (a.k.a. permutation) to retain all the data points of feature x_j. 11 | Moreover, if the practitioner has some priors on which features should be permuted, 12 | it can be passed through options by specifying 'sensitive_variables', by default it contains all the features 13 | """ 14 | rng = np.random.RandomState(random_state) 15 | n, p = x.shape 16 | if 'sensitive_variables' in options: 17 | sensitive_variables = options['sensitive_variables'] 18 | else: 19 | sensitive_variables = [i for i in range(p)] 20 | x_decorrelation = np.zeros([n, p]) 21 | for i in sensitive_variables: 22 | var = x[:, i] 23 | if replacement: # sampling with replacement 24 | x_decorrelation[:, i] = np.array([var[rng.randint(0, n)] for j in range(n)]) 25 | else: # permutation 26 | x_decorrelation[:, i] = var[rng.permutation(n)] 27 | return x_decorrelation 28 | 29 | def decorrelation(x, solver = 'adam', hidden_layer_sizes = (5,5), max_iter = 500, random_state = 0): 30 | """ 31 | Calcualte new sample weights by density ratio estimation 32 | q(x) P(x belongs to q(x) | x) 33 | w(x) = ---- = ------------------------ 34 | p(x) P(x belongs to p(x) | x) 35 | """ 36 | n, p = x.shape 37 | x_decorrelation = column_wise_resampling(x, random_state = random_state) 38 | P = pd.DataFrame(x) 39 | Q = pd.DataFrame(x_decorrelation) 40 | P['src'] = 1 # 1 means source distribution 41 | Q['src'] = 0 # 0 means target distribution 42 | Z = pd.concat([P, Q], ignore_index=True, axis=0) 43 | labels = Z['src'].values 44 | Z = Z.drop('src', axis=1).values 45 | P, Q = P.values, Q.values 46 | # train a multi-layer perceptron to classify the source and target distribution 47 | clf = MLPClassifier(solver=solver, hidden_layer_sizes=hidden_layer_sizes, max_iter=max_iter, random_state=random_state) 48 | clf.fit(Z, labels) 49 | proba = clf.predict_proba(Z)[:len(P), 1] 50 | weights = (1./proba) - 1. # calculate sample weights by density ratio 51 | weights /= np.mean(weights) # normalize the weights to get average 1 52 | weights = np.reshape(weights, [n,1]) 53 | return weights 54 | --------------------------------------------------------------------------------