├── .gitattributes ├── .gitignore ├── 00_fixed_effect.stan ├── 01_create_dyadic_data.R ├── 01_srm_stan.stan ├── 02_srm_stan_dyad.stan ├── 03_amen_stan.stan ├── 04_stan_fixed_lower_tri.stan ├── 05_minibatching_sampling.R ├── 05_minibatching_sampling.stan ├── README.md ├── amen_model_writeup.Rmd ├── amen_model_writeup.pdf ├── amen_model_writeup_cache ├── gfm │ ├── __packages │ ├── eda_dataviz_f3893eee081c18c4e1ccdd1c4d582b70.RData │ ├── eda_dataviz_f3893eee081c18c4e1ccdd1c4d582b70.rdb │ └── eda_dataviz_f3893eee081c18c4e1ccdd1c4d582b70.rdx ├── html │ └── __packages └── latex │ └── __packages ├── amen_model_writeup_files ├── figure-gfm │ └── eda_dataviz-1.png ├── figure-html │ ├── eda_dataviz-1.png │ ├── m0_ppc-1.png │ ├── m0_ppc-2.png │ ├── m1_ppc-1.png │ ├── m1_ppc-2.png │ ├── m3_ppc-1.png │ └── m3_ppc-2.png └── figure-latex │ ├── eda_dataviz-1.pdf │ ├── m0_ppc-1.pdf │ ├── m1_ppc-1.pdf │ └── m3_ppc-1.pdf └── lower_tri_fixed_amen.R /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | srm_amen_stan.rdata 3 | *.RData 4 | *.rdx 5 | *.rdb 6 | -------------------------------------------------------------------------------- /00_fixed_effect.stan: -------------------------------------------------------------------------------- 1 | 2 | data{ 3 | int n_nodes ; 4 | int n_dyads ; 5 | int N; //total obs. should be n_dyads * 2 6 | int sender_id[n_dyads * 2] ; 7 | int receiver_id[n_dyads * 2] ; 8 | real Y[n_dyads * 2] ; 9 | 10 | } 11 | parameters{ 12 | real intercept ; 13 | vector[n_nodes] sender_beta; 14 | vector[n_nodes] receiver_beta; 15 | } 16 | 17 | model{ 18 | intercept ~ normal(0, 5) ; 19 | sender_beta ~ normal(0, 5) ; 20 | receiver_beta ~ normal(0, 5) ; 21 | 22 | for(n in 1:N){ 23 | Y[n] ~ normal(intercept + 24 | sender_beta[sender_id[n]] + receiver_beta[receiver_id[n]], 1 ); 25 | } 26 | } 27 | generated quantities{ 28 | real Y_sim[N] ; 29 | for(n in 1:N){ 30 | Y_sim[n] = normal_rng(intercept + 31 | sender_beta[sender_id[n]] + receiver_beta[receiver_id[n]], 1) ; 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /01_create_dyadic_data.R: -------------------------------------------------------------------------------- 1 | library(amen) 2 | library(rstan) 3 | data(IR90s) 4 | 5 | 6 | # A function to convert a matrix of pairs to an edgelist 7 | # Assumes that rows == columns, but will check 8 | 9 | matrix_to_edgelist <- function(sociomatrix_to_convert){ 10 | 11 | if(nrow(sociomatrix_to_convert) != ncol(sociomatrix_to_convert)){ 12 | stop("nrows != ncols") 13 | } 14 | 15 | all_nodes <- expand.grid( 16 | list(rownames(sociomatrix_to_convert), colnames(sociomatrix_to_convert))) 17 | all_nodes[, 1] <- as.numeric(all_nodes[, 1]) 18 | all_nodes[, 2] <- as.numeric(all_nodes[, 2]) 19 | all_nodes <- all_nodes[all_nodes[, 1] != all_nodes[, 2], ] 20 | 21 | lookup_table <- data.frame( 22 | node_names = rownames(sociomatrix_to_convert), 23 | idx = as.numeric(factor(rownames(sociomatrix_to_convert)))) 24 | obs_values <- which(sociomatrix_to_convert != 0, arr.ind = TRUE) 25 | obs_values <- cbind(obs_values, sociomatrix_to_convert[obs_values]) 26 | 27 | edgelist <- merge( 28 | all_nodes, obs_values, by.x = c("Var1", "Var2"), 29 | by.y = c("row", "col"), all.x = TRUE) 30 | 31 | edgelist$V3 <- ifelse(is.na(edgelist$V3), 0, edgelist$V3) 32 | edgelist <- cbind( 33 | edgelist, 34 | rep(seq.int(from = 1, to = sum(edgelist[, 1] < edgelist[, 2])), 2)) 35 | edgelist <- cbind(edgelist, ifelse(edgelist[, 1] < edgelist[, 2], 1, 2)) 36 | colnames(edgelist) <- c("sender", "receiver", "y", "dyad_id", "sr_indicator") 37 | # for dyad list include a (1, 2) indicator for send/receive 38 | # if s < r -> 1, else 2 39 | 40 | return(list(edgelist = edgelist, lookup_table = lookup_table, 41 | n_nodes = max(edgelist[, 1]), n_dyads = max(edgelist[, 4]), 42 | N = nrow(edgelist))) 43 | 44 | } 45 | 46 | data_for_stan <- matrix_to_edgelist(IR90s$dyadvars[, , 2]) # trade data 47 | 48 | 49 | 50 | 51 | m0_code <- stan_model(file = "00_fixed_effect.stan" 52 | ) 53 | 54 | m0 <- vb(m0_code, 55 | data = list( 56 | N = data_for_stan$N, 57 | n_nodes = data_for_stan$n_nodes, 58 | n_dyads = data_for_stan$n_dyads, 59 | sender_id = data_for_stan$edgelist[, 1], 60 | receiver_id = data_for_stan$edgelist[, 2], 61 | Y = data_for_stan$edgelist[, 3]), 62 | seed = 123, 63 | # chains = 4, cores = 4, 64 | iter = 10000 65 | ) 66 | 67 | m0_params <- extract(m0) 68 | preds0 <- apply(m0_params$Y_sim, 2, mean) 69 | plot(data_for_stan$edgelist[, 3], preds0) 70 | 71 | m1_code <- stan_model(file = "01_srm_stan.stan" 72 | ) 73 | 74 | m1 <- vb(m1_code, 75 | data = list( 76 | N = data_for_stan$N, 77 | n_nodes = data_for_stan$n_nodes, 78 | n_dyads = data_for_stan$n_dyads, 79 | sender_id = data_for_stan$edgelist[, 1], 80 | receiver_id = data_for_stan$edgelist[, 2], 81 | Y = data_for_stan$edgelist[, 3]), 82 | seed = 123, 83 | # chains = 4, cores = 4, 84 | iter = 10000 85 | ) 86 | 87 | m1_params <- extract(m1) 88 | preds <- apply(m1_params$Y_sim, 2, mean) 89 | plot(data_for_stan$edgelist[, 3], preds) 90 | 91 | m2_code <- stan_model(file = "02_srm_stan_dyad.stan" 92 | ) 93 | 94 | m2 <- vb(m2_code, 95 | data = list( 96 | N = data_for_stan$N, 97 | n_nodes = data_for_stan$n_nodes, 98 | n_dyads = data_for_stan$n_dyads, 99 | sender_id = data_for_stan$edgelist[, 1], 100 | receiver_id = data_for_stan$edgelist[, 2], 101 | dyad_id = data_for_stan$edgelist[, 4], 102 | send_receive = data_for_stan$edgelist[, 5], 103 | Y = data_for_stan$edgelist[, 3]), 104 | seed = 123, 105 | # chains = 4, cores = 4, 106 | iter = 10000 107 | ) 108 | 109 | 110 | m2_params <- extract(m2) 111 | preds2 <- apply(m2_params$Y_sim, 2, mean) 112 | plot(data_for_stan$edgelist[, 3], preds2) 113 | 114 | 115 | m3_code <- stan_model(file = "03_amen_stan.stan" 116 | ) 117 | 118 | m3 <- vb(m3_code, 119 | data = list( 120 | N = data_for_stan$N, 121 | n_nodes = data_for_stan$n_nodes, 122 | n_dyads = data_for_stan$n_dyads, 123 | sender_id = data_for_stan$edgelist[, 1], 124 | receiver_id = data_for_stan$edgelist[, 2], 125 | K = 10, 126 | Y = data_for_stan$edgelist[, 3]), 127 | # chains = 4, cores = 4, 128 | iter = 10000, 129 | seed = 123) 130 | 131 | m3_params <- extract(m3) 132 | preds3 <- apply(m3_params$Y_sim, 2, mean) 133 | plot(data_for_stan$edgelist[, 3], preds3) 134 | 135 | mean((data_for_stan$edgelist[, 3] - preds0)^2) 136 | mean((data_for_stan$edgelist[, 3] - preds)^2) 137 | mean((data_for_stan$edgelist[, 3] - preds2)^2) 138 | 139 | mean((data_for_stan$edgelist[, 3] - preds3)^2) 140 | save(m0, m1, m2, m3, file = "srm_amen_stan.rdata") 141 | -------------------------------------------------------------------------------- /01_srm_stan.stan: -------------------------------------------------------------------------------- 1 | 2 | data{ 3 | int n_nodes ; 4 | int n_dyads ; 5 | int N; //total obs. should be n_dyads * 2 6 | int sender_id[n_dyads * 2] ; 7 | int receiver_id[n_dyads * 2] ; 8 | real Y[n_dyads * 2] ; 9 | 10 | } 11 | parameters{ 12 | real intercept ; 13 | cholesky_factor_corr[2] corr_nodes; // correlation matrix 14 | vector[2] sigma_nodes; // sd 15 | matrix[2, n_nodes] z_nodes ; // for node non-centered parameterization 16 | } 17 | transformed parameters{ 18 | matrix[n_nodes, 2] mean_nodes; 19 | mean_nodes = (diag_pre_multiply( 20 | sigma_nodes, corr_nodes) * z_nodes)'; // non-centered parameterization 21 | } 22 | model{ 23 | intercept ~ normal(0, 5) ; 24 | to_vector(z_nodes) ~ normal(0, 1) ; 25 | corr_nodes ~ lkj_corr_cholesky(5) ; 26 | sigma_nodes ~ gamma(1, 1) ; 27 | 28 | for(n in 1:N){ 29 | Y[n] ~ normal(intercept + 30 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2], 1 ); 31 | } 32 | } 33 | generated quantities{ 34 | real Y_sim[N] ; 35 | for(n in 1:N){ 36 | Y_sim[n] = normal_rng(intercept + 37 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2], 1) ; 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /02_srm_stan_dyad.stan: -------------------------------------------------------------------------------- 1 | 2 | data{ 3 | int n_nodes ; 4 | int n_dyads ; 5 | int N; //total obs. should be n_dyads * 2 6 | int sender_id[n_dyads * 2] ; 7 | int receiver_id[n_dyads * 2] ; 8 | int dyad_id[n_dyads * 2] ; 9 | int send_receive[n_dyads * 2] ; 10 | real Y[n_dyads * 2] ; 11 | 12 | } 13 | parameters{ 14 | real intercept ; 15 | cholesky_factor_corr[2] corr_nodes; // correlation matrix w/in hh 16 | vector[2] sigma_nodes; // sd w/in household 17 | matrix[2, n_nodes] z_nodes ; // for node non-centered parameterization 18 | cholesky_factor_corr[2] corr_dyads; // correlation matrix w/in hh 19 | real sigma_dyads; // sd w/in household 20 | matrix[2, n_dyads] z_dyads ; // for node non-centered parameterization 21 | } 22 | transformed parameters{ 23 | matrix[n_dyads,2] mean_dyads; 24 | matrix[n_nodes, 2] mean_nodes; 25 | 26 | mean_dyads = (diag_pre_multiply( 27 | rep_vector(sigma_dyads, 2), corr_dyads) * z_dyads)'; // sd *correlation 28 | mean_nodes = (diag_pre_multiply( 29 | sigma_nodes, corr_nodes) * z_nodes)'; // sd *correlation 30 | 31 | } 32 | model{ 33 | intercept ~ normal(0, 5) ; 34 | to_vector(z_nodes) ~ normal(0, 1) ; 35 | to_vector(z_dyads) ~ normal(0, 1) ; 36 | corr_nodes ~ lkj_corr_cholesky(5) ; 37 | corr_dyads ~ lkj_corr_cholesky(5) ; 38 | sigma_nodes ~ gamma(1, 1) ; 39 | sigma_dyads ~ gamma(1, 1) ; 40 | 41 | for(n in 1:N){ 42 | Y[n] ~ normal(intercept + 43 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2] + 44 | mean_dyads[dyad_id[n], send_receive[n]], 1 ); 45 | } 46 | } 47 | generated quantities{ 48 | real Y_sim[N] ; 49 | for(n in 1:N){ 50 | Y_sim[n] = normal_rng(intercept + 51 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2] + mean_dyads[dyad_id[n], send_receive[n]], 1) ; 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /03_amen_stan.stan: -------------------------------------------------------------------------------- 1 | 2 | data{ 3 | int n_nodes ; 4 | int n_dyads ; 5 | int N; //total obs. should be n_dyads * 2 6 | int sender_id[n_dyads * 2] ; 7 | int receiver_id[n_dyads * 2] ; 8 | int K ; // number of latent dimensions 9 | real Y[n_dyads * 2] ; 10 | 11 | } 12 | parameters{ 13 | real intercept ; 14 | cholesky_factor_corr[2] corr_nodes; // correlation matrix w/in hh 15 | vector[2] sigma_nodes; // sd w/in nodes 16 | matrix[2, n_nodes] z_nodes ; // for node non-centered parameterization 17 | cholesky_factor_corr[K * 2] corr_multi_effects ; // correlation matrix for multiplicative effect 18 | vector[K * 2] sigma_multi_effects ; // sd 19 | matrix[K * 2, n_nodes] z_multi_effects; // Multi-effect non-centered term 20 | 21 | } 22 | transformed parameters{ 23 | matrix[n_nodes, 2] mean_nodes; // node parameter mean 24 | matrix[n_nodes, K * 2] mean_multi_effects ; // multi-effect mean 25 | 26 | mean_nodes = (diag_pre_multiply( 27 | sigma_nodes, corr_nodes) * z_nodes)'; // sd *correlation 28 | mean_multi_effects = (diag_pre_multiply( 29 | sigma_multi_effects, corr_multi_effects) * z_multi_effects)'; // sd *correlation 30 | 31 | } 32 | model{ 33 | intercept ~ normal(0, 5) ; 34 | 35 | //node terms 36 | to_vector(z_nodes) ~ normal(0, 1) ; 37 | corr_nodes ~ lkj_corr_cholesky(5) ; 38 | sigma_nodes ~ gamma(1, 1) ; 39 | 40 | // multi-effect terms 41 | to_vector(z_multi_effects) ~ normal(0, 1) ; 42 | corr_multi_effects ~ lkj_corr_cholesky(5) ; 43 | sigma_multi_effects ~ gamma(1, 1) ; 44 | 45 | for(n in 1:N){ 46 | Y[n] ~ normal(intercept + 47 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2] + 48 | mean_multi_effects[sender_id[n], 1:K] * 49 | (mean_multi_effects[receiver_id[n], (K+1):(K*2)])', 50 | 1 ); 51 | } 52 | } 53 | generated quantities{ 54 | real Y_sim[N] ; 55 | for(n in 1:N){ 56 | Y_sim[n] = normal_rng(intercept + 57 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2] + 58 | mean_multi_effects[sender_id[n], 1:K] * 59 | (mean_multi_effects[receiver_id[n], (K+1):(K*2)])', 1) ; 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /04_stan_fixed_lower_tri.stan: -------------------------------------------------------------------------------- 1 | // code for an AMEN model where the lower triangle of the covariance matrices 2 | // are fixed, allowing latent components to be identified. 3 | data{ 4 | int n_nodes ; 5 | int n_dyads ; 6 | int N; //total obs. should be n_dyads * 2 7 | int sender_id[n_dyads * 2] ; 8 | int receiver_id[n_dyads * 2] ; 9 | int K ; // number of latent dimensions 10 | real Y[n_dyads * 2] ; 11 | 12 | } 13 | transformed data{ 14 | int lower_tri_size = (K * (K + 1))/2 ; 15 | } 16 | 17 | parameters{ 18 | real intercept ; 19 | cholesky_factor_cov[2] cov_nodes; // correlation matrix w/in hh 20 | matrix[2, n_nodes] z_nodes ; // for node non-centered parameterization 21 | // vector[lower_tri_size] cov_raw_values; 22 | cholesky_factor_cov[K * 2] cov_multi_effects ; // correlation matrix for multiplicative effect 23 | matrix[K * 2, n_nodes] z_multi_effects; // Multi-effect non-centered term 24 | 25 | } 26 | transformed parameters{ 27 | matrix[n_nodes, 2] mean_nodes; // node parameter mean 28 | matrix[n_nodes, K * 2] mean_multi_effects ; // multi-effect mean 29 | // vector [lower_tri_size] cov_values; 30 | 31 | // cov_values = 3 * tan(cov_raw_values); 32 | 33 | mean_nodes = (cov_nodes * z_nodes)' ; 34 | mean_multi_effects = (cov_multi_effects * z_multi_effects)'; // sd *correlation 35 | 36 | } 37 | model{ 38 | intercept ~ normal(0, 5) ; 39 | 40 | //node terms 41 | to_vector(z_nodes) ~ normal(0, 1) ; 42 | to_vector(cov_nodes) ~ student_t(3, 0, 2) ; 43 | 44 | // multi-effect terms 45 | to_vector(z_multi_effects) ~ normal(0, 1) ; 46 | to_vector(cov_multi_effects) ~ student_t(3, 0, 2) ; 47 | 48 | for(n in 1:N){ 49 | Y[n] ~ normal(intercept + 50 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2] + 51 | mean_multi_effects[sender_id[n], 1:K] * 52 | (mean_multi_effects[receiver_id[n], (K+1):(K*2)])', 53 | 1 ); 54 | } 55 | } 56 | generated quantities{ 57 | real Y_sim[N] ; 58 | for(n in 1:N){ 59 | Y_sim[n] = normal_rng(intercept + 60 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2] + 61 | mean_multi_effects[sender_id[n], 1:K] * 62 | (mean_multi_effects[receiver_id[n], (K+1):(K*2)])', 1) ; 63 | } 64 | 65 | } 66 | -------------------------------------------------------------------------------- /05_minibatching_sampling.R: -------------------------------------------------------------------------------- 1 | library(amen) 2 | library(rstan) 3 | data(IR90s) 4 | 5 | 6 | # A function to convert a matrix of pairs to an edgelist 7 | # Assumes that rows == columns, but will check 8 | 9 | matrix_to_edgelist <- function(sociomatrix_to_convert){ 10 | 11 | if(nrow(sociomatrix_to_convert) != ncol(sociomatrix_to_convert)){ 12 | stop("nrows != ncols") 13 | } 14 | 15 | all_nodes <- expand.grid( 16 | list(rownames(sociomatrix_to_convert), colnames(sociomatrix_to_convert))) 17 | all_nodes[, 1] <- as.numeric(all_nodes[, 1]) 18 | all_nodes[, 2] <- as.numeric(all_nodes[, 2]) 19 | all_nodes <- all_nodes[all_nodes[, 1] != all_nodes[, 2], ] 20 | 21 | lookup_table <- data.frame( 22 | node_names = rownames(sociomatrix_to_convert), 23 | idx = as.numeric(factor(rownames(sociomatrix_to_convert)))) 24 | obs_values <- which(sociomatrix_to_convert != 0, arr.ind = TRUE) 25 | obs_values <- cbind(obs_values, sociomatrix_to_convert[obs_values]) 26 | 27 | edgelist <- merge( 28 | all_nodes, obs_values, by.x = c("Var1", "Var2"), 29 | by.y = c("row", "col"), all.x = TRUE) 30 | 31 | edgelist$V3 <- ifelse(is.na(edgelist$V3), 0, edgelist$V3) 32 | edgelist <- cbind( 33 | edgelist, 34 | rep(seq.int(from = 1, to = sum(edgelist[, 1] < edgelist[, 2])), 2)) 35 | edgelist <- cbind(edgelist, ifelse(edgelist[, 1] < edgelist[, 2], 1, 2)) 36 | colnames(edgelist) <- c("sender", "receiver", "y", "dyad_id", "sr_indicator") 37 | # for dyad list include a (1, 2) indicator for send/receive 38 | # if s < r -> 1, else 2 39 | 40 | return(list(edgelist = edgelist, lookup_table = lookup_table, 41 | n_nodes = max(edgelist[, 1]), n_dyads = max(edgelist[, 4]), 42 | N = nrow(edgelist))) 43 | 44 | } 45 | 46 | data_for_stan <- matrix_to_edgelist(IR90s$dyadvars[, , 2]) # trade data 47 | 48 | 49 | -------------------------------------------------------------------------------- /05_minibatching_sampling.stan: -------------------------------------------------------------------------------- 1 | // TODO: go through, figure out how to make batch indexing work for sending/receiving nodes and dyads. 2 | data{ 3 | int n_nodes ; 4 | int n_dyads ; 5 | int N; //total obs. should be n_dyads * 2 6 | int sender_id[n_dyads * 2] ; 7 | int receiver_id[n_dyads * 2] ; 8 | int K ; // number of latent dimensions 9 | real Y[n_dyads * 2] ; 10 | int batch_size ; 11 | int dyad_idx[n_dyads] ; 12 | int n_batches ; 13 | } 14 | transformed data{ 15 | int lower_tri_size = (K * (K + 1))/2 ; 16 | simplex[n_dyads] uniform = rep_vector(1.0 / n_dyads, n_dyads); 17 | int batch_idxs[batch_size, n_batches]; 18 | for (n in 1:n_batches) 19 | for (i in 1:batch_size) 20 | batch_idxs[i , n] = categorical_rng(uniform) ; 21 | } 22 | 23 | parameters{ 24 | real intercept ; 25 | cholesky_factor_cov[2] cov_nodes; // correlation matrix w/in hh 26 | matrix[2, n_nodes] z_nodes ; // for node non-centered parameterization 27 | // vector[lower_tri_size] cov_raw_values; 28 | cholesky_factor_cov[K * 2] cov_multi_effects ; // correlation matrix for multiplicative effect 29 | matrix[K * 2, n_nodes] z_multi_effects; // Multi-effect non-centered term 30 | 31 | } 32 | transformed parameters{ 33 | matrix[n_nodes, 2] mean_nodes; // node parameter mean 34 | matrix[n_nodes, K * 2] mean_multi_effects ; // multi-effect mean 35 | // vector [lower_tri_size] cov_values; 36 | 37 | // cov_values = 3 * tan(cov_raw_values); 38 | 39 | mean_nodes = (cov_nodes * z_nodes)' ; 40 | mean_multi_effects = (cov_multi_effects * z_multi_effects)'; // sd *correlation 41 | 42 | } 43 | model{ 44 | intercept ~ normal(0, 5) ; 45 | 46 | //node terms 47 | to_vector(z_nodes) ~ normal(0, 1) ; 48 | to_vector(cov_nodes) ~ student_t(3, 0, 2) ; 49 | 50 | // multi-effect terms 51 | to_vector(z_multi_effects) ~ normal(0, 1) ; 52 | to_vector(cov_multi_effects) ~ student_t(3, 0, 2) ; 53 | 54 | for(batch in 1:n_batches){ 55 | for(obs in 1:batch_size){ 56 | 57 | 58 | Y[batch_idxs[obs, batch]] ~ normal(intercept + 59 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2] + 60 | mean_multi_effects[sender_id[n], 1:K] * 61 | (mean_multi_effects[receiver_id[n], (K+1):(K*2)])', 62 | 1 ); 63 | } 64 | } 65 | } 66 | generated quantities{ 67 | real Y_sim[N] ; 68 | for(n in 1:N){ 69 | Y_sim[n] = normal_rng(intercept + 70 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2] + 71 | mean_multi_effects[sender_id[n], 1:K] * 72 | (mean_multi_effects[receiver_id[n], (K+1):(K*2)])', 1) ; 73 | } 74 | 75 | } 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AMEN Models in Stan 2 | 3 | This repository contains code to implement Minhas, Hoff, and Ward's AMEN model in R and stan, via rstan. 4 | 5 | It contains the following files: 6 | 7 | 1. `00_fixed_effect.stan`: a basic fixed effects model in Stan. 8 | 2. `01_create_dyadic_data.R`: R 9 | 3. `01_srm_stan.stan`: A basic social relations model in Stan. 10 | 4. `02_srm_stan_dyad.stan`: A basic social relations model in Stan, with dyad effects. 11 | 5. `03_amen_stan.stan`: An implementation of the Additive and Multiplicative Effects Network model in Stan. 12 | 6. `amen_model_writeup.Rmd`: A `.Rmarkdown` writeup, explaining what's going on. 13 | 7. `amen_model_writeup.pdf`: A `.pdf` of said writeup. 14 | -------------------------------------------------------------------------------- /amen_model_writeup.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Implementing the Social Relations and AMEN model in Stan" 3 | author: "Adam M. Lauretig" 4 | date: "3/31/2020" 5 | output: pdf_document 6 | mainfont: Palatino 7 | header-includes: 8 | - \usepackage{palatino} 9 | --- 10 | 11 | # Introduction 12 | 13 | In a [recent review article](https://www.e-publications.org/ims/submission/STS/user/submissionFile/36407?confirm=150a239a), Peter Hoff discussed the Additive and Multiplicative Effects Network (AMEN) model, a Bayesian hierarchical model, developed by Minhas, Hoff, and Ward for modeling network data. This model can be understood as a hierarchical model with a more structured covariance, to account for the network dependence between observations. While there is currently an [implementation in R](https://cran.r-project.org/web/packages/amen/index.html), I wanted to explore building this model up in [Stan](https://mc-stan.org/), a probabilistic programming language, and taking advantage of variational inference to fit models more quickly. 14 | 15 | These models were all fit in stan using the `vb()` function, which meant they were fit quickly. Stan is fast, and provides a variety of diagnostic tools, though final work should probably use dynamic HMC/NUTS sampling. 16 | 17 | In the remaining sections, I gather some example data, and a fixed effects model, a social relations model, and an AMEN model to trade data from the 1990s. 18 | 19 | # Dataset 20 | 21 | I use a dataset included in the `amen` package to demonstrate model-building: exports from one country to another (in billions of dollars), averaged over the 1990s. I convert the sociomatrix to a real-valued edgelist, with indicators for sending and receiving countries, dyads, and the direction of the dyad. Plotting the data, we see the long-tailed distribution often characteristic of network data. 22 | 23 | \clearpage 24 | 25 | ```{r, eda_dataviz, cache = TRUE, echo = TRUE, eval = TRUE, warning=FALSE, message=FALSE, size = "small"} 26 | library(ggplot2) 27 | library(amen) 28 | library(rstan) 29 | 30 | 31 | # A function to convert a matrix of pairs to an edgelist 32 | # Assumes that rows == columns, but will check 33 | 34 | matrix_to_edgelist <- function(sociomatrix_to_convert){ 35 | 36 | if(nrow(sociomatrix_to_convert) != ncol(sociomatrix_to_convert)){ 37 | stop("nrows != ncols") 38 | } 39 | 40 | all_nodes <- expand.grid( 41 | list(rownames(sociomatrix_to_convert), colnames(sociomatrix_to_convert))) 42 | all_nodes[, 1] <- as.numeric(all_nodes[, 1]) 43 | all_nodes[, 2] <- as.numeric(all_nodes[, 2]) 44 | all_nodes <- all_nodes[all_nodes[, 1] != all_nodes[, 2], ] 45 | 46 | lookup_table <- data.frame( 47 | node_names = rownames(sociomatrix_to_convert), 48 | idx = as.numeric(factor(rownames(sociomatrix_to_convert)))) 49 | obs_values <- which(sociomatrix_to_convert != 0, arr.ind = TRUE) 50 | obs_values <- cbind(obs_values, sociomatrix_to_convert[obs_values]) 51 | 52 | edgelist <- merge( 53 | all_nodes, obs_values, by.x = c("Var1", "Var2"), 54 | by.y = c("row", "col"), all.x = TRUE) 55 | 56 | edgelist$V3 <- ifelse(is.na(edgelist$V3), 0, edgelist$V3) 57 | edgelist <- cbind( 58 | edgelist, 59 | rep(seq.int(from = 1, to = sum(edgelist[, 1] < edgelist[, 2])), 2)) 60 | edgelist <- cbind(edgelist, ifelse(edgelist[, 1] < edgelist[, 2], 1, 2)) 61 | colnames(edgelist) <- c("sender", "receiver", "y", "dyad_id", "sr_indicator") 62 | # for dyad list include a (1, 2) indicator for send/receive 63 | # if s < r -> 1, else 2 64 | 65 | return(list(edgelist = edgelist, lookup_table = lookup_table, 66 | n_nodes = max(edgelist[, 1]), n_dyads = max(edgelist[, 4]), 67 | N = nrow(edgelist))) 68 | 69 | } 70 | 71 | data_for_stan <- matrix_to_edgelist(IR90s$dyadvars[, , 2]) # trade data 72 | 73 | y_df <- data.frame(y = data_for_stan$edgelist$y) 74 | ggplot(data = y_df, aes(x = y)) + geom_histogram() + theme_minimal() + labs(x = "Export Volume", y = "Frequency", title = "Distribution of Export Volume") 75 | 76 | ``` 77 | 78 | \clearpage 79 | 80 | # Fixed Effects Model 81 | 82 | I begin with a simple fixed-effects model, where the outcome, $y_{i, j}$ is the exports from $\text{sender}_{i}$ to $\text{receiver}_{j}$. 83 | 84 | $$y_{i,j} \sim \mathcal{N}(\alpha + \beta_{\text{sender}_{i}} + \beta_{\text{receiver}_{j}}, 1) $$ 85 | $$\alpha \sim \mathcal{N}(0, 5) $$ 86 | 87 | $$\beta_{\text{sender}_{i}} \sim \mathcal{N}(0, 5)$$ 88 | 89 | $$\beta_{\text{receiver}_{j}} \sim \mathcal{N}(0, 5)$$ 90 | \clearpage 91 | 92 | ```{stan, output.var = "m0_stan", echo = TRUE, eval = FALSE, size = "small"} 93 | 94 | data{ 95 | int n_nodes ; 96 | int n_dyads ; 97 | int N; //total obs. should be n_dyads * 2 98 | int sender_id[n_dyads * 2] ; //indexing for sender 99 | int receiver_id[n_dyads * 2] ; //indexing for receivers 100 | real Y[n_dyads * 2] ; // outcome variable 101 | 102 | } 103 | parameters{ 104 | real intercept ; 105 | vector[n_nodes] sender_beta; //fixed effects 106 | vector[n_nodes] receiver_beta; //fixed effects 107 | } 108 | 109 | model{ 110 | intercept ~ normal(0, 5) ; 111 | sender_beta ~ normal(0, 5) ; 112 | receiver_beta ~ normal(0, 5) ; 113 | 114 | for(n in 1:N){ 115 | Y[n] ~ normal(intercept + 116 | sender_beta[sender_id[n]] + receiver_beta[receiver_id[n]], 1 ); 117 | } 118 | } 119 | generated quantities{ 120 | real Y_sim[N] ; 121 | for(n in 1:N){ 122 | Y_sim[n] = normal_rng(intercept + 123 | sender_beta[sender_id[n]] + receiver_beta[receiver_id[n]], 1) ; 124 | } 125 | 126 | } 127 | 128 | ``` 129 | 130 | 131 | ```{r, load_saved_objects, echo = FALSE, message = FALSE, warning = FALSE, cache = TRUE, size = "small"} 132 | library(rstan) 133 | load("srm_amen_stan.rdata") 134 | 135 | ``` 136 | 137 | 138 | \clearpage 139 | 140 | Examining the posterior predictive check from the fixed effects model, we see a model with some clear fit problems, for example, in the scatterplot, the axes are different by an order of magnitude. In this model, the point predictions are the posterior means of outcomes simulated from the fitted model. 141 | 142 | ```{r, m0_ppc, echo = TRUE, cache = TRUE, message=FALSE, warning=FALSE, size = "small"} 143 | library(bayesplot) 144 | color_scheme_set("red") 145 | 146 | m0_params <- extract(m0) 147 | to_visualize <- data.frame(y = y_df$y, y_sim = colMeans(m0_params$Y_sim)) 148 | mse_0 <- mean((to_visualize$y - to_visualize$y_sim)^2) 149 | 150 | 151 | #ppc_dens_overlay(y = y_df$y, yrep = m0_params$Y_sim) 152 | ggplot(data = to_visualize, aes(x = y, y = y_sim)) + 153 | geom_point(alpha = .1) + 154 | theme_minimal() + 155 | labs(x = "Observed Y", y = "Predicted Fit", title = paste0("Mean Squared Error is ", round(mse_0, 4))) 156 | ``` 157 | 158 | # Social Relations Model 159 | 160 | To improve model fit, we can incorporate covariance between sender and receiver. In concrete terms, this means that we assume that US exports covary with US imports. To model this, we turn to the *social relations model*, where the outcome, $y_{i, j}$ is the exports from $\text{sender}_{i}$ to $\text{receiver}_{j}$, however, we add additional structure to the covariances. 161 | 162 | $$y_{i,j} \sim \mathcal{N}(\alpha + \beta_{\text{sender}_{i}} + \beta_{\text{receiver}_{j}}, 1) $$ 163 | $$\alpha \sim \mathcal{N}(0, 5) $$ 164 | 165 | $$ \pmb{\beta}_{1\times2} \sim \mathcal{MVN}(\pmb{0}, \pmb{\Sigma})$$ 166 | $$ \pmb{\Sigma} = \pmb{\sigma} \pmb{\Omega}_{\text{cholesky}}$$ 167 | $$ \pmb{\sigma} \sim \text{Gam}(1, 1)$$ 168 | $$\pmb{\Omega}_{\text{cholesky}} \sim \mathcal{LKJ}(5)$$ 169 | 170 | \clearpage 171 | 172 | In Stan, this is written as follows, with the non-centered parameterization used for the Multivariate normal: 173 | 174 | ```{stan output.var="m1_srm", echo = TRUE, eval = FALSE, size = "small"} 175 | 176 | data{ 177 | int n_nodes ; 178 | int n_dyads ; 179 | int N; //total obs. should be n_dyads * 2 180 | int sender_id[n_dyads * 2] ; 181 | int receiver_id[n_dyads * 2] ; 182 | real Y[n_dyads * 2] ; 183 | 184 | } 185 | parameters{ 186 | real intercept ; 187 | cholesky_factor_corr[2] corr_nodes; // correlation matrix 188 | vector[2] sigma_nodes; // sd 189 | matrix[2, n_nodes] z_nodes ; // for node non-centered parameterization 190 | } 191 | transformed parameters{ 192 | matrix[n_nodes, 2] mean_nodes; 193 | mean_nodes = (diag_pre_multiply( 194 | sigma_nodes, corr_nodes) * z_nodes)'; // non-centered parameterization 195 | } 196 | model{ 197 | intercept ~ normal(0, 5) ; 198 | to_vector(z_nodes) ~ normal(0, 1) ; 199 | corr_nodes ~ lkj_corr_cholesky(5) ; 200 | sigma_nodes ~ gamma(1, 1) ; 201 | 202 | for(n in 1:N){ 203 | Y[n] ~ normal(intercept + 204 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2], 1 ); 205 | } 206 | } 207 | generated quantities{ 208 | real Y_sim[N] ; 209 | for(n in 1:N){ 210 | Y_sim[n] = normal_rng(intercept + 211 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2], 1) ; 212 | } 213 | 214 | } 215 | 216 | ``` 217 | 218 | \clearpage 219 | 220 | Examining fit from this model, we see that fit is largely the same, likely because despite allowing for covariance between sender and receiver effects, we are still imposing a strong additive relationship: 221 | 222 | 223 | ```{r, m1_ppc, echo = TRUE, cache = TRUE, message=FALSE, warning=FALSE, size = "small"} 224 | library(bayesplot) 225 | color_scheme_set("red") 226 | 227 | m1_params <- extract(m1) 228 | to_visualize1 <- data.frame(y = y_df$y, y_sim = colMeans(m1_params$Y_sim)) 229 | mse_1 <- mean((to_visualize1$y - to_visualize1$y_sim)^2) 230 | 231 | 232 | #ppc_dens_overlay(y = y_df$y, yrep = m1_params$Y_sim) 233 | ggplot(data = to_visualize1, aes(x = y, y = y_sim)) + 234 | geom_point(alpha = .1) + 235 | theme_minimal() + 236 | labs(x = "Observed Y", y = "Predicted Fit", title = paste0("Mean Squared Error is ", round(mse_1, 4))) 237 | ``` 238 | 239 | # Additive and Multiplicative Effects Network Model 240 | 241 | To relax the assumption of additivity in the parameters, [Peter Hoff](https://www.e-publications.org/ims/submission/STS/user/submissionFile/36407?confirm=150a239a) suggests the Additive and Multiplicative Effects Network (AMEN) Model. This adds a latent multiplicative effect to capture a higher-order network representation. This effect, for a given observation $y_{i,j}$ is modeled by the inner product $\pmb{u}_{\text{sender}_{i}} \cdot \pmb{v}_{\text{receiver}_{j}}^\top$, where $\pmb{u}$ and $\pmb{v}$ both have $1 \times K$ dimensions. These are concatenated together, and their covariance is modeled with a $2K \times 2K$ covariance matrix. 242 | 243 | 244 | $$y_{i,j} \sim \mathcal{N}(\alpha + \beta_{\text{sender}_{i}} + \beta_{\text{receiver}_{j}} + (\pmb{u}_{\text{sender}_{i}} \cdot \pmb{v}_{\text{receiver}_{j}}^\top), 1) $$ 245 | $$\alpha \sim \mathcal{N}(0, 5) $$ 246 | 247 | $$ \pmb{\beta}_{1\times2} \sim \mathcal{MVN}(\pmb{0}, \pmb{\Sigma})$$ 248 | $$ \pmb{\Sigma} = \pmb{\sigma} \pmb{\Omega}_{\text{cholesky}}$$ 249 | $$ \pmb{\sigma} \sim \text{Gam}(1, 1)$$ 250 | $$\pmb{\Omega}_{\text{cholesky}} \sim \mathcal{LKJ}(5)$$ 251 | $$ c(\pmb{u}_{\text{sender}_{i}}, \pmb{v}_{\text{receiver}_{j}}^\top) \sim \mathcal{MVN}(\pmb{0}, \pmb{\Sigma}_{\pmb{u}, \pmb{v}})$$ 252 | $$ \pmb{\Sigma}_{\pmb{u}, \pmb{v}} = \pmb{\sigma}_{\pmb{u}, \pmb{v}} \pmb{\Omega}_{\pmb{u}, \pmb{v}_{\text{cholesky}}}$$ 253 | $$ \pmb{\sigma}_{\pmb{u}, \pmb{v}} \sim \text{Gam}(1, 1)$$ 254 | $$\pmb{\Omega}_{\pmb{u}, \pmb{v}_{\text{cholesky}}} \sim \mathcal{LKJ}(5)$$ 255 | 256 | \clearpage 257 | In stan, this is written as follows, with non-centered parameterizations for all multivariate normals: 258 | ```{stan output.var="m3_amen", echo = TRUE, eval = FALSE, size = "small"} 259 | 260 | data{ 261 | int n_nodes ; 262 | int n_dyads ; 263 | int N; //total obs. should be n_dyads * 2 264 | int sender_id[n_dyads * 2] ; 265 | int receiver_id[n_dyads * 2] ; 266 | int K ; // number of latent dimensions 267 | real Y[n_dyads * 2] ; 268 | 269 | } 270 | parameters{ 271 | real intercept ; 272 | cholesky_factor_corr[2] corr_nodes; // correlation matrix for additive noteeffects 273 | vector[2] sigma_nodes; // sd w/in nodes 274 | matrix[2, n_nodes] z_nodes ; // for node non-centered parameterization 275 | cholesky_factor_corr[K * 2] corr_multi_effects ; // correlation matrix for multiplicative effect 276 | vector[K * 2] sigma_multi_effects ; // sd 277 | matrix[K * 2, n_nodes] z_multi_effects; // Multi-effect non-centered term 278 | 279 | } 280 | transformed parameters{ 281 | matrix[n_nodes, 2] mean_nodes; // node parameter mean 282 | matrix[n_nodes, K * 2] mean_multi_effects ; // multi-effect mean 283 | 284 | mean_nodes = (diag_pre_multiply( 285 | sigma_nodes, corr_nodes) * z_nodes)'; // sd *correlation 286 | mean_multi_effects = (diag_pre_multiply( 287 | sigma_multi_effects, corr_multi_effects) * z_multi_effects)'; // sd *correlation 288 | 289 | } 290 | model{ 291 | intercept ~ normal(0, 5) ; 292 | 293 | //node terms 294 | to_vector(z_nodes) ~ normal(0, 1) ; 295 | corr_nodes ~ lkj_corr_cholesky(5) ; 296 | sigma_nodes ~ gamma(1, 1) ; 297 | 298 | // multi-effect terms 299 | to_vector(z_multi_effects) ~ normal(0, 1) ; 300 | corr_multi_effects ~ lkj_corr_cholesky(5) ; 301 | sigma_multi_effects ~ gamma(1, 1) ; 302 | 303 | for(n in 1:N){ 304 | Y[n] ~ normal(intercept + 305 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2] + 306 | mean_multi_effects[sender_id[n], 1:K] * 307 | (mean_multi_effects[receiver_id[n], (K+1):(K*2)])', 308 | 1 ); 309 | } 310 | } 311 | generated quantities{ 312 | real Y_sim[N] ; 313 | for(n in 1:N){ 314 | Y_sim[n] = normal_rng(intercept + 315 | mean_nodes[sender_id[n], 1] + mean_nodes[receiver_id[n], 2] + 316 | mean_multi_effects[sender_id[n], 1:K] * 317 | (mean_multi_effects[receiver_id[n], (K+1):(K*2)])', 1) ; 318 | } 319 | 320 | } 321 | 322 | ``` 323 | 324 | \clearpage 325 | 326 | ```{r, m3_ppc, echo = TRUE, cache = TRUE, message=FALSE, warning=FALSE, size = "small"} 327 | library(bayesplot) 328 | color_scheme_set("red") 329 | 330 | m3_params <- extract(m3) 331 | to_visualize3 <- data.frame(y = y_df$y, y_sim = colMeans(m3_params$Y_sim)) 332 | mse_3 <- mean((to_visualize3$y - to_visualize3$y_sim)^2) 333 | 334 | 335 | #ppc_dens_overlay(y = y_df$y, yrep = m3_params$Y_sim) 336 | ggplot(data = to_visualize3, aes(x = y, y = y_sim)) + 337 | geom_point(alpha = .1) + 338 | theme_minimal() + 339 | labs(x = "Observed Y", y = "Predicted Fit", title = paste0("Mean Squared Error is ", round(mse_3, 4))) 340 | ``` 341 | 342 | We see that the AMEN model can much more effectively model the outcome variable, with a much lower mean squared error, and that the scale of the predicted outcomes is much closer to the actual outcomes. 343 | 344 | # Discussion 345 | 346 | We have seen that the multiplicative effects model does a far better job of modeling the distribution of $\pmb{y}$ than either the fixed effects model, or the social relations model. That the model fit changes so dramatically suggests the importance of taking network structure seriously in modeling. One further extension would involve incorporating covariates into the model, in order to take advantage of additional estimation when estimating model fit. 347 | 348 | -------------------------------------------------------------------------------- /amen_model_writeup.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup.pdf -------------------------------------------------------------------------------- /amen_model_writeup_cache/gfm/__packages: -------------------------------------------------------------------------------- 1 | base 2 | ggplot2 3 | amen 4 | -------------------------------------------------------------------------------- /amen_model_writeup_cache/gfm/eda_dataviz_f3893eee081c18c4e1ccdd1c4d582b70.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_cache/gfm/eda_dataviz_f3893eee081c18c4e1ccdd1c4d582b70.RData -------------------------------------------------------------------------------- /amen_model_writeup_cache/gfm/eda_dataviz_f3893eee081c18c4e1ccdd1c4d582b70.rdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_cache/gfm/eda_dataviz_f3893eee081c18c4e1ccdd1c4d582b70.rdb -------------------------------------------------------------------------------- /amen_model_writeup_cache/gfm/eda_dataviz_f3893eee081c18c4e1ccdd1c4d582b70.rdx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_cache/gfm/eda_dataviz_f3893eee081c18c4e1ccdd1c4d582b70.rdx -------------------------------------------------------------------------------- /amen_model_writeup_cache/html/__packages: -------------------------------------------------------------------------------- 1 | base 2 | ggplot2 3 | amen 4 | StanHeaders 5 | rstan 6 | bayesplot 7 | -------------------------------------------------------------------------------- /amen_model_writeup_cache/latex/__packages: -------------------------------------------------------------------------------- 1 | base 2 | ggplot2 3 | amen 4 | StanHeaders 5 | rstan 6 | bayesplot 7 | -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-gfm/eda_dataviz-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-gfm/eda_dataviz-1.png -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-html/eda_dataviz-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-html/eda_dataviz-1.png -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-html/m0_ppc-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-html/m0_ppc-1.png -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-html/m0_ppc-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-html/m0_ppc-2.png -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-html/m1_ppc-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-html/m1_ppc-1.png -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-html/m1_ppc-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-html/m1_ppc-2.png -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-html/m3_ppc-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-html/m3_ppc-1.png -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-html/m3_ppc-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-html/m3_ppc-2.png -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-latex/eda_dataviz-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-latex/eda_dataviz-1.pdf -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-latex/m0_ppc-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-latex/m0_ppc-1.pdf -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-latex/m1_ppc-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-latex/m1_ppc-1.pdf -------------------------------------------------------------------------------- /amen_model_writeup_files/figure-latex/m3_ppc-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamlauretig/AMEN_models_in_stan/1a7ce8adb9b05270fc2f9c0ed8933a4d57574e1f/amen_model_writeup_files/figure-latex/m3_ppc-1.pdf -------------------------------------------------------------------------------- /lower_tri_fixed_amen.R: -------------------------------------------------------------------------------- 1 | # fitting a stan model with an AMEN model with an identified lower triangle 2 | 3 | library(amen) 4 | library(rstan) 5 | library(data.table) 6 | data(IR90s) 7 | 8 | 9 | # A function to convert a matrix of pairs to an edgelist 10 | # Assumes that rows == columns, but will check 11 | 12 | matrix_to_edgelist <- function(sociomatrix_to_convert){ 13 | 14 | if(nrow(sociomatrix_to_convert) != ncol(sociomatrix_to_convert)){ 15 | stop("nrows != ncols") 16 | } 17 | 18 | all_nodes <- expand.grid( 19 | list(rownames(sociomatrix_to_convert), colnames(sociomatrix_to_convert))) 20 | all_nodes[, 1] <- as.numeric(all_nodes[, 1]) 21 | all_nodes[, 2] <- as.numeric(all_nodes[, 2]) 22 | all_nodes <- all_nodes[all_nodes[, 1] != all_nodes[, 2], ] 23 | 24 | lookup_table <- data.frame( 25 | node_names = rownames(sociomatrix_to_convert), 26 | idx = as.numeric(factor(rownames(sociomatrix_to_convert)))) 27 | obs_values <- which(sociomatrix_to_convert != 0, arr.ind = TRUE) 28 | obs_values <- cbind(obs_values, sociomatrix_to_convert[obs_values]) 29 | 30 | edgelist <- merge( 31 | all_nodes, obs_values, by.x = c("Var1", "Var2"), 32 | by.y = c("row", "col"), all.x = TRUE) 33 | 34 | edgelist$V3 <- ifelse(is.na(edgelist$V3), 0, edgelist$V3) 35 | edgelist <- cbind( 36 | edgelist, 37 | rep(seq.int(from = 1, to = sum(edgelist[, 1] < edgelist[, 2])), 2)) 38 | edgelist <- cbind(edgelist, ifelse(edgelist[, 1] < edgelist[, 2], 1, 2)) 39 | colnames(edgelist) <- c("sender", "receiver", "y", "dyad_id", "sr_indicator") 40 | # for dyad list include a (1, 2) indicator for send/receive 41 | # if s < r -> 1, else 2 42 | 43 | return(list(edgelist = edgelist, lookup_table = lookup_table, 44 | n_nodes = max(edgelist[, 1]), n_dyads = max(edgelist[, 4]), 45 | N = nrow(edgelist))) 46 | 47 | } 48 | 49 | data_for_stan <- matrix_to_edgelist(IR90s$dyadvars[, , 2]) # trade data 50 | 51 | m4_code <- stan_model(file = "04_stan_fixed_lower_tri.stan" 52 | ) 53 | 54 | 55 | latent_params <- apply(m4_params$mean_multi_effects, c(2:3), mean) 56 | latent_params_dt <- data.table(latent_params) 57 | setnames(latent_params_dt, c(paste0("sender_", 1:9), paste0("receiver_", 1:9))) 58 | latent_params_dt$country <- data_for_stan$lookup_table[, 1] 59 | latent_params_dt <- melt(latent_params_dt, id.vars = "country") 60 | latent_params_dt[, c("sr", "dim_num") := tstrsplit(variable, "_")] 61 | latent_params_dt <- dcast(latent_params_dt, country + dim_num ~ sr) 62 | ggplot(data = latent_params_dt, aes(x = sender, y = receiver)) + 63 | geom_text(aes(label = country)) + 64 | facet_wrap(~dim_num) 65 | ggplot(data = latent_params_dt, aes(x = sender, y = receiver)) + 66 | geom_text(aes(label = country)) + 67 | facet_wrap(~dim_num) + 68 | scale_x_continuous(limits = c(-4, 4)) + scale_y_continuous(limits = c(-4, 4)) 69 | ggplot(data = latent_params_dt, aes(x = sender, y = receiver)) + 70 | geom_text(aes(label = country)) + 71 | facet_wrap(~dim_num) + 72 | scale_x_continuous(limits = c(-1, 1)) + scale_y_continuous(limits = c(-1, 1)) 73 | --------------------------------------------------------------------------------