├── Trot ├── __init__.py ├── Distances.py ├── Unregularized_OT.py ├── Generators.py ├── Projections.py ├── Projected_gradient.py ├── Florida_inference.py ├── Evaluation.py └── Tsallis.py ├── .gitignore ├── Data ├── baseline.pkl ├── joints.pkl ├── joints_M.pkl └── joints_gallup.pkl ├── README.md └── Notebooks ├── Tsallis Plots.ipynb ├── Ecological Inference.ipynb └── .ipynb_checkpoints └── Ecological Inference-checkpoint.ipynb /Trot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | *.pyc 3 | run_ecological_inference.py 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /Data/baseline.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorisMuzellec/TROT/HEAD/Data/baseline.pkl -------------------------------------------------------------------------------- /Data/joints.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorisMuzellec/TROT/HEAD/Data/joints.pkl -------------------------------------------------------------------------------- /Data/joints_M.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorisMuzellec/TROT/HEAD/Data/joints_M.pkl -------------------------------------------------------------------------------- /Data/joints_gallup.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorisMuzellec/TROT/HEAD/Data/joints_gallup.pkl -------------------------------------------------------------------------------- /Trot/Distances.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | #Gaussian kernel 4 | def rbf(x, y, gamma=10): 5 | return np.exp(- gamma * np.linalg.norm(x - y)) 6 | 7 | #Dissimilarity measure based on rbf 8 | def dist_1(x, y): 9 | return 1 - rbf(x, y) 10 | 11 | #Euclidean metric in the rbf Hilbert space 12 | def dist_2(x, y): 13 | return np.sqrt(2 - 2 * rbf(x, y)) 14 | 15 | 16 | #The Frobenius inner product for matrices 17 | def inner_Frobenius(P,Q): 18 | return np.multiply(P,Q).sum() -------------------------------------------------------------------------------- /Trot/Unregularized_OT.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Sep 5 16:20:04 2016 4 | 5 | @author: boris 6 | """ 7 | from scipy.optimize import linprog 8 | import numpy as np 9 | 10 | 11 | #Function using a linear solver to solve unregularized OT 12 | def Unregularized_OT(M,r,c): 13 | 14 | n = M.shape[0] 15 | m = M.shape[1] 16 | 17 | #Create the constraint matrix 18 | Ac = np.zeros((m,n*m)) 19 | for i in range(m): 20 | for j in range(n): 21 | Ac[i,i+j*m] = 1 22 | 23 | Ar = np.zeros((n,n*m)) 24 | for i in range(n): 25 | for j in range(m): 26 | Ar[i,m*i+j] = 1 27 | 28 | res = linprog(M.flatten(), A_eq = np.vstack((Ar,Ac)), b_eq=np.hstack((r,c))) 29 | 30 | return res.x.reshape((n,m)) 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TROT 2 | 3 | This is a Python implementation of Tsallis Regularized Optimal Transport (TROT) for ecological inference, following 4 | 5 | > Boris Muzellec, Richard Nock, Giorgio Patrini, Frank Nielsen. Tsallis Regularized Optimal Transport and Ecological Inference. [arXiv:1609.04495](https://arxiv.org/pdf/1609.04495v1.pdf) 6 | 7 | It contains both scripts implementing algorithms for solving TROT, and notebooks which reproduce the ecological inference pipeline from the article. 8 | 9 | 10 | # Dependencies 11 | 12 | ``` 13 | numpy, scipy, pickle 14 | ``` 15 | 16 | # Usage 17 | 18 | To run the Ecological Inference notebook, you will first want to download the Florida dataset. It may be accessed [here](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/SVY5VF). 19 | 20 | It should then be stored in the root folder of the repo. 21 | 22 | You can then run `Notebooks/Ecological/Inference.ipynb` for a reproduction of the article's ecological inference pipeline, and `Notebooks/Tsallis/Plots.ipynb` for a visualization of the impact of parameter $q$ and $\lambda$. 23 | 24 | The code under `Trot/` contains the basics for building a TROT-based application. 25 | -------------------------------------------------------------------------------- /Trot/Generators.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu May 26 10:42:22 2016 4 | 5 | @author: boris 6 | """ 7 | 8 | import numpy as np 9 | from math import ceil 10 | 11 | maxval = 100000 12 | 13 | #Uniform sampling over the open unit simplex (i.e. no 0s allowed) 14 | def rand_marginal(n): 15 | x = np.random.choice(np.arange(1,maxval-1), n, replace = False) 16 | p = np.zeros(n) 17 | x[n-1] = maxval 18 | x = np.sort(x) 19 | for i in range(1,n): 20 | p[i] = (x[i]-x[i-1])/maxval 21 | p[0] = x[0]/maxval 22 | return p 23 | 24 | 25 | #Same cost matrix as in Cuturi 2013 (for scale = 1), not necessarily square 26 | def rand_costs(n,m,scale): 27 | X = np.random.multivariate_normal(np.zeros(ceil(n/10)),np.identity(ceil(n/10)),max(n,m))*scale 28 | M = np.zeros((n,m)) 29 | 30 | for i in range(n): 31 | for j in range(m): 32 | M[i,j] = np.linalg.norm(X[i]-X[j]) 33 | 34 | return M/np.median(M) 35 | 36 | #Squared Euclidean distance cost matrix 37 | def euc_costs(n,scale): 38 | M = np.zeros((n,n)) 39 | for i in range(n): 40 | for j in range(n): 41 | M[i,j]=(i/n - j/n)*(i/n - j/n) 42 | return M*scale -------------------------------------------------------------------------------- /Trot/Projections.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 18 10:12:51 2016 4 | 5 | @author: boris 6 | """ 7 | import numpy as np 8 | from Distances import inner_Frobenius 9 | 10 | #Sinkhorn-Knopp's algorithm - performs Kullback-Leibler projection on U(r,c) 11 | def Sinkhorn(A,r,c,precision): 12 | 13 | p = np.sum(A,axis = 1) 14 | q = np.sum(A,axis = 0) 15 | count = 0 16 | 17 | while not (check(p,q,r,c,precision)) and count <= 4000: 18 | 19 | A = np.diag(np.divide(r,p)).dot(A) 20 | q = np.sum(A,axis = 0) 21 | A = A.dot(np.diag(np.divide(c,q))) 22 | p = np.sum(A,axis = 1) 23 | count+=1 24 | 25 | if count >= 4000: 26 | print('Unable to perform Sinkhorn-Knopp projection') 27 | return A 28 | 29 | 30 | 31 | #Euclidean projection on U(r,c) -- Assumes that A is a square matrix 32 | def euc_projection(A,r,c,precision): 33 | 34 | n = A.shape[0] 35 | m = A.shape[0] 36 | 37 | assert (m == n),"Non-square matrix in euclidean projection" 38 | 39 | p = np.sum(A,axis = 1) 40 | q = np.sum(A,axis = 0) 41 | 42 | H = np.matrix(np.full((n,n),1)) 43 | 44 | count = 0 45 | 46 | while not (check(p,q,r,c,precision)) and count <=4000: 47 | 48 | #Projection on the transportation constraints 49 | A = A + (np.tile(r,(n,1)).transpose() +np.tile(c,(n,1)))/n - (np.tile(p,(n,1)).transpose() + np.tile(q,(n,1)))/n + (A.sum() - 1)* H/(n*n) 50 | 51 | #Projection on the positivity constraint 52 | A = np.maximum(A,0) 53 | 54 | p = np.sum(A,axis = 1) 55 | q = np.sum(A,axis = 0) 56 | 57 | count = count +1 58 | 59 | if count >= 4000: 60 | print('Unable to perform Euclidean projection') 61 | return A 62 | 63 | 64 | #Returns true iff p and q approximate respectively r and c to a 'prec' ratio in infinite norm 65 | def check(p,q,r,c,prec): 66 | if (np.linalg.norm(np.divide(p,r)-1,np.inf)>prec) or (np.linalg.norm(np.divide(q,c)-1,np.inf)>prec): 67 | return False 68 | else: return True 69 | 70 | 71 | -------------------------------------------------------------------------------- /Trot/Projected_gradient.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jul 7 16:29:32 2016 4 | 5 | @author: boris 6 | """ 7 | 8 | from copy import deepcopy 9 | import numpy as np 10 | import math 11 | import Projections as proj 12 | from Tsallis import q_obj, q_grad 13 | 14 | #Euclidean projected gradient descent 15 | def euc_proj_descent(q,M,r,c,l ,precision,T , rate = None, rate_type = None): 16 | 17 | if (q<=1): print("Warning: Projected gradient methods only work when q>1") 18 | 19 | ind_map = np.asarray(np.matrix(r).transpose().dot(np.matrix(c))) 20 | P = deepcopy(ind_map) 21 | 22 | new_score = q_obj(q,P,M,l) 23 | scores = [] 24 | scores.append(new_score) 25 | 26 | best_score = new_score 27 | best_P = P 28 | 29 | count = 1 30 | 31 | while count<=T: 32 | 33 | G = q_grad(q,P,M,l) 34 | 35 | if rate is None: 36 | P = proj.euc_projection(P - (math.sqrt(2*math.sqrt(2)/T)/np.linalg.norm(G))*G,r,c,precision) #Absolute horizon 37 | #P = proj.euc_projection(P - (math.sqrt(2./count)/np.linalg.norm(G))*G,r,c,precision) #Rolling horizon 38 | elif rate_type is "constant": 39 | P = proj.euc_projection(P - rate*G,r,c,precision) 40 | elif rate_type is "constant_length": 41 | P = proj.euc_projection(P - rate*np.linalg.norm(G)*G,r,c,precision) 42 | elif rate_type is "diminishing": 43 | P = proj.euc_projection(P - rate/math.sqrt(count)*G,r,c,precision) 44 | elif rate_type is "square_summable": 45 | P = proj.euc_projection(P - rate/count*G,r,c,precision) 46 | 47 | #Update score list 48 | new_score = q_obj(q,P,M,l) 49 | scores.append(new_score) 50 | 51 | #Keep track of the best solution so far 52 | if (new_score < best_score): 53 | best_score = new_score 54 | best_P = P 55 | 56 | count+=1 57 | return best_P,scores 58 | 59 | 60 | #Nesterov's accelerated gradient 61 | def Nesterov_grad(q,M,r,c,l ,precision,T): 62 | 63 | if (q<2): print("Warning: Nesterov's accelerated gradient only works when q>2") 64 | 65 | ind_map = np.matrix(r).transpose().dot(np.matrix(c)) 66 | P = deepcopy(ind_map) 67 | 68 | #Estimation of the gradient Lipschitz constant 69 | L = q/(l*(q-1)*(q-1)) 70 | 71 | 72 | new_score = q_obj(q,P,M,l) 73 | scores = [] 74 | scores.append(new_score) 75 | 76 | best_score = new_score 77 | best_Y = P 78 | 79 | #Negative cumulative weighted gradient sum 80 | grad_sum = np.zeros(P.shape) 81 | count = 1 82 | 83 | while count<=T: 84 | 85 | G =q_grad(q,P,M,l) 86 | grad_sum-=count/2*G 87 | 88 | Y = proj.euc_projection(P-G/L,r,c,precision) 89 | Z = proj.Sinkhorn(np.multiply(ind_map,np.exp(grad_sum/L)),r,c,precision) 90 | P = (2*Z + (count)*Y)/(count+2) 91 | 92 | #Update score list 93 | new_score = q_obj(q,Y,M,l) 94 | scores.append(new_score) 95 | 96 | #Keep track of the best solution so far 97 | if (new_score < best_score): 98 | best_score = new_score 99 | best_Y = Y 100 | 101 | count+=1 102 | 103 | return best_Y, scores 104 | 105 | -------------------------------------------------------------------------------- /Trot/Florida_inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Aug 18 13:44:15 2016 4 | 5 | @author: boris 6 | """ 7 | 8 | # import pandas as pd 9 | import numpy as np 10 | from Tsallis import TROT 11 | from Unregularized_OT import Unregularized_OT 12 | from Evaluation import KL 13 | 14 | 15 | # Compute the marginals and cost matrices in each county (usefull for CV) 16 | 17 | def CV_Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals, counties,q,l, filename = None): 18 | 19 | best_kl = np.Inf 20 | q_best = q[0] 21 | l_best = l[0] 22 | variance_best = 0 23 | 24 | if filename is not None: 25 | file = open('{0}.txt'.format(filename), "w") 26 | file.write('q, l, kl\n') 27 | 28 | for j in range(len(q)): 29 | for i in range(len(l)): 30 | 31 | J_inferred = {} 32 | for county in counties: 33 | 34 | r = Ethnicity_Marginals[county] 35 | c = Party_Marginals[county] 36 | 37 | Infered_Distrib = TROT(q[j],M,r,c,l[i],1E-2) 38 | 39 | J_inferred[county] = Infered_Distrib / Infered_Distrib.sum() 40 | 41 | kl, std = KL(J, J_inferred, counties) 42 | print('q: %.2f, lambda: %.4f, KL: %.4g, STD: %.4g' % (q[j], l[i], kl, std)) 43 | if filename is not None: 44 | file.write('%.2f, %.4f, %.4g\n' % (q[j], l[i], kl)) 45 | 46 | if kl < best_kl: 47 | q_best = q[j] 48 | l_best = l[i] 49 | variance_best = std 50 | best_kl = kl 51 | 52 | 53 | print('Best score: %.4g, Best q: %.2f, Best lambda: %.4f\t Standard Variance: %.4g\n' % (best_kl, q_best, l_best, variance_best)) 54 | 55 | if filename is not None: 56 | file.close() 57 | 58 | return best_kl, q_best, l_best 59 | 60 | 61 | def Unreg_Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals, counties): 62 | 63 | J_inferred = {} 64 | for county in counties: 65 | 66 | r = Ethnicity_Marginals[county] 67 | c = Party_Marginals[county] 68 | 69 | J_inferred[county] = Unregularized_OT(M, r, c) 70 | J_inferred[county] /= J_inferred[county].sum() 71 | 72 | 73 | return J_inferred 74 | 75 | 76 | def CV_Cross_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals, Ref_county, CV_counties,q,l): 77 | 78 | 79 | best_kl = np.Inf 80 | q_best = q[0] 81 | l_best = l[0] 82 | variance_best = 0 83 | 84 | for j in range(len(q)): 85 | for i in range(len(l)): 86 | 87 | print('q= {0}, l= {1}\n'.format(q[j],l[i])) 88 | 89 | J_inferred = {} 90 | for county in CV_counties: 91 | 92 | r = Ethnicity_Marginals[county] 93 | c = Party_Marginals[county] 94 | 95 | Infered_Distrib = TROT(q[j],M,r,c,l[i],1E-2) 96 | 97 | J_inferred[county] = Infered_Distrib / Infered_Distrib.sum() 98 | 99 | kl, std = KL(J, J_inferred, CV_counties) 100 | #print('KL: {0}\t Variance: {1}\n'.format(kl, std)) 101 | print('KL: %g\t Variance: %g\n' % kl, std) 102 | if kl < best_kl: 103 | q_best = q[j] 104 | l_best = l[i] 105 | variance_best = std 106 | best_kl = kl 107 | 108 | print('Best score: {0}, Best q: {1}, Best lambda: {2}\n'.format(best_kl, q_best, l_best, variance_best)) 109 | 110 | return best_kl, q_best, l_best 111 | 112 | 113 | def Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals, counties,q,l): 114 | 115 | J_inferred = {} 116 | for county in counties: 117 | 118 | r = Ethnicity_Marginals[county] 119 | c = Party_Marginals[county] 120 | 121 | Infered_Distrib = TROT(q,M,r,c,l,1E-2) 122 | 123 | J_inferred[county] = Infered_Distrib / Infered_Distrib.sum() 124 | 125 | return J_inferred 126 | -------------------------------------------------------------------------------- /Trot/Evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | 5 | def MSE(J_true, J_inferred, counties, save_to_file=None): 6 | """Compute MSE and its STD by cell. 7 | Input: dictionary of joints probability with counties as keys.""" 8 | 9 | assert len(counties) == len(J_inferred.keys()) 10 | 11 | l = [] 12 | 13 | for c in counties: 14 | diff = np.array(J_true[c] - J_inferred[c]).flatten() 15 | l.append(diff) 16 | 17 | assert len(counties) == len(J_inferred.keys()) 18 | 19 | mse_list = [] 20 | 21 | for c in counties: 22 | 23 | mse_county = np.array(J_true[c] - J_inferred[c]) ** 2 24 | mse_county = mse_county.mean() 25 | 26 | mse_list.append(mse_county) 27 | 28 | if save_to_file: 29 | f = open(save_to_file, 'wb') 30 | pickle.dump(mse_list, f) 31 | 32 | mse = np.asarray(mse_list) 33 | return mse.mean(), mse.std() 34 | 35 | 36 | def KL(J_true, J_inferred, counties, save_to_file=False, compute_abs_err=False): 37 | """Compute KL and its STD by cell. 38 | Input: dictionary of joints probability with counties as keys.""" 39 | 40 | EPS = 1e-18 # avoid KL +inf 41 | 42 | assert len(counties) == len(J_inferred.keys()) 43 | 44 | kl_list = [] 45 | if compute_abs_err: 46 | abs_list = [] 47 | 48 | for c in counties: 49 | 50 | J_inferred[c] /= J_inferred[c].sum() 51 | 52 | kl_county = 0.0 53 | for i in np.arange(J_true[c].shape[0]): 54 | for j in np.arange(J_true[c].shape[1]): 55 | kl_county += J_inferred[c][i, j] * \ 56 | np.log(J_inferred[c][i, j] / np.maximum(J_true[c][i, j], EPS)) 57 | 58 | kl_list.append(kl_county) 59 | 60 | if compute_abs_err: 61 | abs_list.append(np.abs(J_inferred[c] - J_true[c]).mean()) 62 | 63 | if save_to_file: 64 | f = open(save_to_file + '.pkl', 'wb') 65 | pickle.dump(kl_list, f) 66 | f.close() 67 | 68 | if compute_abs_err: 69 | f = open(save_to_file + '_abs.pkl', 'wb') 70 | pickle.dump(abs_list, f) 71 | f.close() 72 | 73 | if compute_abs_err: 74 | err = np.asarray(abs_list) 75 | print('Absolute error', err.mean(), ' + ', err.std()) 76 | 77 | kl = np.asarray(kl_list) 78 | return kl.mean(), kl.std() 79 | 80 | 81 | def National_Average_Baseline(Data, counties): 82 | """This baseline predicts the national average contingency table for every 83 | county.""" 84 | 85 | National_Average = np.zeros((6, 3)) 86 | Total_Num_Voters = Data.shape[0] 87 | 88 | National_Average[0,0] = Data.loc[(Data['Other'] ==1) & (Data['SR.WHI']==1)].shape[0] 89 | National_Average[0,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.WHI']==1)].shape[0] 90 | National_Average[0,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.WHI']==1)].shape[0] 91 | 92 | National_Average[1,0] = Data.loc[(Data['Other'] ==1) & (Data['SR.BLA']==1)].shape[0] 93 | National_Average[1,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.BLA']==1)].shape[0] 94 | National_Average[1,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.BLA']==1)].shape[0] 95 | 96 | National_Average[2,0] = Data.loc[(Data['Other'] ==1) & (Data['SR.HIS']==1)].shape[0] 97 | National_Average[2,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.HIS']==1)].shape[0] 98 | National_Average[2,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.HIS']==1)].shape[0] 99 | 100 | National_Average[3,0] = Data.loc[(Data['Other'] ==1) & (Data['SR.ASI']==1)].shape[0] 101 | National_Average[3,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.ASI']==1)].shape[0] 102 | National_Average[3,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.ASI']==1)].shape[0] 103 | 104 | National_Average[4,0] = Data.loc[(Data['Other'] ==1) &(Data['SR.NAT']==1)].shape[0] 105 | National_Average[4,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.NAT']==1)].shape[0] 106 | National_Average[4,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.NAT']==1)].shape[0] 107 | 108 | National_Average[5,0] = Data.loc[(Data['Other'] ==1) & (Data['SR.OTH']==1)].shape[0] 109 | National_Average[5,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.OTH']==1)].shape[0] 110 | National_Average[5,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.OTH']==1)].shape[0] 111 | 112 | National_Average = National_Average / Total_Num_Voters 113 | 114 | # replicate by CV_counties 115 | replica = {} 116 | for c in counties: 117 | replica[c] = National_Average 118 | 119 | return replica 120 | -------------------------------------------------------------------------------- /Trot/Tsallis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jun 15 11:38:13 2016 4 | 5 | @author: boris 6 | """ 7 | 8 | 9 | import numpy as np 10 | import math 11 | from copy import deepcopy 12 | from Projections import check, Sinkhorn 13 | from Distances import inner_Frobenius 14 | 15 | 16 | #Tsallis generalization of the exponential 17 | def q_exp (q,u): 18 | return np.power(1+(1-q)*u,1/(1-q)) 19 | 20 | #Tsallis generalization of the logarithm 21 | def q_log (q,u): 22 | return (np.power(u,1-q) - 1)/(1-q) 23 | 24 | #Objective function of the TROT minimization problem 25 | def q_obj(q,P,M,l): 26 | P_q = np.power(P,q) 27 | return inner_Frobenius(P,M) - np.sum(P_q - P)/((1-q)*l) 28 | 29 | #Gradient of the objective function in q-space 30 | def q_grad(q,P,M,l): 31 | return M + (q*np.power(P,q-1) - 1)/(l*(1-q)) 32 | 33 | def sign(x): 34 | if (x>0): 35 | return 1 36 | else: 37 | return -1 38 | 39 | #A wrapper for the three TROT-solving algorithms, depending on the value of q 40 | #q must be positive 41 | def TROT(q,M,r,c,l,precision): 42 | 43 | assert (q >= 0),"Invalid parameter q: q must be strictly positive" 44 | 45 | if np.isclose(q,1): 46 | #Add multipliers to rescale A and avoid dividing by zero 47 | A = deepcopy(l*M) 48 | A = A-np.amin(A,axis = 0) 49 | A = (A.T-np.amin(A,axis = 1)).T 50 | 51 | return Sinkhorn(np.exp(-A),r,c,precision) 52 | 53 | elif q<1: 54 | return second_order_sinkhorn(q,M,r,c,l,precision)[0] 55 | 56 | else: 57 | return KL_proj_descent(q,M,r,c,l,precision, 50, rate = 1, rate_type = "square_summable")[0] 58 | 59 | 60 | #A TROT optimizer using first order approximations (less efficient than second order) 61 | def first_order_sinkhorn(q,M,r,c,l,precision): 62 | 63 | q1 = q_exp(q,-1) 64 | 65 | P = q1/q_exp(q,l*M) 66 | A = deepcopy(l*M) 67 | 68 | p = P.sum(axis = 1) 69 | s = P.sum(axis = 0) 70 | 71 | count = 0 72 | alpha = np.zeros(M.shape[0]) 73 | beta = np.zeros(M.shape[1]) 74 | 75 | while not (check(p,s,r,c,precision)) and count <= 1000: 76 | 77 | alpha = np.divide(p-r,np.sum(np.divide(P,(1+(1-q)*A)),axis = 1)) 78 | A = (A.transpose() + alpha).transpose() 79 | 80 | P = q1/q_exp(q,A) 81 | s = P.sum(axis = 0) 82 | 83 | beta = np.divide(s-c,np.sum(np.divide(P,(1+(1-q)*A)),axis = 0)) 84 | A += beta 85 | 86 | 87 | P = q1/q_exp(q,A) 88 | p = P.sum(axis = 1) 89 | s = P.sum(axis = 0) 90 | 91 | 92 | 93 | count +=1 94 | 95 | return P, count, q_obj(q,P,M,l) 96 | 97 | 98 | #The TROT optimizing algorithm from the paper -- to use when q is in (0,1) 99 | def second_order_sinkhorn(q,M,r,c,l,precision): 100 | 101 | n = M.shape[0] 102 | m = M.shape[1] 103 | q1 = q_exp(q,-1) 104 | 105 | A = deepcopy(l*M) 106 | 107 | #Add multipliers to make sure that A is not too large 108 | #We subtract the minimum on columns and then on rows 109 | #This does not change the solution of TROT 110 | A = A-np.amin(A,axis = 0) 111 | A = (A.T-np.amin(A,axis = 1)).T 112 | 113 | P = q1/q_exp(q,A) 114 | 115 | p = P.sum(axis = 1) 116 | s = P.sum(axis = 0) 117 | 118 | count = 0 119 | alpha = np.zeros(M.shape[0]) 120 | beta = np.zeros(M.shape[1]) 121 | 122 | while not (check(p,s,r,c,precision)) and count <= 1000: 123 | 124 | 125 | A_q2 = np.divide(P,(1+(1-q)*A)) 126 | a = (1-q/2)*(np.sum(np.divide(A_q2,(1+(1-q)*A)),axis = 1)) 127 | b = - np.sum(A_q2,axis = 1) 128 | d = p-r 129 | delta = np.multiply(b,b) - 4*np.multiply(a,d) 130 | 131 | 132 | for i in range(n): 133 | if (delta[i] >=0 and d[i]<0 and a[i]>0): 134 | alpha[i] = - (b[i] + math.sqrt(delta[i]))/(2*a[i]) 135 | elif (b[i] != 0): 136 | alpha[i] = 2*d[i]/(-b[i]) 137 | else: alpha[i] = 0 138 | 139 | #Check that the multiplier is not too large 140 | if abs(alpha[i]) > 1/((2-2*q)*max(1+(1-q)*A[i,:])): 141 | alpha[i] = sign(d[i])*1/((2-2*q)*max(1+(1-q)*A[i,:])) 142 | 143 | A = (A.transpose() + alpha).transpose() 144 | 145 | 146 | P = q1/q_exp(q,A) 147 | 148 | s = P.sum(axis = 0) 149 | 150 | A_q2 = np.divide(P,(1+(1-q)*A)) 151 | a = (1-q/2)*(np.sum(np.divide(A_q2,(1+(1-q)*A)),axis = 0)) 152 | b = - np.sum(A_q2,axis = 0) 153 | d = s - c 154 | delta = np.multiply(b,b) - 4*np.multiply(a,d) 155 | 156 | 157 | for i in range(m): 158 | if (delta[i] >=0 and d[i]<0 and a[i]>0): 159 | beta[i] = - (b[i] + math.sqrt(delta[i]))/(2*a[i]) 160 | elif (b[i] != 0): 161 | beta[i] = 2*d[i]/(-b[i]) 162 | else: beta[i] = 0 163 | 164 | #Check that the multiplier is not too large 165 | if abs(beta[i]) > 1/((2-2*q)*max(1+(1-q)*A[:,i])): 166 | beta[i] = sign(d[i])*1/((2-2*q)*max(1+(1-q)*A[:,i])) 167 | A += beta 168 | 169 | 170 | P = q1/q_exp(q,A) 171 | 172 | p = P.sum(axis = 1) 173 | s = P.sum(axis = 0) 174 | 175 | count +=1 176 | 177 | #print(P.sum()) 178 | 179 | return P, count, q_obj(q,P,M,l) 180 | 181 | 182 | #Kullback-Leibler projected gradient method, to use when q > 1 183 | def KL_proj_descent(q,M,r,c,l ,precision,T , rate = None, rate_type = None): 184 | 185 | if (q<=1): print("Warning: Projected gradient methods only work when q>1") 186 | 187 | omega = math.sqrt(2*np.log(M.shape[0])) 188 | ind_map = np.asarray(np.matrix(r).transpose().dot(np.matrix(c))) 189 | 190 | P = deepcopy(ind_map) 191 | 192 | new_score = q_obj(q,P,M,l) 193 | scores = [] 194 | scores.append(new_score) 195 | 196 | best_score = new_score 197 | best_P= P 198 | 199 | count = 1 200 | 201 | while count<=T: 202 | 203 | G = q_grad(q,P,M,l) 204 | 205 | 206 | if rate is None: 207 | tmp = np.exp(G*(-omega/(math.sqrt(T)*np.linalg.norm(G,np.inf)))) #Absolute horizon 208 | #tmp = np.exp(G*(-omega/(math.sqrt(count)*np.linalg.norm(G,np.inf)))) #Rolling horizon 209 | elif rate_type is "constant": 210 | tmp = np.exp(G*(-rate)) 211 | elif rate_type is "constant_length": 212 | tmp = np.exp(G*(-rate*np.linalg.norm(G,np.inf))) 213 | elif rate_type is "diminishing": 214 | tmp = np.exp(G*(-rate/math.sqrt(count))) 215 | elif rate_type is "square_summable": 216 | tmp = np.exp(G*(-rate/count)) 217 | 218 | 219 | P = np.multiply(P,tmp) 220 | P = Sinkhorn(P,r,c,precision) 221 | 222 | #Update score list 223 | new_score = q_obj(q,P,M,l) 224 | scores.append(new_score) 225 | 226 | #Keep track of the best solution so far 227 | if (new_score < best_score): 228 | best_score = new_score 229 | best_P = P 230 | 231 | 232 | count+=1 233 | 234 | return best_P, scores -------------------------------------------------------------------------------- /Notebooks/Tsallis Plots.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import matplotlib.gridspec as gridspec\n", 14 | "from matplotlib.patches import Rectangle\n", 15 | "from scipy.stats import poisson\n", 16 | "\n", 17 | "import sys\n", 18 | "sys.path.append('..')\n", 19 | "sys.path.append('../Trot')\n", 20 | "sys.path.append('../Data')\n", 21 | "\n", 22 | "from Trot.Tsallis import TROT, q_log\n", 23 | "from Trot.Generators import euc_costs " 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# Influence of parameters $q$ and $\\lambda$\n", 31 | "\n", 32 | "We reproduce Figure 2. from (cite our paper)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "First, choose values for $q$ and $\\lambda$ to test, sample size n, modes $\\mu$ and mixing coefficients $t$." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "collapsed": false 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "n = 50\n", 51 | "q = np.arange(0.5,2.5,0.5)\n", 52 | "l = [0.01,0.1,1,5]\n", 53 | "mu1 = [10,30]\n", 54 | "mu2 = [5,20,35]\n", 55 | "t1 = [0.5,0.5]\n", 56 | "t2 = [0.2,0.8,0.2]" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "Run TROT on selected parameters and marginals and produce the figure." 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "collapsed": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "x = range(n)\n", 75 | " \n", 76 | "r_tmp = [] \n", 77 | "for mode in mu1:\n", 78 | " r_tmp.append(poisson.pmf(x,mode))\n", 79 | " \n", 80 | "c_tmp = [] \n", 81 | "for mode in mu2:\n", 82 | " c_tmp.append(poisson.pmf(x,mode))\n", 83 | " \n", 84 | "r = np.dot(t1,r_tmp)\n", 85 | "r = r/r.sum()\n", 86 | " \n", 87 | "c = np.dot(t2,c_tmp)\n", 88 | "c = c/c.sum()\n", 89 | " \n", 90 | " \n", 91 | "M = euc_costs(n,n)\n", 92 | " \n", 93 | "P = []\n", 94 | " \n", 95 | "nq = len(q)\n", 96 | "nl = len(l)\n", 97 | " \n", 98 | " \n", 99 | "for j in range(nq):\n", 100 | " for i in range(nl):\n", 101 | " P_tmp = TROT(q[j],M,r,c,l[i],1E-2)\n", 102 | " P.append(P_tmp)\n", 103 | " \n", 104 | " \n", 105 | "fig = plt.figure(figsize=(8, 8))\n", 106 | " \n", 107 | "outer_grid = gridspec.GridSpec(2, 2, width_ratios=[1,5], height_ratios=[1,5])\n", 108 | "outer_grid.update(wspace=0.01, hspace=0.01)\n", 109 | "# gridspec inside gridspec\n", 110 | "outer_joint = gridspec.GridSpecFromSubplotSpec(nq,nl, subplot_spec=outer_grid[1,1],wspace=0.02, hspace=0.02)\n", 111 | "outer_row_marg = gridspec.GridSpecFromSubplotSpec(nq,1, subplot_spec=outer_grid[1,0],wspace=0.02, hspace=0.02)\n", 112 | "outer_col_marg = gridspec.GridSpecFromSubplotSpec(1,nl, subplot_spec=outer_grid[0,1],wspace=0.02, hspace=0.02)\n", 113 | " \n", 114 | " \n", 115 | "for b in range(nl):\n", 116 | " for a in range (nq):\n", 117 | " ax = plt.Subplot(fig, outer_joint[a,b])\n", 118 | " ax.imshow(P[nl*a + b], origin='upper', interpolation = None, aspect = 'auto', cmap = 'Greys')\n", 119 | " rect = Rectangle((0, 0), n-1, n-1, fc='none', ec='black') \n", 120 | " rect.set_width(0.8)\n", 121 | " rect.set_bounds(0,0,n-1,n-1)\n", 122 | " ax.add_patch(rect)\n", 123 | " ax.set_xticks([])\n", 124 | " ax.set_yticks([])\n", 125 | " fig.add_subplot(ax)\n", 126 | " ax.set_axis_bgcolor('white')\n", 127 | " \n", 128 | "for i in range(nq):\n", 129 | " ax_row = plt.Subplot(fig,outer_row_marg[i], sharey = ax)\n", 130 | " ax_row.plot(1-r, x)\n", 131 | " fig.add_subplot(ax_row)\n", 132 | " \n", 133 | " ax_row.axes.get_xaxis().set_visible(False)\n", 134 | " ax_row.axes.get_yaxis().set_visible(False)\n", 135 | " bottom, height = .25, .5\n", 136 | " top = bottom + height\n", 137 | " ax_row.text(-0.05, 0.5*(bottom+top), 'q = %.2f' % q[i], horizontalalignment='right', verticalalignment='center', rotation='vertical',transform=ax_row.transAxes, fontsize='medium')\n", 138 | " \n", 139 | " ax_row.set_axis_bgcolor('white')\n", 140 | " \n", 141 | "for j in range(nl):\n", 142 | " ax_col = plt.Subplot(fig,outer_col_marg[j], sharex = ax)\n", 143 | " ax_col.plot(x,c)\n", 144 | " fig.add_subplot(ax_col) \n", 145 | " bottom, height = .25, .5\n", 146 | " ax_col.axes.get_xaxis().set_visible(False)\n", 147 | " ax_col.axes.get_yaxis().set_visible(False)\n", 148 | " ax_col.set_title(r'$\\lambda$'+' = {0}'.format(l[j]),fontsize='medium')\n", 149 | " ax_col.set_axis_bgcolor('white')\n", 150 | " \n", 151 | "fig.show()" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "# Tsallis entropies\n", 159 | "\n", 160 | "We plot the Tsallis entropies of a Bernoulli random variable with parameter $p$ for different values of $q$." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": { 167 | "collapsed": false 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "def Tsallis_Bernoulli(q,p):\n", 172 | " if np.isclose(q,1):\n", 173 | " return p*np.log(1/p) + (1-p)*np.log(1/(1-p))\n", 174 | " else:\n", 175 | " return p*q_log(q,1/p) + (1-p)*q_log(q,1/(1-p))\n", 176 | "\n", 177 | "x = np.arange(1E-6,1,0.001)\n", 178 | "\n", 179 | "plt.figure()\n", 180 | "\n", 181 | "plt.plot(x,Tsallis_Bernoulli(0.1,x), label = 'q = 0.1')\n", 182 | "plt.plot(x,Tsallis_Bernoulli(0.5,x), label = 'q = 0.5')\n", 183 | "plt.plot(x,Tsallis_Bernoulli(1,x), label = 'q = 1')\n", 184 | "plt.plot(x,Tsallis_Bernoulli(1.5,x), label = 'q = 1.5')\n", 185 | "plt.plot(x,Tsallis_Bernoulli(5,x), label = 'q = 5')\n", 186 | "\n", 187 | "plt.legend()\n", 188 | "\n", 189 | "plt.show()" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "Notice how the choice of $q$ impacts the notion of randomness. When $q$ goes to 0, the entropy converges to the constant function $x=1$, except at the borders $p=0$ and $p=1$. Low values of q are anti-sparsity-inducing, but since they are rather flat around $p=0.5$ they do not require the distribution to be exactly uniform to have high-entropy. When $q$ goes to infinity, the entropy converges to the zero fonction." 197 | ] 198 | } 199 | ], 200 | "metadata": { 201 | "kernelspec": { 202 | "display_name": "Python 3", 203 | "language": "python", 204 | "name": "python3" 205 | }, 206 | "language_info": { 207 | "codemirror_mode": { 208 | "name": "ipython", 209 | "version": 3 210 | }, 211 | "file_extension": ".py", 212 | "mimetype": "text/x-python", 213 | "name": "python", 214 | "nbconvert_exporter": "python", 215 | "pygments_lexer": "ipython3", 216 | "version": "3.5.1" 217 | } 218 | }, 219 | "nbformat": 4, 220 | "nbformat_minor": 0 221 | } 222 | -------------------------------------------------------------------------------- /Notebooks/Ecological Inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Ecological Inference through Tsallis Regularized Optimal Transport (TROT)\n", 10 | "This notebook presents the pipeline used in our paper to perform ecological inference on the Florida dataset.\n", 11 | "\n", 12 | "You will first want to download the Florida dataset - see the README. " 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": { 19 | "collapsed": false, 20 | "scrolled": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import pandas as pd\n", 25 | "import numpy as np\n", 26 | "import pickle\n", 27 | "from matplotlib import pyplot as plt\n", 28 | "from matplotlib.pylab import savefig\n", 29 | "\n", 30 | "import sys\n", 31 | "sys.path.append('..')\n", 32 | "sys.path.append('../Trot')\n", 33 | "sys.path.append('../Data')\n", 34 | "\n", 35 | "from Trot import Distances as dist\n", 36 | "from Trot.Evaluation import KL\n", 37 | "from Trot.Florida_inference import CV_Local_Inference, Local_Inference" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "# Data Loading and Processing" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "collapsed": false 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "FlData = pd.read_csv('../Fl_Data.csv', usecols = ['District', 'County','Voters_Age', 'Voters_Gender', 'PID', 'vote08', \n", 56 | " 'SR.WHI', 'SR.BLA', 'SR.HIS', 'SR.ASI', 'SR.NAT', 'SR.OTH']) \n", 57 | "\n", 58 | "FlData = FlData.dropna()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "Change gender values to numerical values" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": { 72 | "collapsed": true 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "FlData['Voters_Gender'] = FlData['Voters_Gender'].map({'M': 1, 'F': 0})" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "Renormalize the age so that it takes values between 0 and 1" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": { 90 | "collapsed": true 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "FlData['Voters_Age'] = ((FlData['Voters_Age'] -\n", 95 | " FlData['Voters_Age'].min()) /\n", 96 | " (FlData['Voters_Age'].max() -\n", 97 | " FlData['Voters_Age'].min()))\n" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "One-hot party subscriptions (PID)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": { 111 | "collapsed": true 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "#Get one hot encoding of column PID\n", 116 | "one_hot = pd.get_dummies(FlData['PID'])\n", 117 | "# Drop column PID as it is now encoded\n", 118 | "FlData = FlData.drop('PID', axis=1)\n", 119 | "# Join the encoded df\n", 120 | "FlData = FlData.join(one_hot)\n", 121 | "# Rename the new columns\n", 122 | "FlData.rename(columns={0: 'Other', 1: 'Democrat', 2: 'Republican'},\n", 123 | " inplace=True)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": { 130 | "collapsed": false 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "FlData.describe()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "# Compute Marginals and Joint Distributions" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "Create a county dictionnary" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": { 155 | "collapsed": false 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "Voters_By_County = {}\n", 160 | "all_counties = FlData.County.unique()\n", 161 | "for county in all_counties:\n", 162 | " Voters_By_County[county] = FlData[FlData['County'] == county]" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "Compute the ground truth joint distribution" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": { 176 | "collapsed": false 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "J = {}\n", 181 | "for county in all_counties:\n", 182 | " J[county] = np.zeros((6, 3))\n", 183 | "\n", 184 | " J[county][0,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.WHI']==1)].shape[0]\n", 185 | " J[county][0,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.WHI']==1)].shape[0]\n", 186 | " J[county][0,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.WHI']==1)].shape[0]\n", 187 | "\n", 188 | " J[county][1,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.BLA']==1)].shape[0]\n", 189 | " J[county][1,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.BLA']==1)].shape[0]\n", 190 | " J[county][1,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.BLA']==1)].shape[0]\n", 191 | "\n", 192 | " J[county][2,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.HIS']==1)].shape[0]\n", 193 | " J[county][2,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.HIS']==1)].shape[0]\n", 194 | " J[county][2,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.HIS']==1)].shape[0]\n", 195 | "\n", 196 | " J[county][3,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.ASI']==1)].shape[0]\n", 197 | " J[county][3,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.ASI']==1)].shape[0]\n", 198 | " J[county][3,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.ASI']==1)].shape[0]\n", 199 | "\n", 200 | " J[county][4,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) &(Voters_By_County[county]['SR.NAT']==1)].shape[0]\n", 201 | " J[county][4,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.NAT']==1)].shape[0]\n", 202 | " J[county][4,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.NAT']==1)].shape[0]\n", 203 | "\n", 204 | " J[county][5,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.OTH']==1)].shape[0]\n", 205 | " J[county][5,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.OTH']==1)].shape[0]\n", 206 | " J[county][5,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.OTH']==1)].shape[0]\n", 207 | "\n", 208 | " J[county] /= J[county].sum()" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": { 215 | "collapsed": false 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "print(J[12])" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "Compute the party marginals" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": { 233 | "collapsed": false 234 | }, 235 | "outputs": [], 236 | "source": [ 237 | "Party_Marginals = {}\n", 238 | "parties = ['Other', 'Democrat', 'Republican']\n", 239 | "for county in all_counties:\n", 240 | " Party_Marginals[county] = pd.Series([J[county][:, i].sum()\n", 241 | " for i in np.arange(3)])\n", 242 | " Party_Marginals[county].index = parties" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "Compute the ethnicity marginals" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": { 256 | "collapsed": false 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "Ethnicity_Marginals = {}\n", 261 | "ethnies = ['SR.WHI', 'SR.BLA', 'SR.HIS', 'SR.ASI', 'SR.NAT', 'SR.OTH']\n", 262 | "for county in all_counties:\n", 263 | " Ethnicity_Marginals[county] = pd.Series([J[county][i, :].sum()\n", 264 | " for i in np.arange(6)])\n", 265 | " Ethnicity_Marginals[county].index = ethnies" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "# Compute the cost matrix\n", 273 | "Using only age, gender, and 2008 vote or abstention" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": { 280 | "collapsed": false 281 | }, 282 | "outputs": [], 283 | "source": [ 284 | "features = ['Voters_Age', 'Voters_Gender', 'vote08']\n", 285 | "e_len, p_len = len(ethnies), len(parties)\n", 286 | "M = np.zeros((e_len, p_len))\n", 287 | "for i, e in enumerate(ethnies):\n", 288 | " data_e = FlData[FlData[e] == 1.0]\n", 289 | " average_by_e = data_e[features].mean(axis=0)\n", 290 | " for j, p in enumerate(parties):\n", 291 | " data_p = FlData[FlData[p] == 1.0]\n", 292 | " average_by_p = data_p[features].mean(axis=0)\n", 293 | "\n", 294 | " M[i, j] = np.array(dist.dist_2(average_by_e, average_by_p))" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": {}, 300 | "source": [ 301 | "# Start the inference" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "Use a specific county or district to select the best parameters" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "metadata": { 315 | "collapsed": true 316 | }, 317 | "outputs": [], 318 | "source": [ 319 | "CV_counties = FlData[FlData['District'] == 3].County.unique()\n", 320 | "del FlData" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": {}, 326 | "source": [ 327 | "Find the best parameters" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": { 334 | "collapsed": false 335 | }, 336 | "outputs": [], 337 | "source": [ 338 | "q = np.arange(0.5, 3.1, 0.1)\n", 339 | "l = [0.01, 0.1, 1., 10., 100.] \n", 340 | "\n", 341 | "best_score, best_q, best_l = CV_Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals,\n", 342 | " CV_counties,q,l)" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "metadata": { 348 | "collapsed": true 349 | }, 350 | "source": [ 351 | "Use selected parameters on the rest of the dataset" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": { 358 | "collapsed": false 359 | }, 360 | "outputs": [], 361 | "source": [ 362 | "J_inferred = Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals, all_counties, best_q, best_l)\n", 363 | "kl, std = KL(J, J_inferred, all_counties, save_to_file=False, compute_abs_err=True)" 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "metadata": {}, 369 | "source": [ 370 | "# Plot the results" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "metadata": { 377 | "collapsed": false 378 | }, 379 | "outputs": [], 380 | "source": [ 381 | "diag = np.linspace(-0.1, 1.0, 100)\n", 382 | "\n", 383 | "# pickle results\n", 384 | "f = open('../Data/joints_gallup.pkl', 'rb')\n", 385 | "J_true, J = pickle.load(f)\n", 386 | "\n", 387 | "f = open('../Data/baseline.pkl', 'rb')\n", 388 | "J_baseline = pickle.load(f)\n", 389 | "\n", 390 | "j_true, j, j_baseline = [], [], []\n", 391 | "for c in all_counties:\n", 392 | " j_true.append(np.array(J_true[c]).flatten())\n", 393 | " j.append(np.array(J_inferred[c]).flatten())\n", 394 | " j_baseline.append(np.array(J_baseline[c]).flatten())\n", 395 | "\n", 396 | "j_true = np.array(j_true).flatten()\n", 397 | "j = np.array(j).flatten()\n", 398 | "j_baseline = np.array(j_baseline).flatten()\n" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "metadata": {}, 404 | "source": [ 405 | "Plot the correlation between the ground truth for the joint distribution and the infered distribution (the closer to the $x = y$ diagonal axis, the better)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": { 412 | "collapsed": false 413 | }, 414 | "outputs": [], 415 | "source": [ 416 | "plt.figure()\n", 417 | "plt.scatter(j_true, j, alpha=0.5)\n", 418 | "plt.xlabel('Ground truth')\n", 419 | "plt.ylabel('TROT (RBF)')\n", 420 | "plt.plot(diag, diag, 'r--')\n", 421 | "\n", 422 | "#plt.show()" 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": {}, 428 | "source": [ 429 | "Plot the distribution of the error (the more packed around the origin of the $x$-axis, the better)" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": { 436 | "collapsed": false 437 | }, 438 | "outputs": [], 439 | "source": [ 440 | "plt.figure()\n", 441 | "bins = np.arange(-.3, .6, 0.01)\n", 442 | "plt.hist(j_true - j, bins=bins, alpha=0.5, label='TROT')\n", 443 | "plt.hist(j_true - j_baseline, bins=bins, alpha=0.5, label='Florida-average')\n", 444 | "plt.legend()\n", 445 | "plt.xlabel('Difference between inference and ground truth')\n", 446 | "\n", 447 | "#plt.show()" 448 | ] 449 | }, 450 | { 451 | "cell_type": "markdown", 452 | "metadata": {}, 453 | "source": [ 454 | "# Survey-based ecological inference\n", 455 | "Same pipeline, but using a cost matrix computed thanks to the 2013 Gallup survey. (http://www.gallup.com/poll/160373/democrats-racially-diverse-republicans-mostly-white.aspx)\n", 456 | "\n", 457 | "We assume that Gallup's Other = {Native, Other}\n", 458 | "\n", 459 | "The cost matrix M is computed as $1-p_{ij}$, where $p_{ij}$ is the proportion of people registered to party $j$ belonging to group $i$." 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": null, 465 | "metadata": { 466 | "collapsed": true 467 | }, 468 | "outputs": [], 469 | "source": [ 470 | "M_sur = np.array([\n", 471 | " [.38, .26, .35],\n", 472 | " [.29, .64, .05],\n", 473 | " [.50, .32, .13],\n", 474 | " [.46, .36, .17],\n", 475 | " [.49, .32, .18],\n", 476 | " [.49, .32, .18]\n", 477 | " ])\n", 478 | "M_sur = (1. - M_sur)" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "Once again, find the best parameters" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "metadata": { 492 | "collapsed": false 493 | }, 494 | "outputs": [], 495 | "source": [ 496 | "best_score, best_q, best_l = CV_Local_Inference(Voters_By_County, M_sur, J, Ethnicity_Marginals, Party_Marginals,\n", 497 | " CV_counties,q,l)" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": {}, 503 | "source": [ 504 | "Using these parameters, run the inference on the rest of the dataset" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": null, 510 | "metadata": { 511 | "collapsed": false 512 | }, 513 | "outputs": [], 514 | "source": [ 515 | "J_sur = Local_Inference(Voters_By_County, M_sur, J, Ethnicity_Marginals, Party_Marginals, all_counties, best_q, best_l)\n", 516 | "kl, std = KL(J, J_sur, all_counties, save_to_file=False, compute_abs_err=True)" 517 | ] 518 | }, 519 | { 520 | "cell_type": "markdown", 521 | "metadata": {}, 522 | "source": [ 523 | "Plot correlation with ground truth" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": null, 529 | "metadata": { 530 | "collapsed": true 531 | }, 532 | "outputs": [], 533 | "source": [ 534 | "j_sur = []\n", 535 | "for c in all_counties:\n", 536 | " j_sur.append(np.array(J_sur[c]).flatten())\n", 537 | "\n", 538 | "j_sur = np.array(j_sur).flatten()\n", 539 | "\n", 540 | "plt.figure()\n", 541 | "plt.scatter(j_true, j_sur, alpha=0.5)\n", 542 | "plt.xlabel('Ground truth')\n", 543 | "plt.ylabel('TROT (survey)')\n", 544 | "plt.plot(diag, diag, 'r--')\n", 545 | "\n", 546 | "#plt.show()\n", 547 | " " 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": {}, 553 | "source": [ 554 | "Plot error distribution (compared with Florida average)" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": null, 560 | "metadata": { 561 | "collapsed": false 562 | }, 563 | "outputs": [], 564 | "source": [ 565 | "plt.figure()\n", 566 | "bins = np.arange(-.3, .6, 0.01)\n", 567 | "plt.hist(j_true - j_sur, bins=bins, alpha=0.5, label='TROT (survey)')\n", 568 | "plt.hist(j_true - j_baseline, bins=bins, alpha=0.5, label='Florida-average')\n", 569 | "plt.legend()\n", 570 | "plt.xlabel('Difference between inference and ground truth')\n", 571 | "\n", 572 | "#plt.show()" 573 | ] 574 | } 575 | ], 576 | "metadata": { 577 | "anaconda-cloud": {}, 578 | "kernelspec": { 579 | "display_name": "Python 3", 580 | "language": "python", 581 | "name": "python3" 582 | }, 583 | "language_info": { 584 | "codemirror_mode": { 585 | "name": "ipython", 586 | "version": 3 587 | }, 588 | "file_extension": ".py", 589 | "mimetype": "text/x-python", 590 | "name": "python", 591 | "nbconvert_exporter": "python", 592 | "pygments_lexer": "ipython3", 593 | "version": "3.5.1" 594 | } 595 | }, 596 | "nbformat": 4, 597 | "nbformat_minor": 0 598 | } 599 | -------------------------------------------------------------------------------- /Notebooks/.ipynb_checkpoints/Ecological Inference-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Ecological Inference through Tsallis Regularized Optimal Transport (TROT)\n", 10 | "This notebook presents the pipeline used in (cite our paper) to perform ecological inference on the Florida dataset.\n", 11 | "\n", 12 | "You will first want to download the dataset from (url to the dataset)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": { 19 | "collapsed": false 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "import pandas as pd\n", 24 | "import numpy as np\n", 25 | "import pickle\n", 26 | "from matplotlib import pyplot as plt\n", 27 | "from matplotlib.pylab import savefig\n", 28 | "\n", 29 | "import sys\n", 30 | "sys.path.append('..')\n", 31 | "sys.path.append('../Trot')\n", 32 | "sys.path.append('../Data')\n", 33 | "\n", 34 | "from Trot import Distances as dist\n", 35 | "from Trot.Evaluation import KL\n", 36 | "from Trot.Florida_inference import CV_Local_Inference, Local_Inference" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "# Data Loading and Processing" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": { 50 | "collapsed": false 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "FlData = pd.read_csv('../Fl_Data.csv', usecols = ['District', 'County','Voters_Age', 'Voters_Gender', 'PID', 'vote08', \n", 55 | " 'SR.WHI', 'SR.BLA', 'SR.HIS', 'SR.ASI', 'SR.NAT', 'SR.OTH']) \n", 56 | "\n", 57 | "FlData = FlData.dropna()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "Change gender values to numerical values" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": { 71 | "collapsed": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "FlData['Voters_Gender'] = FlData['Voters_Gender'].map({'M': 1, 'F': 0})" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "Renormalize the age so that it takes values between 0 and 1" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "metadata": { 89 | "collapsed": true 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "FlData['Voters_Age'] = ((FlData['Voters_Age'] -\n", 94 | " FlData['Voters_Age'].min()) /\n", 95 | " (FlData['Voters_Age'].max() -\n", 96 | " FlData['Voters_Age'].min()))\n" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "One-hot party subscriptions (PID)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": { 110 | "collapsed": true 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "#Get one hot encoding of column PID\n", 115 | "one_hot = pd.get_dummies(FlData['PID'])\n", 116 | "# Drop column PID as it is now encoded\n", 117 | "FlData = FlData.drop('PID', axis=1)\n", 118 | "# Join the encoded df\n", 119 | "FlData = FlData.join(one_hot)\n", 120 | "# Rename the new columns\n", 121 | "FlData.rename(columns={0: 'Other', 1: 'Democrat', 2: 'Republican'},\n", 122 | " inplace=True)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 29, 128 | "metadata": { 129 | "collapsed": false 130 | }, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/html": [ 135 | "
| \n", 140 | " | District | \n", 141 | "County | \n", 142 | "Voters_Age | \n", 143 | "Voters_Gender | \n", 144 | "vote08 | \n", 145 | "SR.WHI | \n", 146 | "SR.BLA | \n", 147 | "SR.HIS | \n", 148 | "SR.ASI | \n", 149 | "SR.NAT | \n", 150 | "SR.OTH | \n", 151 | "Other | \n", 152 | "Democrat | \n", 153 | "Republican | \n", 154 | "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | \n", 159 | "18205.000000 | \n", 160 | "18205 | \n", 161 | "18205.000000 | \n", 162 | "18205.000000 | \n", 163 | "18205.000000 | \n", 164 | "18205.000000 | \n", 165 | "18205.000000 | \n", 166 | "18205.000000 | \n", 167 | "18205.000000 | \n", 168 | "18205.000000 | \n", 169 | "18205.000000 | \n", 170 | "18205.000000 | \n", 171 | "18205.000000 | \n", 172 | "18205.000000 | \n", 173 | "
| mean | \n", 176 | "5.429827 | \n", 177 | "12 | \n", 178 | "0.361123 | \n", 179 | "0.473002 | \n", 180 | "0.820159 | \n", 181 | "0.768745 | \n", 182 | "0.103927 | \n", 183 | "0.049767 | \n", 184 | "0.026476 | \n", 185 | "0.003516 | \n", 186 | "0.049602 | \n", 187 | "0.199396 | \n", 188 | "0.543257 | \n", 189 | "0.257347 | \n", 190 | "
| std | \n", 193 | "1.177072 | \n", 194 | "0 | \n", 195 | "0.240638 | \n", 196 | "0.499284 | \n", 197 | "0.384065 | \n", 198 | "0.421647 | \n", 199 | "0.305175 | \n", 200 | "0.217468 | \n", 201 | "0.160551 | \n", 202 | "0.059189 | \n", 203 | "0.217127 | \n", 204 | "0.399557 | \n", 205 | "0.498139 | \n", 206 | "0.437184 | \n", 207 | "
| min | \n", 210 | "3.000000 | \n", 211 | "12 | \n", 212 | "0.000000 | \n", 213 | "0.000000 | \n", 214 | "0.000000 | \n", 215 | "0.000000 | \n", 216 | "0.000000 | \n", 217 | "0.000000 | \n", 218 | "0.000000 | \n", 219 | "0.000000 | \n", 220 | "0.000000 | \n", 221 | "0.000000 | \n", 222 | "0.000000 | \n", 223 | "0.000000 | \n", 224 | "
| 25% | \n", 227 | "6.000000 | \n", 228 | "12 | \n", 229 | "0.125000 | \n", 230 | "0.000000 | \n", 231 | "1.000000 | \n", 232 | "1.000000 | \n", 233 | "0.000000 | \n", 234 | "0.000000 | \n", 235 | "0.000000 | \n", 236 | "0.000000 | \n", 237 | "0.000000 | \n", 238 | "0.000000 | \n", 239 | "0.000000 | \n", 240 | "0.000000 | \n", 241 | "
| 50% | \n", 244 | "6.000000 | \n", 245 | "12 | \n", 246 | "0.350000 | \n", 247 | "0.000000 | \n", 248 | "1.000000 | \n", 249 | "1.000000 | \n", 250 | "0.000000 | \n", 251 | "0.000000 | \n", 252 | "0.000000 | \n", 253 | "0.000000 | \n", 254 | "0.000000 | \n", 255 | "0.000000 | \n", 256 | "1.000000 | \n", 257 | "0.000000 | \n", 258 | "
| 75% | \n", 261 | "6.000000 | \n", 262 | "12 | \n", 263 | "0.550000 | \n", 264 | "1.000000 | \n", 265 | "1.000000 | \n", 266 | "1.000000 | \n", 267 | "0.000000 | \n", 268 | "0.000000 | \n", 269 | "0.000000 | \n", 270 | "0.000000 | \n", 271 | "0.000000 | \n", 272 | "0.000000 | \n", 273 | "1.000000 | \n", 274 | "1.000000 | \n", 275 | "
| max | \n", 278 | "6.000000 | \n", 279 | "12 | \n", 280 | "1.000000 | \n", 281 | "1.000000 | \n", 282 | "1.000000 | \n", 283 | "1.000000 | \n", 284 | "1.000000 | \n", 285 | "1.000000 | \n", 286 | "1.000000 | \n", 287 | "1.000000 | \n", 288 | "1.000000 | \n", 289 | "1.000000 | \n", 290 | "1.000000 | \n", 291 | "1.000000 | \n", 292 | "