├── Trot ├── __init__.py ├── Distances.py ├── Unregularized_OT.py ├── Generators.py ├── Projections.py ├── Projected_gradient.py ├── Florida_inference.py ├── Evaluation.py └── Tsallis.py ├── .gitignore ├── Data ├── baseline.pkl ├── joints.pkl ├── joints_M.pkl └── joints_gallup.pkl ├── README.md └── Notebooks ├── Tsallis Plots.ipynb ├── Ecological Inference.ipynb └── .ipynb_checkpoints └── Ecological Inference-checkpoint.ipynb /Trot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | *.pyc 3 | run_ecological_inference.py 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /Data/baseline.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorisMuzellec/TROT/HEAD/Data/baseline.pkl -------------------------------------------------------------------------------- /Data/joints.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorisMuzellec/TROT/HEAD/Data/joints.pkl -------------------------------------------------------------------------------- /Data/joints_M.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorisMuzellec/TROT/HEAD/Data/joints_M.pkl -------------------------------------------------------------------------------- /Data/joints_gallup.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorisMuzellec/TROT/HEAD/Data/joints_gallup.pkl -------------------------------------------------------------------------------- /Trot/Distances.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | #Gaussian kernel 4 | def rbf(x, y, gamma=10): 5 | return np.exp(- gamma * np.linalg.norm(x - y)) 6 | 7 | #Dissimilarity measure based on rbf 8 | def dist_1(x, y): 9 | return 1 - rbf(x, y) 10 | 11 | #Euclidean metric in the rbf Hilbert space 12 | def dist_2(x, y): 13 | return np.sqrt(2 - 2 * rbf(x, y)) 14 | 15 | 16 | #The Frobenius inner product for matrices 17 | def inner_Frobenius(P,Q): 18 | return np.multiply(P,Q).sum() -------------------------------------------------------------------------------- /Trot/Unregularized_OT.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Sep 5 16:20:04 2016 4 | 5 | @author: boris 6 | """ 7 | from scipy.optimize import linprog 8 | import numpy as np 9 | 10 | 11 | #Function using a linear solver to solve unregularized OT 12 | def Unregularized_OT(M,r,c): 13 | 14 | n = M.shape[0] 15 | m = M.shape[1] 16 | 17 | #Create the constraint matrix 18 | Ac = np.zeros((m,n*m)) 19 | for i in range(m): 20 | for j in range(n): 21 | Ac[i,i+j*m] = 1 22 | 23 | Ar = np.zeros((n,n*m)) 24 | for i in range(n): 25 | for j in range(m): 26 | Ar[i,m*i+j] = 1 27 | 28 | res = linprog(M.flatten(), A_eq = np.vstack((Ar,Ac)), b_eq=np.hstack((r,c))) 29 | 30 | return res.x.reshape((n,m)) 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TROT 2 | 3 | This is a Python implementation of Tsallis Regularized Optimal Transport (TROT) for ecological inference, following 4 | 5 | > Boris Muzellec, Richard Nock, Giorgio Patrini, Frank Nielsen. Tsallis Regularized Optimal Transport and Ecological Inference. [arXiv:1609.04495](https://arxiv.org/pdf/1609.04495v1.pdf) 6 | 7 | It contains both scripts implementing algorithms for solving TROT, and notebooks which reproduce the ecological inference pipeline from the article. 8 | 9 | 10 | # Dependencies 11 | 12 | ``` 13 | numpy, scipy, pickle 14 | ``` 15 | 16 | # Usage 17 | 18 | To run the Ecological Inference notebook, you will first want to download the Florida dataset. It may be accessed [here](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/SVY5VF). 19 | 20 | It should then be stored in the root folder of the repo. 21 | 22 | You can then run `Notebooks/Ecological/Inference.ipynb` for a reproduction of the article's ecological inference pipeline, and `Notebooks/Tsallis/Plots.ipynb` for a visualization of the impact of parameter $q$ and $\lambda$. 23 | 24 | The code under `Trot/` contains the basics for building a TROT-based application. 25 | -------------------------------------------------------------------------------- /Trot/Generators.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu May 26 10:42:22 2016 4 | 5 | @author: boris 6 | """ 7 | 8 | import numpy as np 9 | from math import ceil 10 | 11 | maxval = 100000 12 | 13 | #Uniform sampling over the open unit simplex (i.e. no 0s allowed) 14 | def rand_marginal(n): 15 | x = np.random.choice(np.arange(1,maxval-1), n, replace = False) 16 | p = np.zeros(n) 17 | x[n-1] = maxval 18 | x = np.sort(x) 19 | for i in range(1,n): 20 | p[i] = (x[i]-x[i-1])/maxval 21 | p[0] = x[0]/maxval 22 | return p 23 | 24 | 25 | #Same cost matrix as in Cuturi 2013 (for scale = 1), not necessarily square 26 | def rand_costs(n,m,scale): 27 | X = np.random.multivariate_normal(np.zeros(ceil(n/10)),np.identity(ceil(n/10)),max(n,m))*scale 28 | M = np.zeros((n,m)) 29 | 30 | for i in range(n): 31 | for j in range(m): 32 | M[i,j] = np.linalg.norm(X[i]-X[j]) 33 | 34 | return M/np.median(M) 35 | 36 | #Squared Euclidean distance cost matrix 37 | def euc_costs(n,scale): 38 | M = np.zeros((n,n)) 39 | for i in range(n): 40 | for j in range(n): 41 | M[i,j]=(i/n - j/n)*(i/n - j/n) 42 | return M*scale -------------------------------------------------------------------------------- /Trot/Projections.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 18 10:12:51 2016 4 | 5 | @author: boris 6 | """ 7 | import numpy as np 8 | from Distances import inner_Frobenius 9 | 10 | #Sinkhorn-Knopp's algorithm - performs Kullback-Leibler projection on U(r,c) 11 | def Sinkhorn(A,r,c,precision): 12 | 13 | p = np.sum(A,axis = 1) 14 | q = np.sum(A,axis = 0) 15 | count = 0 16 | 17 | while not (check(p,q,r,c,precision)) and count <= 4000: 18 | 19 | A = np.diag(np.divide(r,p)).dot(A) 20 | q = np.sum(A,axis = 0) 21 | A = A.dot(np.diag(np.divide(c,q))) 22 | p = np.sum(A,axis = 1) 23 | count+=1 24 | 25 | if count >= 4000: 26 | print('Unable to perform Sinkhorn-Knopp projection') 27 | return A 28 | 29 | 30 | 31 | #Euclidean projection on U(r,c) -- Assumes that A is a square matrix 32 | def euc_projection(A,r,c,precision): 33 | 34 | n = A.shape[0] 35 | m = A.shape[0] 36 | 37 | assert (m == n),"Non-square matrix in euclidean projection" 38 | 39 | p = np.sum(A,axis = 1) 40 | q = np.sum(A,axis = 0) 41 | 42 | H = np.matrix(np.full((n,n),1)) 43 | 44 | count = 0 45 | 46 | while not (check(p,q,r,c,precision)) and count <=4000: 47 | 48 | #Projection on the transportation constraints 49 | A = A + (np.tile(r,(n,1)).transpose() +np.tile(c,(n,1)))/n - (np.tile(p,(n,1)).transpose() + np.tile(q,(n,1)))/n + (A.sum() - 1)* H/(n*n) 50 | 51 | #Projection on the positivity constraint 52 | A = np.maximum(A,0) 53 | 54 | p = np.sum(A,axis = 1) 55 | q = np.sum(A,axis = 0) 56 | 57 | count = count +1 58 | 59 | if count >= 4000: 60 | print('Unable to perform Euclidean projection') 61 | return A 62 | 63 | 64 | #Returns true iff p and q approximate respectively r and c to a 'prec' ratio in infinite norm 65 | def check(p,q,r,c,prec): 66 | if (np.linalg.norm(np.divide(p,r)-1,np.inf)>prec) or (np.linalg.norm(np.divide(q,c)-1,np.inf)>prec): 67 | return False 68 | else: return True 69 | 70 | 71 | -------------------------------------------------------------------------------- /Trot/Projected_gradient.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jul 7 16:29:32 2016 4 | 5 | @author: boris 6 | """ 7 | 8 | from copy import deepcopy 9 | import numpy as np 10 | import math 11 | import Projections as proj 12 | from Tsallis import q_obj, q_grad 13 | 14 | #Euclidean projected gradient descent 15 | def euc_proj_descent(q,M,r,c,l ,precision,T , rate = None, rate_type = None): 16 | 17 | if (q<=1): print("Warning: Projected gradient methods only work when q>1") 18 | 19 | ind_map = np.asarray(np.matrix(r).transpose().dot(np.matrix(c))) 20 | P = deepcopy(ind_map) 21 | 22 | new_score = q_obj(q,P,M,l) 23 | scores = [] 24 | scores.append(new_score) 25 | 26 | best_score = new_score 27 | best_P = P 28 | 29 | count = 1 30 | 31 | while count<=T: 32 | 33 | G = q_grad(q,P,M,l) 34 | 35 | if rate is None: 36 | P = proj.euc_projection(P - (math.sqrt(2*math.sqrt(2)/T)/np.linalg.norm(G))*G,r,c,precision) #Absolute horizon 37 | #P = proj.euc_projection(P - (math.sqrt(2./count)/np.linalg.norm(G))*G,r,c,precision) #Rolling horizon 38 | elif rate_type is "constant": 39 | P = proj.euc_projection(P - rate*G,r,c,precision) 40 | elif rate_type is "constant_length": 41 | P = proj.euc_projection(P - rate*np.linalg.norm(G)*G,r,c,precision) 42 | elif rate_type is "diminishing": 43 | P = proj.euc_projection(P - rate/math.sqrt(count)*G,r,c,precision) 44 | elif rate_type is "square_summable": 45 | P = proj.euc_projection(P - rate/count*G,r,c,precision) 46 | 47 | #Update score list 48 | new_score = q_obj(q,P,M,l) 49 | scores.append(new_score) 50 | 51 | #Keep track of the best solution so far 52 | if (new_score < best_score): 53 | best_score = new_score 54 | best_P = P 55 | 56 | count+=1 57 | return best_P,scores 58 | 59 | 60 | #Nesterov's accelerated gradient 61 | def Nesterov_grad(q,M,r,c,l ,precision,T): 62 | 63 | if (q<2): print("Warning: Nesterov's accelerated gradient only works when q>2") 64 | 65 | ind_map = np.matrix(r).transpose().dot(np.matrix(c)) 66 | P = deepcopy(ind_map) 67 | 68 | #Estimation of the gradient Lipschitz constant 69 | L = q/(l*(q-1)*(q-1)) 70 | 71 | 72 | new_score = q_obj(q,P,M,l) 73 | scores = [] 74 | scores.append(new_score) 75 | 76 | best_score = new_score 77 | best_Y = P 78 | 79 | #Negative cumulative weighted gradient sum 80 | grad_sum = np.zeros(P.shape) 81 | count = 1 82 | 83 | while count<=T: 84 | 85 | G =q_grad(q,P,M,l) 86 | grad_sum-=count/2*G 87 | 88 | Y = proj.euc_projection(P-G/L,r,c,precision) 89 | Z = proj.Sinkhorn(np.multiply(ind_map,np.exp(grad_sum/L)),r,c,precision) 90 | P = (2*Z + (count)*Y)/(count+2) 91 | 92 | #Update score list 93 | new_score = q_obj(q,Y,M,l) 94 | scores.append(new_score) 95 | 96 | #Keep track of the best solution so far 97 | if (new_score < best_score): 98 | best_score = new_score 99 | best_Y = Y 100 | 101 | count+=1 102 | 103 | return best_Y, scores 104 | 105 | -------------------------------------------------------------------------------- /Trot/Florida_inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Aug 18 13:44:15 2016 4 | 5 | @author: boris 6 | """ 7 | 8 | # import pandas as pd 9 | import numpy as np 10 | from Tsallis import TROT 11 | from Unregularized_OT import Unregularized_OT 12 | from Evaluation import KL 13 | 14 | 15 | # Compute the marginals and cost matrices in each county (usefull for CV) 16 | 17 | def CV_Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals, counties,q,l, filename = None): 18 | 19 | best_kl = np.Inf 20 | q_best = q[0] 21 | l_best = l[0] 22 | variance_best = 0 23 | 24 | if filename is not None: 25 | file = open('{0}.txt'.format(filename), "w") 26 | file.write('q, l, kl\n') 27 | 28 | for j in range(len(q)): 29 | for i in range(len(l)): 30 | 31 | J_inferred = {} 32 | for county in counties: 33 | 34 | r = Ethnicity_Marginals[county] 35 | c = Party_Marginals[county] 36 | 37 | Infered_Distrib = TROT(q[j],M,r,c,l[i],1E-2) 38 | 39 | J_inferred[county] = Infered_Distrib / Infered_Distrib.sum() 40 | 41 | kl, std = KL(J, J_inferred, counties) 42 | print('q: %.2f, lambda: %.4f, KL: %.4g, STD: %.4g' % (q[j], l[i], kl, std)) 43 | if filename is not None: 44 | file.write('%.2f, %.4f, %.4g\n' % (q[j], l[i], kl)) 45 | 46 | if kl < best_kl: 47 | q_best = q[j] 48 | l_best = l[i] 49 | variance_best = std 50 | best_kl = kl 51 | 52 | 53 | print('Best score: %.4g, Best q: %.2f, Best lambda: %.4f\t Standard Variance: %.4g\n' % (best_kl, q_best, l_best, variance_best)) 54 | 55 | if filename is not None: 56 | file.close() 57 | 58 | return best_kl, q_best, l_best 59 | 60 | 61 | def Unreg_Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals, counties): 62 | 63 | J_inferred = {} 64 | for county in counties: 65 | 66 | r = Ethnicity_Marginals[county] 67 | c = Party_Marginals[county] 68 | 69 | J_inferred[county] = Unregularized_OT(M, r, c) 70 | J_inferred[county] /= J_inferred[county].sum() 71 | 72 | 73 | return J_inferred 74 | 75 | 76 | def CV_Cross_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals, Ref_county, CV_counties,q,l): 77 | 78 | 79 | best_kl = np.Inf 80 | q_best = q[0] 81 | l_best = l[0] 82 | variance_best = 0 83 | 84 | for j in range(len(q)): 85 | for i in range(len(l)): 86 | 87 | print('q= {0}, l= {1}\n'.format(q[j],l[i])) 88 | 89 | J_inferred = {} 90 | for county in CV_counties: 91 | 92 | r = Ethnicity_Marginals[county] 93 | c = Party_Marginals[county] 94 | 95 | Infered_Distrib = TROT(q[j],M,r,c,l[i],1E-2) 96 | 97 | J_inferred[county] = Infered_Distrib / Infered_Distrib.sum() 98 | 99 | kl, std = KL(J, J_inferred, CV_counties) 100 | #print('KL: {0}\t Variance: {1}\n'.format(kl, std)) 101 | print('KL: %g\t Variance: %g\n' % kl, std) 102 | if kl < best_kl: 103 | q_best = q[j] 104 | l_best = l[i] 105 | variance_best = std 106 | best_kl = kl 107 | 108 | print('Best score: {0}, Best q: {1}, Best lambda: {2}\n'.format(best_kl, q_best, l_best, variance_best)) 109 | 110 | return best_kl, q_best, l_best 111 | 112 | 113 | def Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals, counties,q,l): 114 | 115 | J_inferred = {} 116 | for county in counties: 117 | 118 | r = Ethnicity_Marginals[county] 119 | c = Party_Marginals[county] 120 | 121 | Infered_Distrib = TROT(q,M,r,c,l,1E-2) 122 | 123 | J_inferred[county] = Infered_Distrib / Infered_Distrib.sum() 124 | 125 | return J_inferred 126 | -------------------------------------------------------------------------------- /Trot/Evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | 5 | def MSE(J_true, J_inferred, counties, save_to_file=None): 6 | """Compute MSE and its STD by cell. 7 | Input: dictionary of joints probability with counties as keys.""" 8 | 9 | assert len(counties) == len(J_inferred.keys()) 10 | 11 | l = [] 12 | 13 | for c in counties: 14 | diff = np.array(J_true[c] - J_inferred[c]).flatten() 15 | l.append(diff) 16 | 17 | assert len(counties) == len(J_inferred.keys()) 18 | 19 | mse_list = [] 20 | 21 | for c in counties: 22 | 23 | mse_county = np.array(J_true[c] - J_inferred[c]) ** 2 24 | mse_county = mse_county.mean() 25 | 26 | mse_list.append(mse_county) 27 | 28 | if save_to_file: 29 | f = open(save_to_file, 'wb') 30 | pickle.dump(mse_list, f) 31 | 32 | mse = np.asarray(mse_list) 33 | return mse.mean(), mse.std() 34 | 35 | 36 | def KL(J_true, J_inferred, counties, save_to_file=False, compute_abs_err=False): 37 | """Compute KL and its STD by cell. 38 | Input: dictionary of joints probability with counties as keys.""" 39 | 40 | EPS = 1e-18 # avoid KL +inf 41 | 42 | assert len(counties) == len(J_inferred.keys()) 43 | 44 | kl_list = [] 45 | if compute_abs_err: 46 | abs_list = [] 47 | 48 | for c in counties: 49 | 50 | J_inferred[c] /= J_inferred[c].sum() 51 | 52 | kl_county = 0.0 53 | for i in np.arange(J_true[c].shape[0]): 54 | for j in np.arange(J_true[c].shape[1]): 55 | kl_county += J_inferred[c][i, j] * \ 56 | np.log(J_inferred[c][i, j] / np.maximum(J_true[c][i, j], EPS)) 57 | 58 | kl_list.append(kl_county) 59 | 60 | if compute_abs_err: 61 | abs_list.append(np.abs(J_inferred[c] - J_true[c]).mean()) 62 | 63 | if save_to_file: 64 | f = open(save_to_file + '.pkl', 'wb') 65 | pickle.dump(kl_list, f) 66 | f.close() 67 | 68 | if compute_abs_err: 69 | f = open(save_to_file + '_abs.pkl', 'wb') 70 | pickle.dump(abs_list, f) 71 | f.close() 72 | 73 | if compute_abs_err: 74 | err = np.asarray(abs_list) 75 | print('Absolute error', err.mean(), ' + ', err.std()) 76 | 77 | kl = np.asarray(kl_list) 78 | return kl.mean(), kl.std() 79 | 80 | 81 | def National_Average_Baseline(Data, counties): 82 | """This baseline predicts the national average contingency table for every 83 | county.""" 84 | 85 | National_Average = np.zeros((6, 3)) 86 | Total_Num_Voters = Data.shape[0] 87 | 88 | National_Average[0,0] = Data.loc[(Data['Other'] ==1) & (Data['SR.WHI']==1)].shape[0] 89 | National_Average[0,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.WHI']==1)].shape[0] 90 | National_Average[0,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.WHI']==1)].shape[0] 91 | 92 | National_Average[1,0] = Data.loc[(Data['Other'] ==1) & (Data['SR.BLA']==1)].shape[0] 93 | National_Average[1,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.BLA']==1)].shape[0] 94 | National_Average[1,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.BLA']==1)].shape[0] 95 | 96 | National_Average[2,0] = Data.loc[(Data['Other'] ==1) & (Data['SR.HIS']==1)].shape[0] 97 | National_Average[2,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.HIS']==1)].shape[0] 98 | National_Average[2,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.HIS']==1)].shape[0] 99 | 100 | National_Average[3,0] = Data.loc[(Data['Other'] ==1) & (Data['SR.ASI']==1)].shape[0] 101 | National_Average[3,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.ASI']==1)].shape[0] 102 | National_Average[3,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.ASI']==1)].shape[0] 103 | 104 | National_Average[4,0] = Data.loc[(Data['Other'] ==1) &(Data['SR.NAT']==1)].shape[0] 105 | National_Average[4,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.NAT']==1)].shape[0] 106 | National_Average[4,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.NAT']==1)].shape[0] 107 | 108 | National_Average[5,0] = Data.loc[(Data['Other'] ==1) & (Data['SR.OTH']==1)].shape[0] 109 | National_Average[5,1] = Data.loc[(Data['Democrat'] ==1) & (Data['SR.OTH']==1)].shape[0] 110 | National_Average[5,2] = Data.loc[(Data['Republican'] ==1) & (Data['SR.OTH']==1)].shape[0] 111 | 112 | National_Average = National_Average / Total_Num_Voters 113 | 114 | # replicate by CV_counties 115 | replica = {} 116 | for c in counties: 117 | replica[c] = National_Average 118 | 119 | return replica 120 | -------------------------------------------------------------------------------- /Trot/Tsallis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jun 15 11:38:13 2016 4 | 5 | @author: boris 6 | """ 7 | 8 | 9 | import numpy as np 10 | import math 11 | from copy import deepcopy 12 | from Projections import check, Sinkhorn 13 | from Distances import inner_Frobenius 14 | 15 | 16 | #Tsallis generalization of the exponential 17 | def q_exp (q,u): 18 | return np.power(1+(1-q)*u,1/(1-q)) 19 | 20 | #Tsallis generalization of the logarithm 21 | def q_log (q,u): 22 | return (np.power(u,1-q) - 1)/(1-q) 23 | 24 | #Objective function of the TROT minimization problem 25 | def q_obj(q,P,M,l): 26 | P_q = np.power(P,q) 27 | return inner_Frobenius(P,M) - np.sum(P_q - P)/((1-q)*l) 28 | 29 | #Gradient of the objective function in q-space 30 | def q_grad(q,P,M,l): 31 | return M + (q*np.power(P,q-1) - 1)/(l*(1-q)) 32 | 33 | def sign(x): 34 | if (x>0): 35 | return 1 36 | else: 37 | return -1 38 | 39 | #A wrapper for the three TROT-solving algorithms, depending on the value of q 40 | #q must be positive 41 | def TROT(q,M,r,c,l,precision): 42 | 43 | assert (q >= 0),"Invalid parameter q: q must be strictly positive" 44 | 45 | if np.isclose(q,1): 46 | #Add multipliers to rescale A and avoid dividing by zero 47 | A = deepcopy(l*M) 48 | A = A-np.amin(A,axis = 0) 49 | A = (A.T-np.amin(A,axis = 1)).T 50 | 51 | return Sinkhorn(np.exp(-A),r,c,precision) 52 | 53 | elif q<1: 54 | return second_order_sinkhorn(q,M,r,c,l,precision)[0] 55 | 56 | else: 57 | return KL_proj_descent(q,M,r,c,l,precision, 50, rate = 1, rate_type = "square_summable")[0] 58 | 59 | 60 | #A TROT optimizer using first order approximations (less efficient than second order) 61 | def first_order_sinkhorn(q,M,r,c,l,precision): 62 | 63 | q1 = q_exp(q,-1) 64 | 65 | P = q1/q_exp(q,l*M) 66 | A = deepcopy(l*M) 67 | 68 | p = P.sum(axis = 1) 69 | s = P.sum(axis = 0) 70 | 71 | count = 0 72 | alpha = np.zeros(M.shape[0]) 73 | beta = np.zeros(M.shape[1]) 74 | 75 | while not (check(p,s,r,c,precision)) and count <= 1000: 76 | 77 | alpha = np.divide(p-r,np.sum(np.divide(P,(1+(1-q)*A)),axis = 1)) 78 | A = (A.transpose() + alpha).transpose() 79 | 80 | P = q1/q_exp(q,A) 81 | s = P.sum(axis = 0) 82 | 83 | beta = np.divide(s-c,np.sum(np.divide(P,(1+(1-q)*A)),axis = 0)) 84 | A += beta 85 | 86 | 87 | P = q1/q_exp(q,A) 88 | p = P.sum(axis = 1) 89 | s = P.sum(axis = 0) 90 | 91 | 92 | 93 | count +=1 94 | 95 | return P, count, q_obj(q,P,M,l) 96 | 97 | 98 | #The TROT optimizing algorithm from the paper -- to use when q is in (0,1) 99 | def second_order_sinkhorn(q,M,r,c,l,precision): 100 | 101 | n = M.shape[0] 102 | m = M.shape[1] 103 | q1 = q_exp(q,-1) 104 | 105 | A = deepcopy(l*M) 106 | 107 | #Add multipliers to make sure that A is not too large 108 | #We subtract the minimum on columns and then on rows 109 | #This does not change the solution of TROT 110 | A = A-np.amin(A,axis = 0) 111 | A = (A.T-np.amin(A,axis = 1)).T 112 | 113 | P = q1/q_exp(q,A) 114 | 115 | p = P.sum(axis = 1) 116 | s = P.sum(axis = 0) 117 | 118 | count = 0 119 | alpha = np.zeros(M.shape[0]) 120 | beta = np.zeros(M.shape[1]) 121 | 122 | while not (check(p,s,r,c,precision)) and count <= 1000: 123 | 124 | 125 | A_q2 = np.divide(P,(1+(1-q)*A)) 126 | a = (1-q/2)*(np.sum(np.divide(A_q2,(1+(1-q)*A)),axis = 1)) 127 | b = - np.sum(A_q2,axis = 1) 128 | d = p-r 129 | delta = np.multiply(b,b) - 4*np.multiply(a,d) 130 | 131 | 132 | for i in range(n): 133 | if (delta[i] >=0 and d[i]<0 and a[i]>0): 134 | alpha[i] = - (b[i] + math.sqrt(delta[i]))/(2*a[i]) 135 | elif (b[i] != 0): 136 | alpha[i] = 2*d[i]/(-b[i]) 137 | else: alpha[i] = 0 138 | 139 | #Check that the multiplier is not too large 140 | if abs(alpha[i]) > 1/((2-2*q)*max(1+(1-q)*A[i,:])): 141 | alpha[i] = sign(d[i])*1/((2-2*q)*max(1+(1-q)*A[i,:])) 142 | 143 | A = (A.transpose() + alpha).transpose() 144 | 145 | 146 | P = q1/q_exp(q,A) 147 | 148 | s = P.sum(axis = 0) 149 | 150 | A_q2 = np.divide(P,(1+(1-q)*A)) 151 | a = (1-q/2)*(np.sum(np.divide(A_q2,(1+(1-q)*A)),axis = 0)) 152 | b = - np.sum(A_q2,axis = 0) 153 | d = s - c 154 | delta = np.multiply(b,b) - 4*np.multiply(a,d) 155 | 156 | 157 | for i in range(m): 158 | if (delta[i] >=0 and d[i]<0 and a[i]>0): 159 | beta[i] = - (b[i] + math.sqrt(delta[i]))/(2*a[i]) 160 | elif (b[i] != 0): 161 | beta[i] = 2*d[i]/(-b[i]) 162 | else: beta[i] = 0 163 | 164 | #Check that the multiplier is not too large 165 | if abs(beta[i]) > 1/((2-2*q)*max(1+(1-q)*A[:,i])): 166 | beta[i] = sign(d[i])*1/((2-2*q)*max(1+(1-q)*A[:,i])) 167 | A += beta 168 | 169 | 170 | P = q1/q_exp(q,A) 171 | 172 | p = P.sum(axis = 1) 173 | s = P.sum(axis = 0) 174 | 175 | count +=1 176 | 177 | #print(P.sum()) 178 | 179 | return P, count, q_obj(q,P,M,l) 180 | 181 | 182 | #Kullback-Leibler projected gradient method, to use when q > 1 183 | def KL_proj_descent(q,M,r,c,l ,precision,T , rate = None, rate_type = None): 184 | 185 | if (q<=1): print("Warning: Projected gradient methods only work when q>1") 186 | 187 | omega = math.sqrt(2*np.log(M.shape[0])) 188 | ind_map = np.asarray(np.matrix(r).transpose().dot(np.matrix(c))) 189 | 190 | P = deepcopy(ind_map) 191 | 192 | new_score = q_obj(q,P,M,l) 193 | scores = [] 194 | scores.append(new_score) 195 | 196 | best_score = new_score 197 | best_P= P 198 | 199 | count = 1 200 | 201 | while count<=T: 202 | 203 | G = q_grad(q,P,M,l) 204 | 205 | 206 | if rate is None: 207 | tmp = np.exp(G*(-omega/(math.sqrt(T)*np.linalg.norm(G,np.inf)))) #Absolute horizon 208 | #tmp = np.exp(G*(-omega/(math.sqrt(count)*np.linalg.norm(G,np.inf)))) #Rolling horizon 209 | elif rate_type is "constant": 210 | tmp = np.exp(G*(-rate)) 211 | elif rate_type is "constant_length": 212 | tmp = np.exp(G*(-rate*np.linalg.norm(G,np.inf))) 213 | elif rate_type is "diminishing": 214 | tmp = np.exp(G*(-rate/math.sqrt(count))) 215 | elif rate_type is "square_summable": 216 | tmp = np.exp(G*(-rate/count)) 217 | 218 | 219 | P = np.multiply(P,tmp) 220 | P = Sinkhorn(P,r,c,precision) 221 | 222 | #Update score list 223 | new_score = q_obj(q,P,M,l) 224 | scores.append(new_score) 225 | 226 | #Keep track of the best solution so far 227 | if (new_score < best_score): 228 | best_score = new_score 229 | best_P = P 230 | 231 | 232 | count+=1 233 | 234 | return best_P, scores -------------------------------------------------------------------------------- /Notebooks/Tsallis Plots.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import matplotlib.gridspec as gridspec\n", 14 | "from matplotlib.patches import Rectangle\n", 15 | "from scipy.stats import poisson\n", 16 | "\n", 17 | "import sys\n", 18 | "sys.path.append('..')\n", 19 | "sys.path.append('../Trot')\n", 20 | "sys.path.append('../Data')\n", 21 | "\n", 22 | "from Trot.Tsallis import TROT, q_log\n", 23 | "from Trot.Generators import euc_costs " 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# Influence of parameters $q$ and $\\lambda$\n", 31 | "\n", 32 | "We reproduce Figure 2. from (cite our paper)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "First, choose values for $q$ and $\\lambda$ to test, sample size n, modes $\\mu$ and mixing coefficients $t$." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "collapsed": false 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "n = 50\n", 51 | "q = np.arange(0.5,2.5,0.5)\n", 52 | "l = [0.01,0.1,1,5]\n", 53 | "mu1 = [10,30]\n", 54 | "mu2 = [5,20,35]\n", 55 | "t1 = [0.5,0.5]\n", 56 | "t2 = [0.2,0.8,0.2]" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "Run TROT on selected parameters and marginals and produce the figure." 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "collapsed": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "x = range(n)\n", 75 | " \n", 76 | "r_tmp = [] \n", 77 | "for mode in mu1:\n", 78 | " r_tmp.append(poisson.pmf(x,mode))\n", 79 | " \n", 80 | "c_tmp = [] \n", 81 | "for mode in mu2:\n", 82 | " c_tmp.append(poisson.pmf(x,mode))\n", 83 | " \n", 84 | "r = np.dot(t1,r_tmp)\n", 85 | "r = r/r.sum()\n", 86 | " \n", 87 | "c = np.dot(t2,c_tmp)\n", 88 | "c = c/c.sum()\n", 89 | " \n", 90 | " \n", 91 | "M = euc_costs(n,n)\n", 92 | " \n", 93 | "P = []\n", 94 | " \n", 95 | "nq = len(q)\n", 96 | "nl = len(l)\n", 97 | " \n", 98 | " \n", 99 | "for j in range(nq):\n", 100 | " for i in range(nl):\n", 101 | " P_tmp = TROT(q[j],M,r,c,l[i],1E-2)\n", 102 | " P.append(P_tmp)\n", 103 | " \n", 104 | " \n", 105 | "fig = plt.figure(figsize=(8, 8))\n", 106 | " \n", 107 | "outer_grid = gridspec.GridSpec(2, 2, width_ratios=[1,5], height_ratios=[1,5])\n", 108 | "outer_grid.update(wspace=0.01, hspace=0.01)\n", 109 | "# gridspec inside gridspec\n", 110 | "outer_joint = gridspec.GridSpecFromSubplotSpec(nq,nl, subplot_spec=outer_grid[1,1],wspace=0.02, hspace=0.02)\n", 111 | "outer_row_marg = gridspec.GridSpecFromSubplotSpec(nq,1, subplot_spec=outer_grid[1,0],wspace=0.02, hspace=0.02)\n", 112 | "outer_col_marg = gridspec.GridSpecFromSubplotSpec(1,nl, subplot_spec=outer_grid[0,1],wspace=0.02, hspace=0.02)\n", 113 | " \n", 114 | " \n", 115 | "for b in range(nl):\n", 116 | " for a in range (nq):\n", 117 | " ax = plt.Subplot(fig, outer_joint[a,b])\n", 118 | " ax.imshow(P[nl*a + b], origin='upper', interpolation = None, aspect = 'auto', cmap = 'Greys')\n", 119 | " rect = Rectangle((0, 0), n-1, n-1, fc='none', ec='black') \n", 120 | " rect.set_width(0.8)\n", 121 | " rect.set_bounds(0,0,n-1,n-1)\n", 122 | " ax.add_patch(rect)\n", 123 | " ax.set_xticks([])\n", 124 | " ax.set_yticks([])\n", 125 | " fig.add_subplot(ax)\n", 126 | " ax.set_axis_bgcolor('white')\n", 127 | " \n", 128 | "for i in range(nq):\n", 129 | " ax_row = plt.Subplot(fig,outer_row_marg[i], sharey = ax)\n", 130 | " ax_row.plot(1-r, x)\n", 131 | " fig.add_subplot(ax_row)\n", 132 | " \n", 133 | " ax_row.axes.get_xaxis().set_visible(False)\n", 134 | " ax_row.axes.get_yaxis().set_visible(False)\n", 135 | " bottom, height = .25, .5\n", 136 | " top = bottom + height\n", 137 | " ax_row.text(-0.05, 0.5*(bottom+top), 'q = %.2f' % q[i], horizontalalignment='right', verticalalignment='center', rotation='vertical',transform=ax_row.transAxes, fontsize='medium')\n", 138 | " \n", 139 | " ax_row.set_axis_bgcolor('white')\n", 140 | " \n", 141 | "for j in range(nl):\n", 142 | " ax_col = plt.Subplot(fig,outer_col_marg[j], sharex = ax)\n", 143 | " ax_col.plot(x,c)\n", 144 | " fig.add_subplot(ax_col) \n", 145 | " bottom, height = .25, .5\n", 146 | " ax_col.axes.get_xaxis().set_visible(False)\n", 147 | " ax_col.axes.get_yaxis().set_visible(False)\n", 148 | " ax_col.set_title(r'$\\lambda$'+' = {0}'.format(l[j]),fontsize='medium')\n", 149 | " ax_col.set_axis_bgcolor('white')\n", 150 | " \n", 151 | "fig.show()" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "# Tsallis entropies\n", 159 | "\n", 160 | "We plot the Tsallis entropies of a Bernoulli random variable with parameter $p$ for different values of $q$." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": { 167 | "collapsed": false 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "def Tsallis_Bernoulli(q,p):\n", 172 | " if np.isclose(q,1):\n", 173 | " return p*np.log(1/p) + (1-p)*np.log(1/(1-p))\n", 174 | " else:\n", 175 | " return p*q_log(q,1/p) + (1-p)*q_log(q,1/(1-p))\n", 176 | "\n", 177 | "x = np.arange(1E-6,1,0.001)\n", 178 | "\n", 179 | "plt.figure()\n", 180 | "\n", 181 | "plt.plot(x,Tsallis_Bernoulli(0.1,x), label = 'q = 0.1')\n", 182 | "plt.plot(x,Tsallis_Bernoulli(0.5,x), label = 'q = 0.5')\n", 183 | "plt.plot(x,Tsallis_Bernoulli(1,x), label = 'q = 1')\n", 184 | "plt.plot(x,Tsallis_Bernoulli(1.5,x), label = 'q = 1.5')\n", 185 | "plt.plot(x,Tsallis_Bernoulli(5,x), label = 'q = 5')\n", 186 | "\n", 187 | "plt.legend()\n", 188 | "\n", 189 | "plt.show()" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "Notice how the choice of $q$ impacts the notion of randomness. When $q$ goes to 0, the entropy converges to the constant function $x=1$, except at the borders $p=0$ and $p=1$. Low values of q are anti-sparsity-inducing, but since they are rather flat around $p=0.5$ they do not require the distribution to be exactly uniform to have high-entropy. When $q$ goes to infinity, the entropy converges to the zero fonction." 197 | ] 198 | } 199 | ], 200 | "metadata": { 201 | "kernelspec": { 202 | "display_name": "Python 3", 203 | "language": "python", 204 | "name": "python3" 205 | }, 206 | "language_info": { 207 | "codemirror_mode": { 208 | "name": "ipython", 209 | "version": 3 210 | }, 211 | "file_extension": ".py", 212 | "mimetype": "text/x-python", 213 | "name": "python", 214 | "nbconvert_exporter": "python", 215 | "pygments_lexer": "ipython3", 216 | "version": "3.5.1" 217 | } 218 | }, 219 | "nbformat": 4, 220 | "nbformat_minor": 0 221 | } 222 | -------------------------------------------------------------------------------- /Notebooks/Ecological Inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Ecological Inference through Tsallis Regularized Optimal Transport (TROT)\n", 10 | "This notebook presents the pipeline used in our paper to perform ecological inference on the Florida dataset.\n", 11 | "\n", 12 | "You will first want to download the Florida dataset - see the README. " 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": { 19 | "collapsed": false, 20 | "scrolled": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import pandas as pd\n", 25 | "import numpy as np\n", 26 | "import pickle\n", 27 | "from matplotlib import pyplot as plt\n", 28 | "from matplotlib.pylab import savefig\n", 29 | "\n", 30 | "import sys\n", 31 | "sys.path.append('..')\n", 32 | "sys.path.append('../Trot')\n", 33 | "sys.path.append('../Data')\n", 34 | "\n", 35 | "from Trot import Distances as dist\n", 36 | "from Trot.Evaluation import KL\n", 37 | "from Trot.Florida_inference import CV_Local_Inference, Local_Inference" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "# Data Loading and Processing" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "collapsed": false 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "FlData = pd.read_csv('../Fl_Data.csv', usecols = ['District', 'County','Voters_Age', 'Voters_Gender', 'PID', 'vote08', \n", 56 | " 'SR.WHI', 'SR.BLA', 'SR.HIS', 'SR.ASI', 'SR.NAT', 'SR.OTH']) \n", 57 | "\n", 58 | "FlData = FlData.dropna()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "Change gender values to numerical values" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": { 72 | "collapsed": true 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "FlData['Voters_Gender'] = FlData['Voters_Gender'].map({'M': 1, 'F': 0})" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "Renormalize the age so that it takes values between 0 and 1" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": { 90 | "collapsed": true 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "FlData['Voters_Age'] = ((FlData['Voters_Age'] -\n", 95 | " FlData['Voters_Age'].min()) /\n", 96 | " (FlData['Voters_Age'].max() -\n", 97 | " FlData['Voters_Age'].min()))\n" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "One-hot party subscriptions (PID)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": { 111 | "collapsed": true 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "#Get one hot encoding of column PID\n", 116 | "one_hot = pd.get_dummies(FlData['PID'])\n", 117 | "# Drop column PID as it is now encoded\n", 118 | "FlData = FlData.drop('PID', axis=1)\n", 119 | "# Join the encoded df\n", 120 | "FlData = FlData.join(one_hot)\n", 121 | "# Rename the new columns\n", 122 | "FlData.rename(columns={0: 'Other', 1: 'Democrat', 2: 'Republican'},\n", 123 | " inplace=True)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": { 130 | "collapsed": false 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "FlData.describe()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "# Compute Marginals and Joint Distributions" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "Create a county dictionnary" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": { 155 | "collapsed": false 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "Voters_By_County = {}\n", 160 | "all_counties = FlData.County.unique()\n", 161 | "for county in all_counties:\n", 162 | " Voters_By_County[county] = FlData[FlData['County'] == county]" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "Compute the ground truth joint distribution" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": { 176 | "collapsed": false 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "J = {}\n", 181 | "for county in all_counties:\n", 182 | " J[county] = np.zeros((6, 3))\n", 183 | "\n", 184 | " J[county][0,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.WHI']==1)].shape[0]\n", 185 | " J[county][0,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.WHI']==1)].shape[0]\n", 186 | " J[county][0,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.WHI']==1)].shape[0]\n", 187 | "\n", 188 | " J[county][1,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.BLA']==1)].shape[0]\n", 189 | " J[county][1,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.BLA']==1)].shape[0]\n", 190 | " J[county][1,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.BLA']==1)].shape[0]\n", 191 | "\n", 192 | " J[county][2,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.HIS']==1)].shape[0]\n", 193 | " J[county][2,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.HIS']==1)].shape[0]\n", 194 | " J[county][2,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.HIS']==1)].shape[0]\n", 195 | "\n", 196 | " J[county][3,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.ASI']==1)].shape[0]\n", 197 | " J[county][3,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.ASI']==1)].shape[0]\n", 198 | " J[county][3,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.ASI']==1)].shape[0]\n", 199 | "\n", 200 | " J[county][4,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) &(Voters_By_County[county]['SR.NAT']==1)].shape[0]\n", 201 | " J[county][4,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.NAT']==1)].shape[0]\n", 202 | " J[county][4,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.NAT']==1)].shape[0]\n", 203 | "\n", 204 | " J[county][5,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.OTH']==1)].shape[0]\n", 205 | " J[county][5,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.OTH']==1)].shape[0]\n", 206 | " J[county][5,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.OTH']==1)].shape[0]\n", 207 | "\n", 208 | " J[county] /= J[county].sum()" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": { 215 | "collapsed": false 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "print(J[12])" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "Compute the party marginals" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": { 233 | "collapsed": false 234 | }, 235 | "outputs": [], 236 | "source": [ 237 | "Party_Marginals = {}\n", 238 | "parties = ['Other', 'Democrat', 'Republican']\n", 239 | "for county in all_counties:\n", 240 | " Party_Marginals[county] = pd.Series([J[county][:, i].sum()\n", 241 | " for i in np.arange(3)])\n", 242 | " Party_Marginals[county].index = parties" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "Compute the ethnicity marginals" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": { 256 | "collapsed": false 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "Ethnicity_Marginals = {}\n", 261 | "ethnies = ['SR.WHI', 'SR.BLA', 'SR.HIS', 'SR.ASI', 'SR.NAT', 'SR.OTH']\n", 262 | "for county in all_counties:\n", 263 | " Ethnicity_Marginals[county] = pd.Series([J[county][i, :].sum()\n", 264 | " for i in np.arange(6)])\n", 265 | " Ethnicity_Marginals[county].index = ethnies" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "# Compute the cost matrix\n", 273 | "Using only age, gender, and 2008 vote or abstention" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": { 280 | "collapsed": false 281 | }, 282 | "outputs": [], 283 | "source": [ 284 | "features = ['Voters_Age', 'Voters_Gender', 'vote08']\n", 285 | "e_len, p_len = len(ethnies), len(parties)\n", 286 | "M = np.zeros((e_len, p_len))\n", 287 | "for i, e in enumerate(ethnies):\n", 288 | " data_e = FlData[FlData[e] == 1.0]\n", 289 | " average_by_e = data_e[features].mean(axis=0)\n", 290 | " for j, p in enumerate(parties):\n", 291 | " data_p = FlData[FlData[p] == 1.0]\n", 292 | " average_by_p = data_p[features].mean(axis=0)\n", 293 | "\n", 294 | " M[i, j] = np.array(dist.dist_2(average_by_e, average_by_p))" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": {}, 300 | "source": [ 301 | "# Start the inference" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "Use a specific county or district to select the best parameters" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "metadata": { 315 | "collapsed": true 316 | }, 317 | "outputs": [], 318 | "source": [ 319 | "CV_counties = FlData[FlData['District'] == 3].County.unique()\n", 320 | "del FlData" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": {}, 326 | "source": [ 327 | "Find the best parameters" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": { 334 | "collapsed": false 335 | }, 336 | "outputs": [], 337 | "source": [ 338 | "q = np.arange(0.5, 3.1, 0.1)\n", 339 | "l = [0.01, 0.1, 1., 10., 100.] \n", 340 | "\n", 341 | "best_score, best_q, best_l = CV_Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals,\n", 342 | " CV_counties,q,l)" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "metadata": { 348 | "collapsed": true 349 | }, 350 | "source": [ 351 | "Use selected parameters on the rest of the dataset" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": { 358 | "collapsed": false 359 | }, 360 | "outputs": [], 361 | "source": [ 362 | "J_inferred = Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals, all_counties, best_q, best_l)\n", 363 | "kl, std = KL(J, J_inferred, all_counties, save_to_file=False, compute_abs_err=True)" 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "metadata": {}, 369 | "source": [ 370 | "# Plot the results" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "metadata": { 377 | "collapsed": false 378 | }, 379 | "outputs": [], 380 | "source": [ 381 | "diag = np.linspace(-0.1, 1.0, 100)\n", 382 | "\n", 383 | "# pickle results\n", 384 | "f = open('../Data/joints_gallup.pkl', 'rb')\n", 385 | "J_true, J = pickle.load(f)\n", 386 | "\n", 387 | "f = open('../Data/baseline.pkl', 'rb')\n", 388 | "J_baseline = pickle.load(f)\n", 389 | "\n", 390 | "j_true, j, j_baseline = [], [], []\n", 391 | "for c in all_counties:\n", 392 | " j_true.append(np.array(J_true[c]).flatten())\n", 393 | " j.append(np.array(J_inferred[c]).flatten())\n", 394 | " j_baseline.append(np.array(J_baseline[c]).flatten())\n", 395 | "\n", 396 | "j_true = np.array(j_true).flatten()\n", 397 | "j = np.array(j).flatten()\n", 398 | "j_baseline = np.array(j_baseline).flatten()\n" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "metadata": {}, 404 | "source": [ 405 | "Plot the correlation between the ground truth for the joint distribution and the infered distribution (the closer to the $x = y$ diagonal axis, the better)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": { 412 | "collapsed": false 413 | }, 414 | "outputs": [], 415 | "source": [ 416 | "plt.figure()\n", 417 | "plt.scatter(j_true, j, alpha=0.5)\n", 418 | "plt.xlabel('Ground truth')\n", 419 | "plt.ylabel('TROT (RBF)')\n", 420 | "plt.plot(diag, diag, 'r--')\n", 421 | "\n", 422 | "#plt.show()" 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": {}, 428 | "source": [ 429 | "Plot the distribution of the error (the more packed around the origin of the $x$-axis, the better)" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": { 436 | "collapsed": false 437 | }, 438 | "outputs": [], 439 | "source": [ 440 | "plt.figure()\n", 441 | "bins = np.arange(-.3, .6, 0.01)\n", 442 | "plt.hist(j_true - j, bins=bins, alpha=0.5, label='TROT')\n", 443 | "plt.hist(j_true - j_baseline, bins=bins, alpha=0.5, label='Florida-average')\n", 444 | "plt.legend()\n", 445 | "plt.xlabel('Difference between inference and ground truth')\n", 446 | "\n", 447 | "#plt.show()" 448 | ] 449 | }, 450 | { 451 | "cell_type": "markdown", 452 | "metadata": {}, 453 | "source": [ 454 | "# Survey-based ecological inference\n", 455 | "Same pipeline, but using a cost matrix computed thanks to the 2013 Gallup survey. (http://www.gallup.com/poll/160373/democrats-racially-diverse-republicans-mostly-white.aspx)\n", 456 | "\n", 457 | "We assume that Gallup's Other = {Native, Other}\n", 458 | "\n", 459 | "The cost matrix M is computed as $1-p_{ij}$, where $p_{ij}$ is the proportion of people registered to party $j$ belonging to group $i$." 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": null, 465 | "metadata": { 466 | "collapsed": true 467 | }, 468 | "outputs": [], 469 | "source": [ 470 | "M_sur = np.array([\n", 471 | " [.38, .26, .35],\n", 472 | " [.29, .64, .05],\n", 473 | " [.50, .32, .13],\n", 474 | " [.46, .36, .17],\n", 475 | " [.49, .32, .18],\n", 476 | " [.49, .32, .18]\n", 477 | " ])\n", 478 | "M_sur = (1. - M_sur)" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "Once again, find the best parameters" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "metadata": { 492 | "collapsed": false 493 | }, 494 | "outputs": [], 495 | "source": [ 496 | "best_score, best_q, best_l = CV_Local_Inference(Voters_By_County, M_sur, J, Ethnicity_Marginals, Party_Marginals,\n", 497 | " CV_counties,q,l)" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": {}, 503 | "source": [ 504 | "Using these parameters, run the inference on the rest of the dataset" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": null, 510 | "metadata": { 511 | "collapsed": false 512 | }, 513 | "outputs": [], 514 | "source": [ 515 | "J_sur = Local_Inference(Voters_By_County, M_sur, J, Ethnicity_Marginals, Party_Marginals, all_counties, best_q, best_l)\n", 516 | "kl, std = KL(J, J_sur, all_counties, save_to_file=False, compute_abs_err=True)" 517 | ] 518 | }, 519 | { 520 | "cell_type": "markdown", 521 | "metadata": {}, 522 | "source": [ 523 | "Plot correlation with ground truth" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": null, 529 | "metadata": { 530 | "collapsed": true 531 | }, 532 | "outputs": [], 533 | "source": [ 534 | "j_sur = []\n", 535 | "for c in all_counties:\n", 536 | " j_sur.append(np.array(J_sur[c]).flatten())\n", 537 | "\n", 538 | "j_sur = np.array(j_sur).flatten()\n", 539 | "\n", 540 | "plt.figure()\n", 541 | "plt.scatter(j_true, j_sur, alpha=0.5)\n", 542 | "plt.xlabel('Ground truth')\n", 543 | "plt.ylabel('TROT (survey)')\n", 544 | "plt.plot(diag, diag, 'r--')\n", 545 | "\n", 546 | "#plt.show()\n", 547 | " " 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": {}, 553 | "source": [ 554 | "Plot error distribution (compared with Florida average)" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": null, 560 | "metadata": { 561 | "collapsed": false 562 | }, 563 | "outputs": [], 564 | "source": [ 565 | "plt.figure()\n", 566 | "bins = np.arange(-.3, .6, 0.01)\n", 567 | "plt.hist(j_true - j_sur, bins=bins, alpha=0.5, label='TROT (survey)')\n", 568 | "plt.hist(j_true - j_baseline, bins=bins, alpha=0.5, label='Florida-average')\n", 569 | "plt.legend()\n", 570 | "plt.xlabel('Difference between inference and ground truth')\n", 571 | "\n", 572 | "#plt.show()" 573 | ] 574 | } 575 | ], 576 | "metadata": { 577 | "anaconda-cloud": {}, 578 | "kernelspec": { 579 | "display_name": "Python 3", 580 | "language": "python", 581 | "name": "python3" 582 | }, 583 | "language_info": { 584 | "codemirror_mode": { 585 | "name": "ipython", 586 | "version": 3 587 | }, 588 | "file_extension": ".py", 589 | "mimetype": "text/x-python", 590 | "name": "python", 591 | "nbconvert_exporter": "python", 592 | "pygments_lexer": "ipython3", 593 | "version": "3.5.1" 594 | } 595 | }, 596 | "nbformat": 4, 597 | "nbformat_minor": 0 598 | } 599 | -------------------------------------------------------------------------------- /Notebooks/.ipynb_checkpoints/Ecological Inference-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Ecological Inference through Tsallis Regularized Optimal Transport (TROT)\n", 10 | "This notebook presents the pipeline used in (cite our paper) to perform ecological inference on the Florida dataset.\n", 11 | "\n", 12 | "You will first want to download the dataset from (url to the dataset)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": { 19 | "collapsed": false 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "import pandas as pd\n", 24 | "import numpy as np\n", 25 | "import pickle\n", 26 | "from matplotlib import pyplot as plt\n", 27 | "from matplotlib.pylab import savefig\n", 28 | "\n", 29 | "import sys\n", 30 | "sys.path.append('..')\n", 31 | "sys.path.append('../Trot')\n", 32 | "sys.path.append('../Data')\n", 33 | "\n", 34 | "from Trot import Distances as dist\n", 35 | "from Trot.Evaluation import KL\n", 36 | "from Trot.Florida_inference import CV_Local_Inference, Local_Inference" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "# Data Loading and Processing" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": { 50 | "collapsed": false 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "FlData = pd.read_csv('../Fl_Data.csv', usecols = ['District', 'County','Voters_Age', 'Voters_Gender', 'PID', 'vote08', \n", 55 | " 'SR.WHI', 'SR.BLA', 'SR.HIS', 'SR.ASI', 'SR.NAT', 'SR.OTH']) \n", 56 | "\n", 57 | "FlData = FlData.dropna()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "Change gender values to numerical values" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": { 71 | "collapsed": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "FlData['Voters_Gender'] = FlData['Voters_Gender'].map({'M': 1, 'F': 0})" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "Renormalize the age so that it takes values between 0 and 1" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "metadata": { 89 | "collapsed": true 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "FlData['Voters_Age'] = ((FlData['Voters_Age'] -\n", 94 | " FlData['Voters_Age'].min()) /\n", 95 | " (FlData['Voters_Age'].max() -\n", 96 | " FlData['Voters_Age'].min()))\n" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "One-hot party subscriptions (PID)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": { 110 | "collapsed": true 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "#Get one hot encoding of column PID\n", 115 | "one_hot = pd.get_dummies(FlData['PID'])\n", 116 | "# Drop column PID as it is now encoded\n", 117 | "FlData = FlData.drop('PID', axis=1)\n", 118 | "# Join the encoded df\n", 119 | "FlData = FlData.join(one_hot)\n", 120 | "# Rename the new columns\n", 121 | "FlData.rename(columns={0: 'Other', 1: 'Democrat', 2: 'Republican'},\n", 122 | " inplace=True)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 29, 128 | "metadata": { 129 | "collapsed": false 130 | }, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/html": [ 135 | "
\n", 136 | "\n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | "
DistrictCountyVoters_AgeVoters_Gendervote08SR.WHISR.BLASR.HISSR.ASISR.NATSR.OTHOtherDemocratRepublican
count18205.0000001820518205.00000018205.00000018205.00000018205.00000018205.00000018205.00000018205.00000018205.00000018205.00000018205.00000018205.00000018205.000000
mean5.429827120.3611230.4730020.8201590.7687450.1039270.0497670.0264760.0035160.0496020.1993960.5432570.257347
std1.17707200.2406380.4992840.3840650.4216470.3051750.2174680.1605510.0591890.2171270.3995570.4981390.437184
min3.000000120.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
25%6.000000120.1250000.0000001.0000001.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
50%6.000000120.3500000.0000001.0000001.0000000.0000000.0000000.0000000.0000000.0000000.0000001.0000000.000000
75%6.000000120.5500001.0000001.0000001.0000000.0000000.0000000.0000000.0000000.0000000.0000001.0000001.000000
max6.000000121.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.0000001.000000
\n", 295 | "
" 296 | ], 297 | "text/plain": [ 298 | " District County Voters_Age Voters_Gender vote08 \\\n", 299 | "count 18205.000000 18205 18205.000000 18205.000000 18205.000000 \n", 300 | "mean 5.429827 12 0.361123 0.473002 0.820159 \n", 301 | "std 1.177072 0 0.240638 0.499284 0.384065 \n", 302 | "min 3.000000 12 0.000000 0.000000 0.000000 \n", 303 | "25% 6.000000 12 0.125000 0.000000 1.000000 \n", 304 | "50% 6.000000 12 0.350000 0.000000 1.000000 \n", 305 | "75% 6.000000 12 0.550000 1.000000 1.000000 \n", 306 | "max 6.000000 12 1.000000 1.000000 1.000000 \n", 307 | "\n", 308 | " SR.WHI SR.BLA SR.HIS SR.ASI SR.NAT \\\n", 309 | "count 18205.000000 18205.000000 18205.000000 18205.000000 18205.000000 \n", 310 | "mean 0.768745 0.103927 0.049767 0.026476 0.003516 \n", 311 | "std 0.421647 0.305175 0.217468 0.160551 0.059189 \n", 312 | "min 0.000000 0.000000 0.000000 0.000000 0.000000 \n", 313 | "25% 1.000000 0.000000 0.000000 0.000000 0.000000 \n", 314 | "50% 1.000000 0.000000 0.000000 0.000000 0.000000 \n", 315 | "75% 1.000000 0.000000 0.000000 0.000000 0.000000 \n", 316 | "max 1.000000 1.000000 1.000000 1.000000 1.000000 \n", 317 | "\n", 318 | " SR.OTH Other Democrat Republican \n", 319 | "count 18205.000000 18205.000000 18205.000000 18205.000000 \n", 320 | "mean 0.049602 0.199396 0.543257 0.257347 \n", 321 | "std 0.217127 0.399557 0.498139 0.437184 \n", 322 | "min 0.000000 0.000000 0.000000 0.000000 \n", 323 | "25% 0.000000 0.000000 0.000000 0.000000 \n", 324 | "50% 0.000000 0.000000 1.000000 0.000000 \n", 325 | "75% 0.000000 0.000000 1.000000 1.000000 \n", 326 | "max 1.000000 1.000000 1.000000 1.000000 " 327 | ] 328 | }, 329 | "execution_count": 29, 330 | "metadata": {}, 331 | "output_type": "execute_result" 332 | } 333 | ], 334 | "source": [ 335 | "FlData.describe()" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": {}, 341 | "source": [ 342 | "# Compute Marginals and Joint Distributions" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "metadata": {}, 348 | "source": [ 349 | "Create a county dictionnary" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 30, 355 | "metadata": { 356 | "collapsed": false 357 | }, 358 | "outputs": [], 359 | "source": [ 360 | "Voters_By_County = {}\n", 361 | "all_counties = FlData.County.unique()\n", 362 | "for county in all_counties:\n", 363 | " Voters_By_County[county] = FlData[FlData['County'] == county]" 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "metadata": {}, 369 | "source": [ 370 | "Compute the ground truth joint distribution" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 31, 376 | "metadata": { 377 | "collapsed": false 378 | }, 379 | "outputs": [], 380 | "source": [ 381 | "J = {}\n", 382 | "for county in all_counties:\n", 383 | " J[county] = np.zeros((6, 3))\n", 384 | "\n", 385 | " J[county][0,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.WHI']==1)].shape[0]\n", 386 | " J[county][0,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.WHI']==1)].shape[0]\n", 387 | " J[county][0,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.WHI']==1)].shape[0]\n", 388 | "\n", 389 | " J[county][1,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.BLA']==1)].shape[0]\n", 390 | " J[county][1,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.BLA']==1)].shape[0]\n", 391 | " J[county][1,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.BLA']==1)].shape[0]\n", 392 | "\n", 393 | " J[county][2,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.HIS']==1)].shape[0]\n", 394 | " J[county][2,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.HIS']==1)].shape[0]\n", 395 | " J[county][2,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.HIS']==1)].shape[0]\n", 396 | "\n", 397 | " J[county][3,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.ASI']==1)].shape[0]\n", 398 | " J[county][3,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.ASI']==1)].shape[0]\n", 399 | " J[county][3,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.ASI']==1)].shape[0]\n", 400 | "\n", 401 | " J[county][4,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) &(Voters_By_County[county]['SR.NAT']==1)].shape[0]\n", 402 | " J[county][4,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.NAT']==1)].shape[0]\n", 403 | " J[county][4,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.NAT']==1)].shape[0]\n", 404 | "\n", 405 | " J[county][5,0] = Voters_By_County[county].loc[(Voters_By_County[county]['Other'] ==1) & (Voters_By_County[county]['SR.OTH']==1)].shape[0]\n", 406 | " J[county][5,1] = Voters_By_County[county].loc[(Voters_By_County[county]['Democrat'] ==1) & (Voters_By_County[county]['SR.OTH']==1)].shape[0]\n", 407 | " J[county][5,2] = Voters_By_County[county].loc[(Voters_By_County[county]['Republican'] ==1) & (Voters_By_County[county]['SR.OTH']==1)].shape[0]\n", 408 | "\n", 409 | " J[county] /= J[county].sum()" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 32, 415 | "metadata": { 416 | "collapsed": false 417 | }, 418 | "outputs": [ 419 | { 420 | "name": "stdout", 421 | "output_type": "stream", 422 | "text": [ 423 | "[[ 0.14225414 0.39540621 0.22952527]\n", 424 | " [ 0.01178599 0.08853196 0.00339875]\n", 425 | " [ 0.01584256 0.02318825 0.0106348 ]\n", 426 | " [ 0.01074444 0.01057998 0.00509813]\n", 427 | " [ 0.00076746 0.00197347 0.00076746]\n", 428 | " [ 0.01792567 0.02351716 0.00805833]]\n" 429 | ] 430 | } 431 | ], 432 | "source": [ 433 | "print(J[12])" 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "metadata": {}, 439 | "source": [ 440 | "Compute the party marginals" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 33, 446 | "metadata": { 447 | "collapsed": false 448 | }, 449 | "outputs": [], 450 | "source": [ 451 | "Party_Marginals = {}\n", 452 | "parties = ['Other', 'Democrat', 'Republican']\n", 453 | "for county in all_counties:\n", 454 | " Party_Marginals[county] = pd.Series([J[county][:, i].sum()\n", 455 | " for i in np.arange(3)])\n", 456 | " Party_Marginals[county].index = parties" 457 | ] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "metadata": {}, 462 | "source": [ 463 | "Compute the ethnicity marginals" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": 34, 469 | "metadata": { 470 | "collapsed": false 471 | }, 472 | "outputs": [], 473 | "source": [ 474 | "Ethnicity_Marginals = {}\n", 475 | "ethnies = ['SR.WHI', 'SR.BLA', 'SR.HIS', 'SR.ASI', 'SR.NAT', 'SR.OTH']\n", 476 | "for county in all_counties:\n", 477 | " Ethnicity_Marginals[county] = pd.Series([J[county][i, :].sum()\n", 478 | " for i in np.arange(6)])\n", 479 | " Ethnicity_Marginals[county].index = ethnies" 480 | ] 481 | }, 482 | { 483 | "cell_type": "markdown", 484 | "metadata": {}, 485 | "source": [ 486 | "# Compute the cost matrix\n", 487 | "Using only age, gender, and 2008 vote or abstention" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": 35, 493 | "metadata": { 494 | "collapsed": false 495 | }, 496 | "outputs": [], 497 | "source": [ 498 | "features = ['Voters_Age', 'Voters_Gender', 'vote08']\n", 499 | "e_len, p_len = len(ethnies), len(parties)\n", 500 | "M = np.zeros((e_len, p_len))\n", 501 | "for i, e in enumerate(ethnies):\n", 502 | " data_e = FlData[FlData[e] == 1.0]\n", 503 | " average_by_e = data_e[features].mean(axis=0)\n", 504 | " for j, p in enumerate(parties):\n", 505 | " data_p = FlData[FlData[p] == 1.0]\n", 506 | " average_by_p = data_p[features].mean(axis=0)\n", 507 | "\n", 508 | " M[i, j] = np.array(dist.dist_2(average_by_e, average_by_p))" 509 | ] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "metadata": {}, 514 | "source": [ 515 | "# Start the inference" 516 | ] 517 | }, 518 | { 519 | "cell_type": "markdown", 520 | "metadata": {}, 521 | "source": [ 522 | "Use a specific county or district to select the best parameters" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 36, 528 | "metadata": { 529 | "collapsed": true 530 | }, 531 | "outputs": [], 532 | "source": [ 533 | "CV_counties = FlData[FlData['District'] == 3].County.unique()" 534 | ] 535 | }, 536 | { 537 | "cell_type": "markdown", 538 | "metadata": {}, 539 | "source": [ 540 | "Find the best parameters" 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": 37, 546 | "metadata": { 547 | "collapsed": false 548 | }, 549 | "outputs": [ 550 | { 551 | "name": "stdout", 552 | "output_type": "stream", 553 | "text": [ 554 | "q: 0.50, lambda: 0.0100, KL: 0.07895, STD: 0\n", 555 | "q: 0.50, lambda: 0.1000, KL: 0.07859, STD: 0\n", 556 | "q: 0.50, lambda: 1.0000, KL: 0.07556, STD: 0\n", 557 | "q: 0.50, lambda: 10.0000, KL: 0.1046, STD: 0\n", 558 | "q: 0.50, lambda: 100.0000, KL: 0.3701, STD: 0\n", 559 | "q: 0.60, lambda: 0.0100, KL: 0.07369, STD: 0\n", 560 | "q: 0.60, lambda: 0.1000, KL: 0.07325, STD: 0\n", 561 | "q: 0.60, lambda: 1.0000, KL: 0.06959, STD: 0\n", 562 | "q: 0.60, lambda: 10.0000, KL: 0.1099, STD: 0\n", 563 | "q: 0.60, lambda: 100.0000, KL: 0.3799, STD: 0\n", 564 | "q: 0.70, lambda: 0.0100, KL: 0.06793, STD: 0\n", 565 | "q: 0.70, lambda: 0.1000, KL: 0.06737, STD: 0\n", 566 | "q: 0.70, lambda: 1.0000, KL: 0.0628, STD: 0\n", 567 | "q: 0.70, lambda: 10.0000, KL: 0.1208, STD: 0\n", 568 | "q: 0.70, lambda: 100.0000, KL: 0.3884, STD: 0\n", 569 | "q: 0.80, lambda: 0.0100, KL: 0.06197, STD: 0\n", 570 | "q: 0.80, lambda: 0.1000, KL: 0.06123, STD: 0\n", 571 | "q: 0.80, lambda: 1.0000, KL: 0.05535, STD: 0\n", 572 | "q: 0.80, lambda: 10.0000, KL: 0.1395, STD: 0\n", 573 | "q: 0.80, lambda: 100.0000, KL: 0.3934, STD: 0\n", 574 | "q: 0.90, lambda: 0.0100, KL: 0.05649, STD: 0\n", 575 | "q: 0.90, lambda: 0.1000, KL: 0.05549, STD: 0\n", 576 | "q: 0.90, lambda: 1.0000, KL: 0.04767, STD: 0\n", 577 | "q: 0.90, lambda: 10.0000, KL: 0.1685, STD: 0\n", 578 | "q: 0.90, lambda: 100.0000, KL: 0.3953, STD: 0\n", 579 | "q: 1.00, lambda: 0.0100, KL: 0.05186, STD: 0\n", 580 | "q: 1.00, lambda: 0.1000, KL: 0.05049, STD: 0\n", 581 | "q: 1.00, lambda: 1.0000, KL: 0.04006, STD: 0\n", 582 | "q: 1.00, lambda: 10.0000, KL: 0.2093, STD: 0\n", 583 | "q: 1.00, lambda: 100.0000, KL: 0.3979, STD: 0\n", 584 | "q: 1.10, lambda: 0.0100, KL: 0.05202, STD: 0\n", 585 | "q: 1.10, lambda: 0.1000, KL: 0.05202, STD: 0\n", 586 | "q: 1.10, lambda: 1.0000, KL: 0.04534, STD: 0\n", 587 | "q: 1.10, lambda: 10.0000, KL: 0.09287, STD: 0\n", 588 | "q: 1.10, lambda: 100.0000, KL: 0.06961, STD: 0\n", 589 | "q: 1.20, lambda: 0.0100, KL: 0.05202, STD: 0\n", 590 | "q: 1.20, lambda: 0.1000, KL: 0.05202, STD: 0\n", 591 | "q: 1.20, lambda: 1.0000, KL: 0.05202, STD: 0\n", 592 | "q: 1.20, lambda: 10.0000, KL: 0.09251, STD: 0\n", 593 | "q: 1.20, lambda: 100.0000, KL: 0.06962, STD: 0\n", 594 | "q: 1.30, lambda: 0.0100, KL: 0.05202, STD: 0\n", 595 | "q: 1.30, lambda: 0.1000, KL: 0.05202, STD: 0\n", 596 | "q: 1.30, lambda: 1.0000, KL: 0.05202, STD: 0\n", 597 | "q: 1.30, lambda: 10.0000, KL: 0.09182, STD: 0\n", 598 | "q: 1.30, lambda: 100.0000, KL: 0.0696, STD: 0\n", 599 | "q: 1.40, lambda: 0.0100, KL: 0.05202, STD: 0\n", 600 | "q: 1.40, lambda: 0.1000, KL: 0.05202, STD: 0\n", 601 | "q: 1.40, lambda: 1.0000, KL: 0.05202, STD: 0\n", 602 | "q: 1.40, lambda: 10.0000, KL: 0.09087, STD: 0\n", 603 | "q: 1.40, lambda: 100.0000, KL: 0.06955, STD: 0\n", 604 | "q: 1.50, lambda: 0.0100, KL: 0.05202, STD: 0\n", 605 | "q: 1.50, lambda: 0.1000, KL: 0.05202, STD: 0\n", 606 | "q: 1.50, lambda: 1.0000, KL: 0.05202, STD: 0\n", 607 | "q: 1.50, lambda: 10.0000, KL: 0.08974, STD: 0\n", 608 | "q: 1.50, lambda: 100.0000, KL: 0.06948, STD: 0\n", 609 | "q: 1.60, lambda: 0.0100, KL: 0.05202, STD: 0\n", 610 | "q: 1.60, lambda: 0.1000, KL: 0.05202, STD: 0\n", 611 | "q: 1.60, lambda: 1.0000, KL: 0.05202, STD: 0\n", 612 | "q: 1.60, lambda: 10.0000, KL: 0.0885, STD: 0\n", 613 | "q: 1.60, lambda: 100.0000, KL: 0.06939, STD: 0\n", 614 | "q: 1.70, lambda: 0.0100, KL: 0.05202, STD: 0\n", 615 | "q: 1.70, lambda: 0.1000, KL: 0.05202, STD: 0\n", 616 | "q: 1.70, lambda: 1.0000, KL: 0.05202, STD: 0\n", 617 | "q: 1.70, lambda: 10.0000, KL: 0.0872, STD: 0\n", 618 | "q: 1.70, lambda: 100.0000, KL: 0.06929, STD: 0\n", 619 | "q: 1.80, lambda: 0.0100, KL: 0.05202, STD: 0\n", 620 | "q: 1.80, lambda: 0.1000, KL: 0.05202, STD: 0\n", 621 | "q: 1.80, lambda: 1.0000, KL: 0.05202, STD: 0\n", 622 | "q: 1.80, lambda: 10.0000, KL: 0.08589, STD: 0\n", 623 | "q: 1.80, lambda: 100.0000, KL: 0.06919, STD: 0\n", 624 | "q: 1.90, lambda: 0.0100, KL: 0.05202, STD: 0\n", 625 | "q: 1.90, lambda: 0.1000, KL: 0.05202, STD: 0\n", 626 | "q: 1.90, lambda: 1.0000, KL: 0.05202, STD: 0\n", 627 | "q: 1.90, lambda: 10.0000, KL: 0.08459, STD: 0\n", 628 | "q: 1.90, lambda: 100.0000, KL: 0.06908, STD: 0\n", 629 | "q: 2.00, lambda: 0.0100, KL: 0.05202, STD: 0\n", 630 | "q: 2.00, lambda: 0.1000, KL: 0.05202, STD: 0\n", 631 | "q: 2.00, lambda: 1.0000, KL: 0.05202, STD: 0\n", 632 | "q: 2.00, lambda: 10.0000, KL: 0.08333, STD: 0\n", 633 | "q: 2.00, lambda: 100.0000, KL: 0.06897, STD: 0\n", 634 | "Best score: 0.04006, Best q: 1.00, Best lambda: 1.0000\t Standard Variance: 0\n", 635 | "\n" 636 | ] 637 | } 638 | ], 639 | "source": [ 640 | "q = np.arange(0.5, 2.1, 0.1)\n", 641 | "l = [0.01, 0.1, 1., 10., 100.] \n", 642 | "\n", 643 | "best_score, best_q, best_l = CV_Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals,\n", 644 | " CV_counties,q,l)" 645 | ] 646 | }, 647 | { 648 | "cell_type": "markdown", 649 | "metadata": { 650 | "collapsed": true 651 | }, 652 | "source": [ 653 | "Use selected parameters on the rest of the dataset" 654 | ] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": 38, 659 | "metadata": { 660 | "collapsed": false 661 | }, 662 | "outputs": [ 663 | { 664 | "name": "stdout", 665 | "output_type": "stream", 666 | "text": [ 667 | "Absolute error 0.0075675635082 + 0.0\n" 668 | ] 669 | } 670 | ], 671 | "source": [ 672 | "J_inferred = Local_Inference(Voters_By_County, M, J, Ethnicity_Marginals, Party_Marginals, all_counties, best_q, best_l)\n", 673 | "kl, std = KL(J, J_inferred, all_counties, save_to_file=False, compute_abs_err=True)" 674 | ] 675 | }, 676 | { 677 | "cell_type": "markdown", 678 | "metadata": {}, 679 | "source": [ 680 | "# Plot the results" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": 39, 686 | "metadata": { 687 | "collapsed": false 688 | }, 689 | "outputs": [], 690 | "source": [ 691 | "diag = np.linspace(-0.1, 1.0, 100)\n", 692 | "\n", 693 | "# pickle results\n", 694 | "f = open('../Data/joints_gallup.pkl', 'rb')\n", 695 | "J_true, J = pickle.load(f)\n", 696 | "\n", 697 | "f = open('../Data/baseline.pkl', 'rb')\n", 698 | "J_baseline = pickle.load(f)\n", 699 | "\n", 700 | "j_true, j, j_baseline = [], [], []\n", 701 | "for c in all_counties:\n", 702 | " j_true.append(np.array(J_true[c]).flatten())\n", 703 | " j.append(np.array(J_inferred[c]).flatten())\n", 704 | " j_baseline.append(np.array(J_baseline[c]).flatten())\n", 705 | "\n", 706 | "j_true = np.array(j_true).flatten()\n", 707 | "j = np.array(j).flatten()\n", 708 | "j_baseline = np.array(j_baseline).flatten()" 709 | ] 710 | }, 711 | { 712 | "cell_type": "markdown", 713 | "metadata": {}, 714 | "source": [ 715 | "Plot the correlation between the ground truth for the joint distribution and the infered distribution (the closer to the $x = y$ diagonal axis, the better" 716 | ] 717 | }, 718 | { 719 | "cell_type": "code", 720 | "execution_count": 40, 721 | "metadata": { 722 | "collapsed": false 723 | }, 724 | "outputs": [], 725 | "source": [ 726 | "plt.figure()\n", 727 | "plt.scatter(j_true, j, alpha=0.5)\n", 728 | "plt.xlabel('Ground truth')\n", 729 | "plt.ylabel('TROT (RBF)')\n", 730 | "plt.plot(diag, diag, 'r--')\n", 731 | "\n", 732 | "plt.show()" 733 | ] 734 | }, 735 | { 736 | "cell_type": "markdown", 737 | "metadata": {}, 738 | "source": [ 739 | "Plot the distribution of the error (the more packed around the origin of the $x$-axis, the better)" 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": 41, 745 | "metadata": { 746 | "collapsed": false 747 | }, 748 | "outputs": [], 749 | "source": [ 750 | "plt.figure()\n", 751 | "bins = np.arange(-.3, .6, 0.01)\n", 752 | "plt.hist(j_true - j, bins=bins, alpha=0.5, label='TROT')\n", 753 | "plt.hist(j_true - j_baseline, bins=bins, alpha=0.5, label='Florida-average')\n", 754 | "plt.legend()\n", 755 | "plt.xlabel('Difference between inference and ground truth')\n", 756 | "\n", 757 | "plt.show()" 758 | ] 759 | }, 760 | { 761 | "cell_type": "markdown", 762 | "metadata": {}, 763 | "source": [ 764 | "# Survey-based ecological inference\n", 765 | "Same pipeline, but using a cost matrix computed thanks to the 2013 Gallup survey. (http://www.gallup.com/poll/160373/democrats-racially-diverse-republicans-mostly-white.aspx)\n", 766 | "\n", 767 | "We assume that Gallup's Other = {Native, Other}\n", 768 | "\n", 769 | "The cost matrix M is computed as $1-p_{ij}$, where $p_{ij}$ is the proportion of people registered to party $j$ belonging to group $i$." 770 | ] 771 | }, 772 | { 773 | "cell_type": "code", 774 | "execution_count": 42, 775 | "metadata": { 776 | "collapsed": true 777 | }, 778 | "outputs": [], 779 | "source": [ 780 | "M_sur = np.array([\n", 781 | " [.38, .26, .35],\n", 782 | " [.29, .64, .05],\n", 783 | " [.50, .32, .13],\n", 784 | " [.46, .36, .17],\n", 785 | " [.49, .32, .18],\n", 786 | " [.49, .32, .18]\n", 787 | " ])\n", 788 | "M_sur = (1. - M_sur)" 789 | ] 790 | }, 791 | { 792 | "cell_type": "markdown", 793 | "metadata": {}, 794 | "source": [ 795 | "Once again, find the best parameters" 796 | ] 797 | }, 798 | { 799 | "cell_type": "code", 800 | "execution_count": 43, 801 | "metadata": { 802 | "collapsed": false 803 | }, 804 | "outputs": [ 805 | { 806 | "name": "stdout", 807 | "output_type": "stream", 808 | "text": [ 809 | "q: 0.50, lambda: 0.0100, KL: 0.1492, STD: 0\n", 810 | "q: 0.50, lambda: 0.1000, KL: 0.1473, STD: 0\n", 811 | "q: 0.50, lambda: 1.0000, KL: 0.1295, STD: 0\n", 812 | "q: 0.50, lambda: 10.0000, KL: 0.04164, STD: 0\n", 813 | "q: 0.50, lambda: 100.0000, KL: 0.08021, STD: 0\n", 814 | "q: 0.60, lambda: 0.0100, KL: 0.1409, STD: 0\n", 815 | "q: 0.60, lambda: 0.1000, KL: 0.1388, STD: 0\n", 816 | "q: 0.60, lambda: 1.0000, KL: 0.1189, STD: 0\n", 817 | "q: 0.60, lambda: 10.0000, KL: 0.03358, STD: 0\n", 818 | "q: 0.60, lambda: 100.0000, KL: 0.09233, STD: 0\n", 819 | "q: 0.70, lambda: 0.0100, KL: 0.1315, STD: 0\n", 820 | "q: 0.70, lambda: 0.1000, KL: 0.1291, STD: 0\n", 821 | "q: 0.70, lambda: 1.0000, KL: 0.1065, STD: 0\n", 822 | "q: 0.70, lambda: 10.0000, KL: 0.02749, STD: 0\n", 823 | "q: 0.70, lambda: 100.0000, KL: 0.1081, STD: 0\n", 824 | "q: 0.80, lambda: 0.0100, KL: 0.1212, STD: 0\n", 825 | "q: 0.80, lambda: 0.1000, KL: 0.1184, STD: 0\n", 826 | "q: 0.80, lambda: 1.0000, KL: 0.09284, STD: 0\n", 827 | "q: 0.80, lambda: 10.0000, KL: 0.0255, STD: 0\n", 828 | "q: 0.80, lambda: 100.0000, KL: 0.1252, STD: 0\n", 829 | "q: 0.90, lambda: 0.0100, KL: 0.1096, STD: 0\n", 830 | "q: 0.90, lambda: 0.1000, KL: 0.1064, STD: 0\n", 831 | "q: 0.90, lambda: 1.0000, KL: 0.07785, STD: 0\n", 832 | "q: 0.90, lambda: 10.0000, KL: 0.029, STD: 0\n", 833 | "q: 0.90, lambda: 100.0000, KL: 0.145, STD: 0\n", 834 | "q: 1.00, lambda: 0.0100, KL: 0.09918, STD: 0\n", 835 | "q: 1.00, lambda: 0.1000, KL: 0.09545, STD: 0\n", 836 | "q: 1.00, lambda: 1.0000, KL: 0.06365, STD: 0\n", 837 | "q: 1.00, lambda: 10.0000, KL: 0.03764, STD: 0\n", 838 | "q: 1.00, lambda: 100.0000, KL: 0.1643, STD: 0\n", 839 | "q: 1.10, lambda: 0.0100, KL: 0.0996, STD: 0\n", 840 | "q: 1.10, lambda: 0.1000, KL: 0.0996, STD: 0\n", 841 | "q: 1.10, lambda: 1.0000, KL: 0.07086, STD: 0\n", 842 | "q: 1.10, lambda: 10.0000, KL: 0.0241, STD: 0\n", 843 | "q: 1.10, lambda: 100.0000, KL: 0.02334, STD: 0\n", 844 | "q: 1.20, lambda: 0.0100, KL: 0.0996, STD: 0\n", 845 | "q: 1.20, lambda: 0.1000, KL: 0.0996, STD: 0\n", 846 | "q: 1.20, lambda: 1.0000, KL: 0.07716, STD: 0\n", 847 | "q: 1.20, lambda: 10.0000, KL: 0.02393, STD: 0\n", 848 | "q: 1.20, lambda: 100.0000, KL: 0.02334, STD: 0\n", 849 | "q: 1.30, lambda: 0.0100, KL: 0.0996, STD: 0\n", 850 | "q: 1.30, lambda: 0.1000, KL: 0.0996, STD: 0\n", 851 | "q: 1.30, lambda: 1.0000, KL: 0.08202, STD: 0\n", 852 | "q: 1.30, lambda: 10.0000, KL: 0.02393, STD: 0\n", 853 | "q: 1.30, lambda: 100.0000, KL: 0.02334, STD: 0\n", 854 | "q: 1.40, lambda: 0.0100, KL: 0.0996, STD: 0\n", 855 | "q: 1.40, lambda: 0.1000, KL: 0.0996, STD: 0\n", 856 | "q: 1.40, lambda: 1.0000, KL: 0.0996, STD: 0\n", 857 | "q: 1.40, lambda: 10.0000, KL: 0.02396, STD: 0\n", 858 | "q: 1.40, lambda: 100.0000, KL: 0.02334, STD: 0\n", 859 | "q: 1.50, lambda: 0.0100, KL: 0.0996, STD: 0\n", 860 | "q: 1.50, lambda: 0.1000, KL: 0.0996, STD: 0\n", 861 | "q: 1.50, lambda: 1.0000, KL: 0.0996, STD: 0\n", 862 | "q: 1.50, lambda: 10.0000, KL: 0.02398, STD: 0\n", 863 | "q: 1.50, lambda: 100.0000, KL: 0.02334, STD: 0\n", 864 | "q: 1.60, lambda: 0.0100, KL: 0.0996, STD: 0\n", 865 | "q: 1.60, lambda: 0.1000, KL: 0.0996, STD: 0\n", 866 | "q: 1.60, lambda: 1.0000, KL: 0.0996, STD: 0\n", 867 | "q: 1.60, lambda: 10.0000, KL: 0.02398, STD: 0\n", 868 | "q: 1.60, lambda: 100.0000, KL: 0.02333, STD: 0\n", 869 | "q: 1.70, lambda: 0.0100, KL: 0.0996, STD: 0\n", 870 | "q: 1.70, lambda: 0.1000, KL: 0.0996, STD: 0\n", 871 | "q: 1.70, lambda: 1.0000, KL: 0.0996, STD: 0\n", 872 | "q: 1.70, lambda: 10.0000, KL: 0.02395, STD: 0\n", 873 | "q: 1.70, lambda: 100.0000, KL: 0.02333, STD: 0\n", 874 | "q: 1.80, lambda: 0.0100, KL: 0.0996, STD: 0\n", 875 | "q: 1.80, lambda: 0.1000, KL: 0.0996, STD: 0\n", 876 | "q: 1.80, lambda: 1.0000, KL: 0.0996, STD: 0\n", 877 | "q: 1.80, lambda: 10.0000, KL: 0.02391, STD: 0\n", 878 | "q: 1.80, lambda: 100.0000, KL: 0.02333, STD: 0\n", 879 | "q: 1.90, lambda: 0.0100, KL: 0.0996, STD: 0\n", 880 | "q: 1.90, lambda: 0.1000, KL: 0.0996, STD: 0\n", 881 | "q: 1.90, lambda: 1.0000, KL: 0.0996, STD: 0\n", 882 | "q: 1.90, lambda: 10.0000, KL: 0.02385, STD: 0\n", 883 | "q: 1.90, lambda: 100.0000, KL: 0.02333, STD: 0\n", 884 | "q: 2.00, lambda: 0.0100, KL: 0.0996, STD: 0\n", 885 | "q: 2.00, lambda: 0.1000, KL: 0.0996, STD: 0\n", 886 | "q: 2.00, lambda: 1.0000, KL: 0.08625, STD: 0\n", 887 | "q: 2.00, lambda: 10.0000, KL: 0.0238, STD: 0\n", 888 | "q: 2.00, lambda: 100.0000, KL: 0.02333, STD: 0\n", 889 | "Best score: 0.02333, Best q: 2.00, Best lambda: 100.0000\t Standard Variance: 0\n", 890 | "\n" 891 | ] 892 | } 893 | ], 894 | "source": [ 895 | "best_score, best_q, best_l = CV_Local_Inference(Voters_By_County, M_sur, J, Ethnicity_Marginals, Party_Marginals,\n", 896 | " CV_counties,q,l)" 897 | ] 898 | }, 899 | { 900 | "cell_type": "markdown", 901 | "metadata": {}, 902 | "source": [ 903 | "Using these parameters, run the inference on the rest of the dataset" 904 | ] 905 | }, 906 | { 907 | "cell_type": "code", 908 | "execution_count": 44, 909 | "metadata": { 910 | "collapsed": false 911 | }, 912 | "outputs": [ 913 | { 914 | "name": "stdout", 915 | "output_type": "stream", 916 | "text": [ 917 | "Absolute error 0.0094930050661 + 0.0\n" 918 | ] 919 | } 920 | ], 921 | "source": [ 922 | "J_sur = Local_Inference(Voters_By_County, M_sur, J, Ethnicity_Marginals, Party_Marginals, all_counties, best_q, best_l)\n", 923 | "kl, std = KL(J, J_sur, all_counties, save_to_file=False, compute_abs_err=True)" 924 | ] 925 | }, 926 | { 927 | "cell_type": "markdown", 928 | "metadata": {}, 929 | "source": [ 930 | "Plot correlation with ground truth" 931 | ] 932 | }, 933 | { 934 | "cell_type": "code", 935 | "execution_count": 45, 936 | "metadata": { 937 | "collapsed": true 938 | }, 939 | "outputs": [], 940 | "source": [ 941 | "j_sur = []\n", 942 | "for c in all_counties:\n", 943 | " j_sur.append(np.array(J_sur[c]).flatten())\n", 944 | "\n", 945 | "j_sur = np.array(j_sur).flatten()\n", 946 | "\n", 947 | "plt.figure()\n", 948 | "plt.scatter(j_true, j_sur, alpha=0.5)\n", 949 | "plt.xlabel('Ground truth')\n", 950 | "plt.ylabel('TROT (survey)')\n", 951 | "plt.plot(diag, diag, 'r--')\n", 952 | "\n", 953 | "plt.show()\n", 954 | " " 955 | ] 956 | }, 957 | { 958 | "cell_type": "markdown", 959 | "metadata": {}, 960 | "source": [ 961 | "Plot error distribution (compared with Florida average)" 962 | ] 963 | }, 964 | { 965 | "cell_type": "code", 966 | "execution_count": 46, 967 | "metadata": { 968 | "collapsed": false 969 | }, 970 | "outputs": [], 971 | "source": [ 972 | "plt.figure()\n", 973 | "bins = np.arange(-.3, .6, 0.01)\n", 974 | "plt.hist(j_true - j_sur, bins=bins, alpha=0.5, label='TROT (survey)')\n", 975 | "plt.hist(j_true - j_baseline, bins=bins, alpha=0.5, label='Florida-average')\n", 976 | "plt.legend()\n", 977 | "plt.xlabel('Difference between inference and ground truth')\n", 978 | "\n", 979 | "plt.show()" 980 | ] 981 | } 982 | ], 983 | "metadata": { 984 | "anaconda-cloud": {}, 985 | "kernelspec": { 986 | "display_name": "Python [default]", 987 | "language": "python", 988 | "name": "python3" 989 | }, 990 | "language_info": { 991 | "codemirror_mode": { 992 | "name": "ipython", 993 | "version": 3 994 | }, 995 | "file_extension": ".py", 996 | "mimetype": "text/x-python", 997 | "name": "python", 998 | "nbconvert_exporter": "python", 999 | "pygments_lexer": "ipython3", 1000 | "version": "3.5.2" 1001 | } 1002 | }, 1003 | "nbformat": 4, 1004 | "nbformat_minor": 0 1005 | } 1006 | --------------------------------------------------------------------------------