├── LICENSE ├── README.md ├── alternative_objectives └── README.md ├── human_experiments ├── README.md ├── expert_analysis.Rmd └── ratings_analysis.Rmd ├── requirements.txt ├── rl_experiments ├── README.md ├── data │ ├── cqa │ │ ├── dev.csv │ │ ├── test.csv │ │ └── train.csv │ ├── e-snli │ │ ├── dev.tsv │ │ ├── test.tsv │ │ └── train.tsv │ └── sm │ │ ├── dataset.jsonl │ │ ├── eval.jsonl │ │ └── train.jsonl ├── models │ └── T5ForMC.py ├── run_cqa_t5_rl_ra.sh ├── run_cqa_t5_rl_re.sh ├── run_nli_t5_rl_ra.sh ├── run_nli_t5_rl_re.sh ├── t5_rl.py ├── t5_utils.py └── utils.py └── sim_experiments ├── NLI_data_utils.py ├── QA_data_utils.py ├── README.md ├── T5-2-agent_main.py ├── causal_estimation.Rmd ├── classes.py ├── compute_sim.py ├── data ├── e-SNLI-data │ ├── dev.tsv │ ├── dev.txt │ ├── test.tsv │ ├── test.txt │ ├── train.tsv │ └── train.txt └── v1.0 │ ├── dev.csv │ ├── test.csv │ └── train.csv ├── main.py ├── models └── T5ForMC.py ├── run_tasks.py ├── training_reports └── README.md └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Peter Hase 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Leakage-Adjusted Simulatability 2 | This is the codebase for the paper: 3 | [Leakage-Adjusted Simulatability: Can Models Generate Non-Trivial Explanations of Their Behavior in Natural Language?](https://arxiv.org/abs/2010.04119) 4 | Peter Hase, Shiyue Zhang, Harry Xie, Mohit Bansal. Findings of EMNLP 2020 5 | 6 | ## Repository Structure 7 | 8 | ``` 9 | |__ sim_experiments/ --> Directory with models and experiment scripts 10 | |__ data/ --> includes eSNLI and CoS-E data used in experiments 11 | |__ models/ --> includes wrappers for T5 model for multiple choice tasks 12 | |__ training_reports/ --> includes wrappers for T5 model for multiple choice tasks 13 | |__ main.py --> train task model or simulator model on SNLI or CQA 14 | |__ T5-2-agent_main.py --> train models for multi-agent experiments [SGD] in paper 15 | |__ compute_sim.py --> script for computing LAS score for model explanations 16 | |__ run_tasks.py --> script for running experimental conditions across seeds 17 | |__ *.py --> utilities for data loading, explanation sampling, etc. 18 | |__ causal_estimation.Rmd --> R markdown script used to calibrate simulator models and compute LAS with *k* bins 19 | |__ alternative_objectives/ --> Directory with additional experimental scripts 20 | |__ code coming soon 21 | |__ rl_experiments/ --> Directory with code for multi-agent [RL] experimental scripts 22 | |__ see internal README 23 | |__ human_experiments/ --> Directory with R code for analyzing human experiment data 24 | |__ ratings_analysis.Rmd --> R markdown script used to analyze human quality ratings 25 | |__ expert_analysis.Rmd --> R markdown script used to analyze expert simulation data 26 | |__ more human experiment code coming soon 27 | |__ requirements.txt 28 | 29 | ``` 30 | 31 | ## Requirements 32 | 33 | - Python 3.6 34 | - torch 1.4 35 | - transformers 2.5.1 36 | - sacrebleu 37 | - pandas 38 | - numpy 39 | 40 | ## Reproducing Experiments 41 | 42 | See READMEs in each directory for instructions on reproducing each set of experiments. 43 | -------------------------------------------------------------------------------- /alternative_objectives/README.md: -------------------------------------------------------------------------------- 1 | ## Reproducing Experiments 2 | 3 | Code and reproduction instructions coming soon for this directory. -------------------------------------------------------------------------------- /human_experiments/README.md: -------------------------------------------------------------------------------- 1 | ## Reproducing Experiments 2 | 3 | More code and reproduction instructions coming soon for this directory. -------------------------------------------------------------------------------- /human_experiments/expert_analysis.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "expert_analysis" 3 | author: "Peter Hase" 4 | output: pdf_document 5 | --- 6 | 7 | ```{r setup, include=FALSE} 8 | library(tidyverse) 9 | library(Cairo) 10 | library(readxl) 11 | ``` 12 | 13 | 14 | ```{r organize rater data for checking rater statistics} 15 | 16 | nli_human <- bind_rows(nli_human_peter, nli_human_shiyue, nli_human_harry) 17 | nli_STRa <- bind_rows(nli_STRa_peter, nli_STRa_shiyue, nli_STRa_harry) 18 | cqa_human <- bind_rows(cqa_human_peter, cqa_human_shiyue, cqa_human_harry) 19 | cqa_STRa <- bind_rows(cqa_STRa_peter, cqa_STRa_shiyue, cqa_STRa_harry) 20 | 21 | 22 | nli_human <- nli_human %>% 23 | mutate(role = ifelse(premise == 'NONE' & hypothesis == 'NONE' & human_exp != 'NONE', 'e', 24 | ifelse(premise != 'NONE' & hypothesis != 'NONE' & human_exp != 'NONE', 'xe', 25 | ifelse(premise != 'NONE' & hypothesis != 'NONE' & human_exp == 'NONE', 'x',NA))), 26 | rater = c(rep('peter',rep_num), rep('shiyue',rep_num), rep('harry',rep_num)) 27 | ) 28 | 29 | nli_human <- left_join(nli_human, nli_data %>% select(unique_key, label), by='unique_key') 30 | 31 | nli_STRa <- nli_STRa %>% 32 | mutate(role = ifelse(premise == 'NONE' & hypothesis == 'NONE' & model_exp != 'NONE', 'e', 33 | ifelse(premise != 'NONE' & hypothesis != 'NONE' & model_exp != 'NONE', 'xe', 34 | ifelse(premise != 'NONE' & hypothesis != 'NONE' & model_exp == 'NONE', 'x',NA))), 35 | rater = c(rep('peter',rep_num), rep('shiyue',rep_num), rep('harry',rep_num)) 36 | ) 37 | nli_data$model_correct <- 1*(nli_data$label==nli_data$cage_ra_label) 38 | nli_STRa <- left_join(nli_STRa, nli_data %>% select(unique_key, cage_ra_label, model_correct), by='unique_key') 39 | 40 | cqa_human <- cqa_human %>% 41 | mutate(role = ifelse(question == 'NONE' & human_exp != 'NONE', 'e', 42 | ifelse(question != 'NONE' & human_exp != 'NONE', 'xe', 43 | ifelse(question != 'NONE' & human_exp == 'NONE', 'x',NA))), 44 | rater = c(rep('peter',rep_num), rep('shiyue',rep_num), rep('harry',rep_num)) 45 | ) 46 | 47 | cqa_human <- left_join(cqa_human, cqa_data %>% select(id, label), by='id') 48 | 49 | cqa_STRa <- cqa_STRa %>% 50 | mutate(role = ifelse(question == 'NONE' & model_exp != 'NONE', 'e', 51 | ifelse(question != 'NONE' & model_exp != 'NONE', 'xe', 52 | ifelse(question != 'NONE' & model_exp == 'NONE', 'x',NA))), 53 | rater = c(rep('peter',rep_num), rep('shiyue',rep_num), rep('harry',rep_num)) 54 | ) 55 | cqa_data$model_correct <- 1*(cqa_data$label==cqa_data$cage_ra_label) 56 | cqa_STRa <- left_join(cqa_STRa, cqa_data %>% select(id, cage_ra_label, model_correct), by='id') 57 | 58 | ``` 59 | 60 | 61 | ```{r check rater statistics} 62 | 63 | nli_human <- nli_human %>% 64 | mutate(correct = 1*(pred==label)) 65 | nli_STRa <- nli_STRa %>% 66 | mutate(correct = 1*(pred==cage_ra_label)) 67 | cqa_human <- cqa_human %>% 68 | mutate(correct = 1*(pred==label)) 69 | cqa_STRa <- cqa_STRa %>% 70 | mutate(correct = 1*(pred==cage_ra_label)) 71 | 72 | nli_human %>% 73 | group_by(rater) %>% 74 | summarize(n = n(), 75 | acc = mean(correct)) 76 | nli_STRa %>% 77 | group_by(rater) %>% 78 | summarize(n = n(), 79 | acc = mean(correct)) 80 | cqa_human %>% 81 | group_by(rater) %>% 82 | summarize(n = n(), 83 | acc = mean(correct)) 84 | cqa_STRa %>% 85 | group_by(rater) %>% 86 | summarize(n = n(), 87 | acc = mean(correct)) 88 | 89 | nli_human %>% 90 | group_by(role) %>% 91 | summarize(n = n(), 92 | acc = mean(correct)) 93 | nli_STRa %>% 94 | group_by(role) %>% 95 | summarize(n = n(), 96 | acc = mean(correct)) 97 | cqa_human %>% 98 | group_by(role) %>% 99 | summarize(n = n(), 100 | acc = mean(correct)) 101 | cqa_STRa %>% 102 | group_by(role) %>% 103 | summarize(n = n(), 104 | acc = mean(correct)) 105 | 106 | 107 | ``` 108 | 109 | 110 | ```{r create merged versions data and spread versions of the data} 111 | 112 | nli_human$model <- 'human' 113 | nli_STRa$model <- 'STRa' 114 | nli_STRa$label <- nli_STRa$cage_ra_label 115 | cqa_human$model <- 'human' 116 | cqa_STRa$model <- 'STRa' 117 | cqa_STRa$label <- cqa_STRa$cage_ra_label 118 | 119 | nli_both <- bind_rows(nli_human, nli_STRa) 120 | cqa_both <- bind_rows(cqa_human, cqa_STRa) 121 | 122 | nli_human <- left_join(nli_human, nli_data %>% 123 | select(unique_key, human_yxe, human_yx, human_ye)) %>% 124 | mutate(yxe = human_yxe, 125 | yx = human_yx, 126 | ye = human_ye, 127 | exp = human_exp) %>% 128 | select(unique_key, premise, hypothesis, exp, pred, role, rater, label, correct, model, yxe, yx, ye) 129 | cqa_human <- left_join(cqa_human, cqa_data %>% 130 | select(id, human_yxe, human_yx, human_ye, cage_ra_yxe, cage_ra_yx, cage_ra_ye)) %>% 131 | mutate(yxe = human_yxe, 132 | yx = human_yx, 133 | ye = human_ye, 134 | exp = human_exp) %>% 135 | select(id, question, choice_0, choice_1, choice_2, exp, pred, role, rater, label, correct, model, yxe, yx, ye) 136 | nli_STRa <- left_join(nli_STRa, nli_data %>% 137 | select(unique_key, cage_ra_yxe, cage_ra_yx, cage_ra_ye)) %>% 138 | mutate(yxe = cage_ra_yxe, 139 | yx = cage_ra_yx, 140 | ye = cage_ra_ye, 141 | exp = model_exp) %>% 142 | select(unique_key, premise, hypothesis, exp, pred, role, rater, label, correct, model, yxe, yx, ye) 143 | cqa_STRa <- left_join(cqa_STRa, cqa_data %>% 144 | select(id, cage_ra_yxe, cage_ra_yx, cage_ra_ye)) %>% 145 | mutate(yxe = cage_ra_yxe, 146 | yx = cage_ra_yx, 147 | ye = cage_ra_ye, 148 | exp = model_exp) %>% 149 | select(id, question, choice_0, choice_1, choice_2, exp, pred, role, rater, label, correct, model, yxe, yx, ye) 150 | 151 | 152 | nli_both <- left_join(nli_both, nli_data %>% 153 | select(unique_key, human_yxe, human_yx, human_ye, cage_ra_yxe, cage_ra_yx, cage_ra_ye)) %>% 154 | mutate(yxe = ifelse(model=='human',human_yxe,cage_ra_yxe), 155 | yx = ifelse(model=='human',human_yx,cage_ra_yx), 156 | ye = ifelse(model=='human',human_ye,cage_ra_ye), 157 | exp = ifelse(model=='human',human_exp,model_exp)) %>% 158 | select(unique_key, premise, hypothesis, exp, pred, role, rater, label, correct, model_correct, model, yxe, yx, ye) 159 | cqa_both <- left_join(cqa_both, cqa_data %>% 160 | select(id, human_yxe, human_yx, human_ye, cage_ra_yxe, cage_ra_yx, cage_ra_ye)) %>% 161 | mutate(yxe = ifelse(model=='human',human_yxe,cage_ra_yxe), 162 | yx = ifelse(model=='human',human_yx,cage_ra_yx), 163 | ye = ifelse(model=='human',human_ye,cage_ra_ye), 164 | exp = ifelse(model=='human',human_exp,model_exp)) %>% 165 | select(id, question, choice_0, choice_1, choice_2, exp, pred, role, rater, label, correct, model_correct, model, yxe, yx, ye) 166 | 167 | # create spread versions too. NOTE y* means simulator. xe, x, and e are human predictions 168 | nli_human_spread <- nli_human %>% 169 | select(unique_key, pred, role, yxe, yx, ye, label) %>% 170 | spread(role, pred) %>% 171 | mutate(xe_correct = (label==xe), 172 | e_correct = (label==e), 173 | x_correct = (label==x), 174 | yxe_correct = (label==yxe), 175 | ye_correct = (label==ye), 176 | yx_correct = (label==yx)) 177 | nli_stra_spread <- nli_STRa %>% 178 | select(unique_key, pred, role, yxe, yx, ye, label) %>% 179 | spread(role, pred) %>% 180 | mutate(xe_correct = (label==xe), 181 | e_correct = (label==e), 182 | x_correct = (label==x), 183 | yxe_correct = (label==yxe), 184 | ye_correct = (label==ye), 185 | yx_correct = (label==yx)) 186 | cqa_human_spread <- cqa_human %>% 187 | select(id, pred, role, yxe, yx, ye, label) %>% 188 | spread(role, pred) %>% 189 | mutate(xe_correct = (label==xe), 190 | e_correct = (label==e), 191 | x_correct = (label==x), 192 | yxe_correct = (label==yxe), 193 | ye_correct = (label==ye), 194 | yx_correct = (label==yx)) 195 | cqa_stra_spread <- cqa_STRa %>% 196 | select(id, pred, role, yxe, yx, ye, label) %>% 197 | spread(role, pred) %>% 198 | mutate(xe_correct = (label==xe), 199 | e_correct = (label==e), 200 | x_correct = (label==x), 201 | yxe_correct = (label==yxe), 202 | ye_correct = (label==ye), 203 | yx_correct = (label==yx)) 204 | 205 | nli_human_spread <- nli_human_spread %>% 206 | mutate(model_LAS = yxe_correct - yx_correct, 207 | human_LAS = xe_correct - x_correct) 208 | nli_stra_spread <- nli_stra_spread %>% 209 | mutate(model_LAS = yxe_correct - yx_correct, 210 | human_LAS = xe_correct - x_correct) 211 | cqa_human_spread <- cqa_human_spread %>% 212 | mutate(model_LAS = yxe_correct - yx_correct, 213 | human_LAS = xe_correct - x_correct) 214 | cqa_stra_spread <- cqa_stra_spread %>% 215 | mutate(model_LAS = yxe_correct - yx_correct, 216 | human_LAS = xe_correct - x_correct) 217 | 218 | 219 | nli_both_spread <- bind_rows(nli_human_spread, nli_stra_spread) 220 | cqa_both_spread <- bind_rows(cqa_human_spread, cqa_stra_spread) 221 | 222 | ``` 223 | 224 | 225 | 226 | ```{r check sampling constraints} 227 | 228 | nli_data %>% 229 | mutate(yxe_correct = (label==human_yxe), 230 | yx_correct = (label==human_yx), 231 | ye_correct = (label==human_ye)) %>% 232 | group_by(ye_correct, yxe_correct) %>% 233 | summarise(n=n()) 234 | cqa_data %>% 235 | mutate(yxe_correct = (label==human_yxe), 236 | yx_correct = (label==human_yx), 237 | ye_correct = (label==human_ye)) %>% 238 | group_by(ye_correct, yxe_correct) %>% 239 | summarise(n=n()) 240 | 241 | nli_human %>% 242 | mutate(yxe_correct = (label==yxe), 243 | yx_correct = (label==yx), 244 | ye_correct = (label==ye)) %>% 245 | group_by(ye_correct, yxe_correct) %>% 246 | summarise(n=n()) 247 | nli_STRa %>% 248 | mutate(yxe_correct = (label==yxe), 249 | yx_correct = (label==yx), 250 | ye_correct = (label==ye)) %>% 251 | group_by(ye_correct, yxe_correct) %>% 252 | summarise(n=n()) 253 | cqa_human %>% 254 | mutate(yxe_correct = (label==yxe), 255 | yx_correct = (label==yx), 256 | ye_correct = (label==ye)) %>% 257 | group_by(ye_correct, yxe_correct) %>% 258 | summarise(n=n()) 259 | cqa_STRa %>% 260 | mutate(yxe_correct = (label==yxe), 261 | yx_correct = (label==yx), 262 | ye_correct = (label==ye)) %>% 263 | group_by(ye_correct, yxe_correct) %>% 264 | summarise(n=n()) 265 | 266 | nli_human %>% 267 | mutate(yxe_correct = (label==yxe), 268 | yx_correct = (label==yx), 269 | ye_correct = (label==ye)) %>% 270 | group_by(ye_correct) %>% 271 | summarise(mean(yxe_correct), 272 | mean(yx_correct)) 273 | 274 | ``` 275 | 276 | ```{r add leaking variables to _both dfs} 277 | 278 | # NOTE HUMAN LEAKING ONLY VALID FOR role == 'e' in these dfs 279 | nli_both <- nli_both %>% 280 | mutate(model_leaking = 1*(ye==label), 281 | human_leaking = 1*(pred==label), 282 | model_LAS = ((yxe==label) - (yx==label))) 283 | 284 | cqa_both <- cqa_both %>% 285 | mutate(model_leaking = 1*(ye==label), 286 | human_leaking = 1*(pred==label), 287 | model_LAS = ((yxe==label) - (yx==label))) 288 | 289 | 290 | ``` 291 | 292 | 293 | ```{r check data balance} 294 | 295 | nli_both %>% 296 | group_by(model_correct) %>% 297 | summarise(n=n()) 298 | 299 | cqa_both %>% 300 | group_by(model_correct) %>% 301 | summarise(n=n()) 302 | 303 | nli_both %>% 304 | filter(role=='e') %>% 305 | group_by(human_leaking) %>% 306 | summarise(n=n()) 307 | 308 | cqa_both %>% 309 | filter(role=='e') %>% 310 | group_by(human_leaking) %>% 311 | summarise(n=n()) 312 | 313 | ``` 314 | 315 | ```{r check simulator accuracies} 316 | 317 | nli_human_spread %>% 318 | summarise(model=mean(yxe_correct), 319 | human=mean(xe_correct)) 320 | nli_stra_spread %>% 321 | summarise(model=mean(yxe_correct), 322 | human=mean(xe_correct)) 323 | cqa_human_spread %>% 324 | summarise(model=mean(yxe_correct), 325 | human=mean(xe_correct)) 326 | cqa_stra_spread %>% 327 | summarise(model=mean(yxe_correct), 328 | human=mean(xe_correct)) 329 | 330 | (78 - 61.33 + 94.66 - 76.66 + 90.66 - 76.66 + 68.66 - 66) / 4 331 | 332 | 333 | ``` 334 | 335 | 336 | 337 | ```{r check proxy variable quality} 338 | 339 | nli_both %>% 340 | filter(role=='e') %>% 341 | group_by(model, model_leaking, human_leaking) %>% 342 | summarise(n=n()) 343 | 344 | cqa_both %>% 345 | filter(role=='e') %>% 346 | group_by(model, model_leaking, human_leaking) %>% 347 | summarise(n=n()) 348 | 349 | nli_both %>% 350 | filter(role=='e', model == 'human') %>% 351 | group_by(model_leaking, human_leaking) %>% 352 | summarise(n=n()) 353 | 354 | # similar trends for human and stra, and for datasets, hence combine all of them for statistical testing 355 | 356 | (nli_table <- table(nli_both$model_leaking[nli_both$role=='e'], nli_both$human_leaking[nli_both$role=='e'])) 357 | (cqa_table <- table(cqa_both$model_leaking[cqa_both$role=='e'], cqa_both$human_leaking[cqa_both$role=='e'])) 358 | (table <- (nli_table+cqa_table)) 359 | (leaking_var_transition_mat <- (table / c(sum(table[1,]), sum(table[2,])))) 360 | chisq.test(table) 361 | 362 | cor.test(1*nli_both_spread$ye_correct, 363 | 1*nli_both_spread$e_correct, 364 | method='spearman') # rank correlation 365 | 366 | cor.test(1*nli_both_spread$ye_correct, 367 | 1*nli_both_spread$e_correct, 368 | method='kendall') 369 | 370 | ``` 371 | 372 | 373 | ```{r LAS variable quality testing} 374 | 375 | # note we do not compare simulators overall, since sampled data is not necessarily the same for the simulators, because of the per-model balancing constraints. but for each explaining model, we always have both simulator responses, so thats where we compare 376 | 377 | nli_human_spread <- nli_human_spread %>% 378 | mutate(model_LAS = yxe_correct - yx_correct, 379 | human_LAS = xe_correct - x_correct) 380 | nli_stra_spread <- nli_stra_spread %>% 381 | mutate(model_LAS = yxe_correct - yx_correct, 382 | human_LAS = xe_correct - x_correct) 383 | cqa_human_spread <- cqa_human_spread %>% 384 | mutate(model_LAS = yxe_correct - yx_correct, 385 | human_LAS = xe_correct - x_correct) 386 | cqa_stra_spread <- cqa_stra_spread %>% 387 | mutate(model_LAS = yxe_correct - yx_correct, 388 | human_LAS = xe_correct - x_correct) 389 | 390 | # nli human 391 | leaking <- nli_human_spread$e_correct 392 | (full_table1 <- table(nli_human_spread$model_LAS, nli_human_spread$human_LAS)) 393 | cor.test(nli_human_spread$model_LAS, nli_human_spread$human_LAS, method='spearman') # rank correlation 394 | 395 | # nli stra 396 | leaking <- nli_stra_spread$e_correct 397 | (full_table2 <- table(nli_stra_spread$model_LAS, nli_stra_spread$human_LAS)) 398 | cor.test(nli_stra_spread$model_LAS, nli_stra_spread$human_LAS, method='spearman') # rank correlation 399 | 400 | # cqa human 401 | leaking <- cqa_human_spread$e_correct 402 | (full_table3 <- table(cqa_human_spread$model_LAS, cqa_human_spread$human_LAS)) 403 | cor.test(cqa_human_spread$model_LAS, cqa_human_spread$human_LAS, method='spearman') # rank correlation 404 | 405 | # cqa stra 406 | leaking <- cqa_stra_spread$e_correct 407 | (full_table4 <- table(cqa_stra_spread$model_LAS, cqa_stra_spread$human_LAS)) 408 | cor.test(cqa_stra_spread$model_LAS, cqa_stra_spread$human_LAS, method='spearman') # rank correlation 409 | 410 | # combine based on model 411 | (full_table2+full_table4) 412 | cor.test(c(cqa_stra_spread$model_LAS, nli_stra_spread$model_LAS), 413 | c(cqa_stra_spread$human_LAS, nli_stra_spread$human_LAS), 414 | method='spearman') # rank correlation 415 | 416 | (full_table1+full_table3) 417 | cor.test(c(cqa_human_spread$model_LAS, nli_human_spread$model_LAS), 418 | c(cqa_human_spread$human_LAS, nli_human_spread$human_LAS), 419 | method='spearman') # rank correlation 420 | cor.test(c(cqa_human_spread$model_LAS, nli_human_spread$model_LAS), 421 | c(cqa_human_spread$human_LAS, nli_human_spread$human_LAS), 422 | method='pearson') # rank correlation 423 | 424 | # combine all 425 | (full_table1+full_table2+full_table3+full_table4) 426 | cor.test(c(cqa_human_spread$model_LAS, nli_human_spread$model_LAS, c(cqa_stra_spread$model_LAS, nli_stra_spread$model_LAS)), 427 | c(cqa_human_spread$human_LAS, nli_human_spread$human_LAS, c(cqa_stra_spread$human_LAS, nli_stra_spread$human_LAS)), 428 | method='spearman') # rank correlation 429 | cor.test(c(cqa_human_spread$model_LAS, nli_human_spread$model_LAS, c(cqa_stra_spread$model_LAS, nli_stra_spread$model_LAS)), 430 | c(cqa_human_spread$human_LAS, nli_human_spread$human_LAS, c(cqa_stra_spread$human_LAS, nli_stra_spread$human_LAS)), 431 | method='pearson') # rank correlation 432 | 433 | all_table <- (full_table1+full_table2+full_table3+full_table4) 434 | round(all_table / c(sum(all_table[1,]), sum(all_table[2,]), sum(all_table[3,])),3) 435 | 436 | ``` 437 | 438 | 439 | -------------------------------------------------------------------------------- /human_experiments/ratings_analysis.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "ratings_analysis" 3 | author: "Peter Hase" 4 | output: pdf_document 5 | --- 6 | 7 | ```{r setup, include=FALSE} 8 | library(tidyverse) 9 | library(Cairo) 10 | ``` 11 | 12 | 13 | ```{r make full ratings dataframes} 14 | 15 | model_names <- c('human','re','ra','cage_re','cage_ra') 16 | 17 | qa_results <- qa_results %>% 18 | mutate(unique_id = paste(WorkerId, Input.qa_id, sep="_"), 19 | qa_id = Input.qa_id, 20 | rate1 = Answer.rate1, 21 | rate2 = Answer.rate2, 22 | rate3 = Answer.rate3, 23 | rate4 = Answer.rate4, 24 | rate5 = Answer.rate5) %>% 25 | dplyr::select(sample, WorkerId, unique_id, qa_id, rate1, rate2, rate3, rate4, rate5) 26 | 27 | qa_results <- qa_results %>% 28 | left_join(qa_key, by = 'qa_id') 29 | 30 | qa_results <- qa_results %>% 31 | mutate(human = NA, 32 | re = NA, 33 | ra = NA, 34 | cage_re = NA, 35 | cage_ra = NA) 36 | 37 | # deshuffle the explanations 38 | answer_cols <- c('rate1','rate2','rate3','rate4','rate5') 39 | for (i in 1:nrow(qa_results)){ 40 | order <- qa_results[i,c('model1','model2','model3','model4','model5')] 41 | for (j in 1:5){ 42 | qa_results[i,order[j][[1]]] <- qa_results[i,answer_cols[j][[1]]] 43 | } 44 | } 45 | 46 | qa_samples <- qa_samples %>% 47 | mutate(qa_id = id) 48 | 49 | qa_results <- qa_results %>% 50 | left_join(qa_samples, by = 'qa_id') 51 | 52 | # model level rating dist. 53 | qa_results %>% 54 | gather('model', 'rating', model_names) %>% 55 | ggplot(aes(model, rating)) + 56 | geom_boxplot() 57 | 58 | qa_gather <- qa_results %>% 59 | gather('model', 'rating', model_names) %>% 60 | mutate(label = ifelse(model=='human',human_label, 61 | ifelse(model=='re',re_model_label, 62 | ifelse(model=='ra',ra_model_label, 63 | ifelse(model=='cage_re',cage_re_model_label, 64 | ifelse(model=='cage_ra',cage_ra_model_label, NA))))), 65 | yxe = ifelse(model=='human',human_yxe, 66 | ifelse(model=='re',re_model_yxe, 67 | ifelse(model=='ra',ra_model_yxe, 68 | ifelse(model=='cage_re',cage_re_model_yxe, 69 | ifelse(model=='cage_ra',cage_ra_model_yxe, NA))))), 70 | ye = ifelse(model=='human',human_ye, 71 | ifelse(model=='re',re_model_ye, 72 | ifelse(model=='ra',ra_model_ye, 73 | ifelse(model=='cage_re',cage_re_model_ye, 74 | ifelse(model=='cage_ra',cage_ra_model_ye, NA))))), 75 | yx = ifelse(model=='human',human_yx, 76 | ifelse(model=='re',re_model_yx, 77 | ifelse(model=='ra',ra_model_yx, 78 | ifelse(model=='cage_re',cage_re_model_yx, 79 | ifelse(model=='cage_ra',cage_ra_model_yx, NA))))), 80 | ye_prob = ifelse(model=='human',human_ye_label_prob, 81 | ifelse(model=='re',re_model_ye_label_prob, 82 | ifelse(model=='ra',ra_model_ye_label_prob, 83 | ifelse(model=='cage_re',cage_re_model_ye_label_prob, 84 | ifelse(model=='cage_ra',cage_ra_model_ye_label_prob, NA))))), 85 | LAS = yxe - yx, 86 | choice = ifelse(label==0, choice0, 87 | ifelse(label==1,choice1, 88 | ifelse(label==2,choice2,NA)))) 89 | 90 | nli_results <- nli_results %>% 91 | mutate(unique_id = paste(WorkerId, Input.qa_id, sep="_"), 92 | qa_id = Input.qa_id, 93 | rate1 = Answer.rate1, 94 | rate2 = Answer.rate2, 95 | rate3 = Answer.rate3, 96 | rate4 = Answer.rate4, 97 | rate5 = Answer.rate5) %>% 98 | dplyr::select(sample, WorkerId, unique_id, qa_id, rate1, rate2, rate3, rate4, rate5) %>% 99 | na.omit() 100 | 101 | nli_results <- nli_results %>% 102 | left_join(nli_key, by = 'qa_id') 103 | 104 | nli_results <- nli_results %>% 105 | mutate(human = NA, 106 | re = NA, 107 | ra = NA, 108 | cage_re = NA, 109 | cage_ra = NA) 110 | 111 | # deshuffle the explanations 112 | answer_cols <- c('rate1','rate2','rate3','rate4','rate5') 113 | for (i in 1:nrow(nli_results)){ 114 | order <- nli_results[i,c('model1','model2','model3','model4','model5')] 115 | for (j in 1:5){ 116 | nli_results[i,order[j][[1]]] <- nli_results[i,answer_cols[j][[1]]] 117 | } 118 | } 119 | 120 | nli_samples <- nli_samples %>% 121 | mutate(qa_id = id) 122 | 123 | nli_results <- nli_results %>% 124 | left_join(nli_samples, by = 'qa_id') 125 | 126 | # model level rating dist. 127 | nli_results %>% 128 | gather('model', 'rating', model_names) %>% 129 | ggplot(aes(model, rating)) + 130 | geom_boxplot() 131 | 132 | nli_gather <- nli_results %>% 133 | gather('model', 'rating', model_names) %>% 134 | mutate(label = ifelse(model=='human',human_label, 135 | ifelse(model=='re',re_model_label, 136 | ifelse(model=='ra',ra_model_label, 137 | ifelse(model=='cage_re',cage_re_model_label, 138 | ifelse(model=='cage_ra',cage_ra_model_label, NA))))), 139 | yxe = ifelse(model=='human',human_yxe, 140 | ifelse(model=='re',re_model_yxe, 141 | ifelse(model=='ra',ra_model_yxe, 142 | ifelse(model=='cage_re',cage_re_model_yxe, 143 | ifelse(model=='cage_ra',cage_ra_model_yxe, NA))))), 144 | ye = ifelse(model=='human',human_ye, 145 | ifelse(model=='re',re_model_ye, 146 | ifelse(model=='ra',ra_model_ye, 147 | ifelse(model=='cage_re',cage_re_model_ye, 148 | ifelse(model=='cage_ra',cage_ra_model_ye, NA))))), 149 | yx = ifelse(model=='human',human_yx, 150 | ifelse(model=='re',re_model_yx, 151 | ifelse(model=='ra',ra_model_yx, 152 | ifelse(model=='cage_re',cage_re_model_yx, 153 | ifelse(model=='cage_ra',cage_ra_model_yx, NA))))), 154 | ye_prob = ifelse(model=='human',human_ye_label_prob, 155 | ifelse(model=='re',re_model_ye_label_prob, 156 | ifelse(model=='ra',ra_model_ye_label_prob, 157 | ifelse(model=='cage_re',cage_re_model_ye_label_prob, 158 | ifelse(model=='cage_ra',cage_ra_model_ye_label_prob, NA))))), 159 | LAS = yxe - yx, 160 | choice = ifelse(label==0, 'neutral', 161 | ifelse(label==1,'entailment', 162 | ifelse(label==2,'contradiction',NA)))) 163 | 164 | nli_gather %>% 165 | mutate(good = rating >= 4) %>% 166 | group_by(ye, LAS) %>% 167 | summarise(n = n(), 168 | prop = mean(good)) 169 | 170 | ``` 171 | 172 | 173 | ```{r calibrate probabilities} 174 | 175 | ### NOTE THIS REQUIRES qa_cal_model AND nli_cal_model FROM causal_estimation.Rmd in memory 176 | 177 | qa_probs = predict(qa_cal_model, qa_gather, type='response') 178 | binnedplot(qa_probs, qa_gather$ye) 179 | qa_gather <- qa_gather %>% 180 | mutate(ye_prob_cal = qa_probs) 181 | 182 | nli_probs = predict(nli_cal_model, nli_gather, type='response') 183 | binnedplot(nli_probs, nli_gather$ye) 184 | nli_gather <- nli_gather %>% 185 | mutate(ye_prob_cal = nli_probs) 186 | 187 | hist(qa_gather$ye_prob_cal) 188 | hist(nli_gather$ye_prob_cal) 189 | 190 | ``` 191 | 192 | 193 | 194 | ```{r QA trend analysis} 195 | 196 | qa_use = qa_gather 197 | 198 | qa_use %>% 199 | group_by(LAS) %>% 200 | summarise(n = n(), 201 | rating = mean(rating)) 202 | 203 | qa_use %>% 204 | group_by(yxe) %>% 205 | summarise(n = n(), 206 | rating = mean(rating)) 207 | 208 | qa_use %>% 209 | ggplot(aes(as.factor(LAS),rating)) + 210 | geom_boxplot() 211 | 212 | qa_use %>% 213 | group_by(ye, LAS) %>% 214 | summarise(n = n(), 215 | avg_rating = mean(rating), 216 | CI = 1.96*sd(avg_rating)/sqrt(n)) 217 | 218 | qa_use %>% 219 | group_by(LAS, model) %>% 220 | summarise(n = n(), 221 | rating = mean(rating)) %>% 222 | arrange(model) 223 | 224 | qa_use %>% 225 | group_by(LAS, yxe) %>% 226 | summarise(n = n(), 227 | rating = mean(rating)) 228 | 229 | qa_use %>% 230 | group_by(ye, LAS, yxe) %>% 231 | summarise(n = n(), 232 | rating = mean(rating)) 233 | 234 | qa_use %>% 235 | group_by(ye, yxe) %>% 236 | summarise(n = n(), 237 | rating = mean(rating)) %>% 238 | arrange(yxe) 239 | 240 | qa_use %>% 241 | group_by(ye, LAS) %>% 242 | summarise(n = n(), 243 | rating = mean(rating)) %>% 244 | arrange(LAS) 245 | 246 | qa_use %>% 247 | group_by(ye, yx) %>% 248 | summarise(n = n(), 249 | rating = mean(rating)) %>% 250 | arrange(yx) 251 | 252 | qa_gather %>% 253 | mutate(good = rating >= 4) %>% 254 | group_by(yxe, model, qa_id) %>% 255 | summarise(n = n(), 256 | rating = mean(rating)) %>% 257 | mutate(good = 1*(rating >= 4)) %>% 258 | group_by(yxe, model) %>% 259 | summarise(n = n(), 260 | prop = mean(good)) %>% 261 | mutate(freq = n / sum(n)) %>% 262 | arrange(desc(yxe)) 263 | 264 | qa_gather %>% 265 | mutate(good = rating >= 4) %>% 266 | group_by(ye, LAS) %>% 267 | summarise(n = n(), 268 | prop = mean(good)) 269 | 270 | qa_use %>% 271 | group_by(model) %>% 272 | summarise(n = n(), 273 | rating = mean(rating)) 274 | 275 | qa_use %>% 276 | group_by(model, ye) %>% 277 | summarise( 278 | yxe = mean(yxe), 279 | yx = mean(yx), 280 | LAS = mean(yxe)-mean(yx), 281 | n=n()) %>% 282 | ungroup() %>% 283 | group_by(model) %>% 284 | summarise(LAS=mean(LAS), 285 | yxe_raw = mean(yxe), 286 | yxe_weight = weighted.mean(yxe, w=n), 287 | n= sum(n)) %>% 288 | left_join( 289 | qa_use %>% 290 | group_by(model) %>% 291 | summarise(n = n(), 292 | rating = mean(rating)) 293 | ) %>% 294 | arrange(rating) 295 | 296 | qa_use %>% 297 | mutate(LAS_exact = paste(yxe, yx)) %>% 298 | group_by(ye, LAS_exact) %>% 299 | summarise(n = n(), 300 | rating = mean(rating), 301 | LAS = mean(LAS)) 302 | 303 | ``` 304 | 305 | ```{r QA models} 306 | 307 | ratings_model <- lm(rating ~ ye + LAS + yxe, data = qa_use) 308 | summary(ratings_model) 309 | 310 | summary(lm(rating ~ ye * LAS, data = qa_use)) 311 | summary(lm(rating ~ ye * yxe, data = qa_use)) 312 | summary(lm(rating ~ ye + LAS, data = qa_use)) 313 | summary(lm(rating ~ ye + yxe, data = qa_use)) 314 | summary(lm(rating ~ ye + yx + yxe, data = qa_use)) 315 | summary(lm(rating ~ ye + yx + yxe, data = qa_use)) 316 | summary(lm(rating ~ ye_prob_cal + yx + yxe, data = qa_use)) 317 | summary(lm(rating ~ ye_prob_cal + yx + LAS, data = qa_use)) 318 | summary(lm(rating ~ ye_prob_cal + yxe + LAS, data = qa_use)) 319 | 320 | ratings_model <- lm(rating ~ ye + LAS, data = qa_use) 321 | summary(ratings_model) 322 | test_data <- qa_use %>% 323 | group_by(ye, LAS) %>% 324 | summarise(n = n(), 325 | rating = mean(rating)) %>% 326 | select(ye,LAS) %>% 327 | ungroup() 328 | 329 | (test_data <- test_data %>% 330 | mutate(preds = predict(ratings_model, test_data))) 331 | qa_use %>% 332 | group_by(ye, LAS) %>% 333 | summarise(n = n(), 334 | rating = mean(rating)) 335 | 336 | 337 | ratings_model <- lm(rating ~ ye + yxe, data = qa_use) 338 | summary(ratings_model) 339 | test_data <- qa_use %>% 340 | group_by(ye, yxe) %>% 341 | summarise(n = n(), 342 | rating = mean(rating)) %>% 343 | select(ye,yxe) %>% 344 | ungroup() 345 | 346 | (test_data <- test_data %>% 347 | mutate(preds = predict(ratings_model, test_data))) 348 | qa_use %>% 349 | group_by(ye, yxe) %>% 350 | summarise(n = n(), 351 | rating = mean(rating)) 352 | 353 | 354 | ratings_model <- lm(rating ~ ye + yxe + LAS, data = qa_use) 355 | summary(ratings_model) 356 | test_data <- qa_use %>% 357 | group_by(ye, yxe, LAS) %>% 358 | summarise(n = n(), 359 | rating = mean(rating)) %>% 360 | select(ye,yxe,LAS) %>% 361 | ungroup() 362 | 363 | (test_data <- test_data %>% 364 | mutate(preds = predict(ratings_model, test_data))) 365 | qa_use %>% 366 | group_by(ye, yxe, LAS) %>% 367 | summarise(n = n(), 368 | rating = mean(rating)) 369 | 370 | 371 | summary(lm(rating ~ ye * LAS, data = qa_use)) 372 | summary(lm(rating ~ ye * yxe, data = qa_use)) 373 | summary(lm(rating ~ ye * (LAS + yxe), data = qa_use)) 374 | summary(lm(rating ~ ye + (LAS + yxe), data = qa_use)) 375 | summary(lm(rating ~ ye_prob_cal + yx + yxe, data = qa_use)) 376 | summary(lm(rating ~ ye_prob_cal + yx + LAS, data = qa_use)) 377 | summary(lm(rating ~ ye_prob_cal + yxe + LAS, data = qa_use)) 378 | 379 | 380 | ``` 381 | 382 | 383 | ```{r nli trend analysis} 384 | 385 | nli_use = nli_gather 386 | 387 | nli_use %>% 388 | group_by(LAS) %>% 389 | summarise(n = n(), 390 | rating = mean(rating)) 391 | 392 | nli_use %>% 393 | group_by(yxe) %>% 394 | summarise(n = n(), 395 | rating = mean(rating)) 396 | 397 | nli_use %>% 398 | ggplot(aes(as.factor(LAS),rating)) + 399 | geom_boxplot() 400 | 401 | nli_use %>% 402 | group_by(sample, ye, LAS) %>% 403 | summarise(n = n(), 404 | rating = mean(rating)) 405 | 406 | nli_use %>% 407 | group_by(LAS, model) %>% 408 | summarise(n = n(), 409 | rating = mean(rating)) %>% 410 | arrange(model) 411 | 412 | nli_use %>% 413 | group_by(LAS, yxe) %>% 414 | summarise(n = n(), 415 | rating = mean(rating)) 416 | 417 | nli_use %>% 418 | group_by(ye, LAS, yxe) %>% 419 | summarise(n = n(), 420 | rating = mean(rating)) 421 | 422 | nli_use %>% 423 | group_by(ye, yxe) %>% 424 | summarise(n = n(), 425 | rating = mean(rating)) %>% 426 | arrange(yxe) 427 | 428 | nli_use %>% 429 | group_by(ye, LAS) %>% 430 | summarise(n = n(), 431 | rating = mean(rating)) %>% 432 | arrange(LAS) 433 | 434 | nli_use %>% 435 | group_by(ye, yx) %>% 436 | summarise(n = n(), 437 | rating = mean(rating)) %>% 438 | arrange(yx) 439 | 440 | nli_use %>% 441 | group_by(model, ye) %>% 442 | summarise( 443 | yxe = mean(yxe), 444 | yx = mean(yx), 445 | LAS = mean(yxe)-mean(yx), 446 | n=n()) %>% 447 | ungroup() %>% 448 | group_by(model) %>% 449 | summarise(LAS=mean(LAS), 450 | yxe_raw = mean(yxe), 451 | yxe_weight = weighted.mean(yxe, w=n), 452 | n= sum(n)) %>% 453 | left_join( 454 | nli_use %>% 455 | group_by(model) %>% 456 | summarise(n = n(), 457 | rating = mean(rating)) 458 | ) %>% 459 | arrange(rating) 460 | 461 | ``` 462 | 463 | ```{r nli models} 464 | 465 | ratings_model <- lm(rating ~ ye + LAS + yxe, data = nli_use) 466 | summary(ratings_model) 467 | 468 | summary(lm(rating ~ ye * LAS, data = nli_use)) 469 | summary(lm(rating ~ ye * yxe, data = nli_use)) 470 | summary(lm(rating ~ ye + LAS, data = nli_use)) 471 | summary(lm(rating ~ ye + yxe, data = nli_use)) 472 | summary(lm(rating ~ ye * LAS + yxe, data = nli_use)) 473 | summary(lm(rating ~ ye + yx + yxe, data = nli_use)) 474 | summary(lm(rating ~ ye_prob_cal + yx + yxe, data = nli_use)) 475 | summary(lm(rating ~ ye_prob_cal * LAS + yx + LAS, data = nli_use)) 476 | summary(lm(rating ~ ye_prob_cal * LAS + yxe, data = nli_use)) 477 | 478 | ratings_model <- lm(rating ~ ye + LAS, data = nli_use) 479 | summary(ratings_model) 480 | test_data <- nli_use %>% 481 | group_by(ye, LAS) %>% 482 | summarise(n = n(), 483 | rating = mean(rating)) %>% 484 | select(ye,LAS) %>% 485 | ungroup() 486 | # mutate(LAS = as.factor(LAS)) 487 | (test_data <- test_data %>% 488 | mutate(preds = predict(ratings_model, test_data))) 489 | nli_use %>% 490 | group_by(ye, LAS) %>% 491 | summarise(n = n(), 492 | rating = mean(rating)) 493 | 494 | 495 | ratings_model <- lm(rating ~ ye + yxe, data = nli_use) 496 | summary(ratings_model) 497 | test_data <- nli_use %>% 498 | group_by(ye, yxe) %>% 499 | summarise(n = n(), 500 | rating = mean(rating)) %>% 501 | select(ye,yxe) %>% 502 | ungroup() 503 | # mutate(LAS = as.factor(LAS)) 504 | (test_data <- test_data %>% 505 | mutate(preds = predict(ratings_model, test_data))) 506 | nli_use %>% 507 | group_by(ye, yxe) %>% 508 | summarise(n = n(), 509 | rating = mean(rating)) 510 | 511 | 512 | ratings_model <- lm(rating ~ ye + yxe + LAS, data = nli_use) 513 | summary(ratings_model) 514 | test_data <- nli_use %>% 515 | group_by(ye, yxe, LAS) %>% 516 | summarise(n = n(), 517 | rating = mean(rating)) %>% 518 | select(ye,yxe,LAS) %>% 519 | ungroup() 520 | # mutate(LAS = as.factor(LAS)) 521 | (test_data <- test_data %>% 522 | mutate(preds = predict(ratings_model, test_data))) 523 | nli_use %>% 524 | group_by(ye, yxe, LAS) %>% 525 | summarise(n = n(), 526 | rating = mean(rating)) 527 | 528 | ratings_model <- lm(rating ~ ye + yx, data = nli_use) 529 | summary(ratings_model) 530 | ratings_model <- lm(rating ~ ye + yx + yxe, data = nli_use) 531 | summary(ratings_model) 532 | ratings_model <- lm(rating ~ ye + as.factor(LAS), data = nli_use) 533 | summary(ratings_model) 534 | 535 | ratings_model <- lm(rating ~ as.factor(LAS), data = nli_use) 536 | summary(ratings_model) 537 | 538 | 539 | ``` 540 | 541 | 542 | 543 | ```{r IAA} 544 | 545 | qa_use %>% 546 | group_by(unique_id) %>% 547 | summarise(avg_rating = mean(rating), 548 | sd = sd(rating), 549 | n = n()) 550 | 551 | nli_use %>% 552 | group_by(unique_id) %>% 553 | summarise(avg_rating = mean(rating), 554 | sd = sd(rating), 555 | n = n()) 556 | 557 | ``` 558 | 559 | 560 | 561 | ```{r cqa paper tables} 562 | qa_use = qa_gather 563 | nli_use = nli_gather 564 | 565 | qa_use %>% 566 | group_by(ye, LAS) %>% 567 | summarise(n = n(), 568 | avg_rating = mean(rating), 569 | CI = 1.96*sd(rating)/sqrt(n)) 570 | 571 | qa_use %>% 572 | group_by(yx) %>% 573 | summarise(n = n(), 574 | avg_rating = round(mean(rating),2), 575 | CI = round(1.96*sd(rating)/sqrt(n),2)) 576 | 577 | 578 | (qa_bar_data <- qa_use %>% 579 | group_by(ye) %>% 580 | summarise(n = n(), 581 | avg_rating = mean(rating), 582 | CI = 1.96*sd(rating)/sqrt(n)) %>% 583 | mutate(group = 'ye', 584 | cor = ye) %>% 585 | bind_rows( 586 | qa_use %>% 587 | group_by(yxe) %>% 588 | summarise(n = n(), 589 | avg_rating = mean(rating), 590 | CI = 1.96*sd(rating)/sqrt(n)) %>% 591 | mutate(group='yxe', 592 | cor=yxe)) %>% 593 | dplyr::select(-c(ye,yxe))) %>% 594 | mutate(avg_rating = round(avg_rating,2), 595 | CI=round(CI,2)) 596 | 597 | nli_use %>% 598 | group_by(yx) %>% 599 | summarise(n = n(), 600 | avg_rating = mean(rating), 601 | CI = 1.96*sd(rating)/sqrt(n)) 602 | ``` 603 | 604 | 605 | ```{r nli paper tables} 606 | nli_use %>% 607 | group_by(ye, LAS) %>% 608 | summarise(n = n(), 609 | avg_rating = mean(rating), 610 | CI = 1.96*sd(rating)/sqrt(n)) %>% 611 | mutate(avg_rating = round(avg_rating,2), 612 | CI=round(CI,2)) 613 | 614 | nli_use %>% 615 | group_by(yx) %>% 616 | summarise(n = n(), 617 | avg_rating = mean(rating), 618 | CI = 1.96*sd(rating)/sqrt(n)) %>% 619 | mutate(avg_rating = round(avg_rating,2), 620 | CI=round(CI,2)) 621 | 622 | (nli_bar_data <- nli_use %>% 623 | group_by(ye) %>% 624 | summarise(n = n(), 625 | avg_rating = mean(rating), 626 | CI = 1.96*sd(rating)/sqrt(n)) %>% 627 | mutate(group = 'ye', 628 | cor = ye) %>% 629 | bind_rows( 630 | nli_use %>% 631 | group_by(yxe) %>% 632 | summarise(n = n(), 633 | avg_rating = mean(rating), 634 | CI = 1.96*sd(rating)/sqrt(n)) %>% 635 | mutate(group='yxe', 636 | cor=yxe)) %>% 637 | dplyr::select(-c(ye,yxe))) %>% 638 | mutate(avg_rating = round(avg_rating,2), 639 | CI=round(CI,2)) 640 | 641 | summary(lm(rating ~ ye + yx + yxe, data = qa_use)) 642 | summary(lm(rating ~ ye + yx + yxe, data = nli_use)) 643 | 644 | 645 | ``` 646 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4 2 | transformers==2.5.1 3 | numpy 4 | pandas 5 | sacrebleu -------------------------------------------------------------------------------- /rl_experiments/README.md: -------------------------------------------------------------------------------- 1 | ## CQA Code 2 | 3 | ------ 4 | 5 | ### Requirements 6 | 7 | torch 1.4 8 | transformers 2.5.1 9 | 10 | ### models 11 | 12 | **models/T5ForMC.py** 13 | - defines *T5ModelForMC* wrapper for *T5PreTrainedModel* 14 | - .forward computes the loss for an output sequence given an input sequence or encoder_hidden_states 15 | - .QA_forward returns a loss of shape (batch_size x num_choices) given output sequence answers of shape (batch_size x num_choices x max_seq_len). predictions are the highest likelihood answer, and can be obtained by computing np.argmin(output_loss, axis = -1) 16 | 17 | ### experiment scripts 18 | 19 | Below are the training scripts and experiment shell scripts. 20 | 21 | **t5_rl.py** - for experiments with multi-agent reinforcement learning using simulation metric as reward. 22 | 23 | T5-RL-reason: run_cqa_t5_rl_re.sh, run_nli_t5_rl_re.sh 24 | 25 | T5-RL-rationalize: run_cqa_t5_rl_ra.sh, run_nli_t5_rl_ra.sh -------------------------------------------------------------------------------- /rl_experiments/models/T5ForMC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import T5PreTrainedModel 4 | from transformers.modeling_t5 import T5Stack 5 | from torch.nn import CrossEntropyLoss 6 | import copy 7 | 8 | class T5ModelForMC(T5PreTrainedModel): 9 | """ 10 | Wrapper for T5PreTrainedModel to use T5 for multiple choice under a closed choice set. 11 | - adds .QA_forward method 12 | 13 | (decoder) QA_forward 14 | Input: 15 | input_ids of shape: batch_size x num_choices x max_seq_len 16 | Output: 17 | outputs[0] is loss of shape batch_size x num_choices. preds should be torch.argmax(loss, dim = -1) 18 | 19 | """ 20 | 21 | def __init__(self, config): 22 | super().__init__(config) 23 | self.model_dim = config.d_model 24 | 25 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 26 | 27 | encoder_config = copy.deepcopy(config) 28 | self.encoder = T5Stack(encoder_config) 29 | 30 | decoder_config = copy.deepcopy(config) 31 | decoder_config.is_decoder = True 32 | self.decoder = T5Stack(decoder_config) 33 | 34 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 35 | 36 | self.init_weights() 37 | 38 | def get_input_embeddings(self): 39 | return self.shared 40 | 41 | def set_input_embeddings(self, new_embeddings): 42 | self.shared = new_embeddings 43 | 44 | def get_output_embeddings(self): 45 | return self.lm_head 46 | 47 | 48 | def forward(self, **kwargs): 49 | # keyword arguments come in 3 flavors: encoder-specific (prefixed by 50 | # `encoder_`), decoder-specific (prefixed by `decoder_`) and those 51 | # that apply to the model as whole. 52 | # We let the specific kwargs override the common ones in case of conflict. 53 | 54 | lm_labels = kwargs.pop("decoder_lm_labels", None) 55 | batch_loss = kwargs.pop("batch_loss", None) 56 | 57 | kwargs_common = dict( 58 | (k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_") 59 | ) 60 | kwargs_encoder = kwargs_common.copy() 61 | kwargs_decoder = kwargs_common.copy() 62 | kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_"))) 63 | kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_"))) 64 | 65 | # Encode if needed (training, first prediction pass) 66 | encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) 67 | if encoder_hidden_states is None: 68 | # Convert encoder inputs in embeddings if needed 69 | hidden_states = kwargs_encoder.pop("inputs_embeds", None) 70 | if hidden_states is None: 71 | encoder_inputs_ids = kwargs_encoder.pop("input_ids") 72 | hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings 73 | 74 | encoder_outputs = self.encoder(hidden_states, **kwargs_encoder) 75 | encoder_hidden_states = encoder_outputs[0] 76 | else: 77 | encoder_outputs = () 78 | 79 | # Decode 80 | # Convert decoder inputs in embeddings if needed 81 | hidden_states = kwargs_decoder.pop("inputs_embeds", None) 82 | if hidden_states is None: 83 | decoder_inputs_ids = kwargs_decoder.pop("input_ids") 84 | hidden_states = self.shared(decoder_inputs_ids) 85 | 86 | kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states 87 | kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None) 88 | decoder_outputs = self.decoder(hidden_states, **kwargs_decoder) 89 | 90 | sequence_output = decoder_outputs[0] 91 | # Rescale output before projecting on vocab 92 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 93 | sequence_output = sequence_output * (self.model_dim ** -0.5) 94 | lm_logits = self.lm_head(sequence_output) 95 | 96 | decoder_outputs = (lm_logits,) + decoder_outputs[1:] # Add hidden states and attention if they are here 97 | if lm_labels is not None: 98 | shift_logits = lm_logits[..., :-1, :].contiguous() 99 | shift_labels = lm_labels[..., 1:].contiguous() 100 | if batch_loss: 101 | loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none') 102 | real_label_lengths = (shift_labels != -100).sum(dim=-1, keepdim=True) 103 | loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels) 104 | loss = loss.sum(dim=-1, keepdim=True) / real_label_lengths 105 | else: 106 | loss_fct = CrossEntropyLoss(ignore_index=-100) 107 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 108 | 109 | # if loss != loss: 110 | # print("got nan loss!") 111 | # 112 | # loss_fct_unreduce = CrossEntropyLoss(ignore_index=-100, reduction = 'none') 113 | # def nanmean(v, *args, inplace=False, **kwargs): 114 | # if not inplace: 115 | # v = v.clone() 116 | # is_nan = torch.isnan(v) 117 | # v[is_nan] = 0 118 | # return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs) 119 | # losses = loss_fct_unreduce(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 120 | # loss = nanmean(loss) 121 | 122 | decoder_outputs = ( 123 | loss, 124 | ) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 125 | 126 | return decoder_outputs + encoder_outputs 127 | 128 | 129 | def QA_forward(self, **kwargs): 130 | ''' 131 | so this is basically just .forward that maintains the num_choices dimension, plus doesn't reduce the token loss into a scalar 132 | 133 | # keyword arguments come in 3 flavors: encoder-specific (prefixed by 134 | # `encoder_`), decoder-specific (prefixed by `decoder_`) and those 135 | # that apply to the model as whole. 136 | # We let the specific kwargs override the common ones in case of conflict. 137 | ''' 138 | 139 | batch_size = kwargs['decoder_input_ids'].size(0) 140 | num_choices = kwargs['decoder_input_ids'].size(1) 141 | seq_len = kwargs['decoder_input_ids'].size(2) 142 | 143 | lm_labels = kwargs.pop("decoder_lm_labels", None) 144 | 145 | # kwargs_encoder/decoder are initialized from kwargs_common, and then overwritten by any encoder_/decoder_ prefixed arguments 146 | # arguments inside of kwargs_encoder/decoder are NOT prefixed 147 | kwargs_common = dict( 148 | (k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_") 149 | ) 150 | kwargs_encoder = kwargs_common.copy() 151 | kwargs_decoder = kwargs_common.copy() 152 | kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_"))) 153 | kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_"))) 154 | 155 | # Encode if needed (training, first prediction pass) 156 | encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) 157 | if encoder_hidden_states is None: 158 | # Convert encoder inputs in embeddings if needed 159 | hidden_states = kwargs_encoder.pop("inputs_embeds", None) 160 | if hidden_states is None: 161 | encoder_inputs_ids = kwargs_encoder.pop("input_ids") 162 | hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings 163 | 164 | encoder_outputs = self.encoder(hidden_states, **kwargs_encoder) 165 | encoder_hidden_states = encoder_outputs[0] 166 | else: 167 | encoder_outputs = () 168 | 169 | # Decode 170 | # Convert decoder inputs in embeddings if needed 171 | hidden_states = kwargs_decoder.pop("inputs_embeds", None) 172 | if hidden_states is None: 173 | decoder_inputs_ids = kwargs_decoder.pop("input_ids") 174 | hidden_states = self.shared(decoder_inputs_ids) 175 | 176 | kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states 177 | kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None) 178 | 179 | # have to combine batch_size and num_choices dimensions while preserving other dimensions for call to self.decoder 180 | hidden_states = hidden_states.view(-1, hidden_states.size(-2), hidden_states.size(-1)) if hidden_states is not None else None 181 | for k, v in kwargs_decoder.items(): 182 | if v.dim() == 3: 183 | kwargs_decoder[k] = v.reshape(-1, v.size(-1)) 184 | elif v.dim() == 4: 185 | kwargs_decoder[k] = v.reshape(-1, v.size(-2), v.size(-1)) 186 | 187 | decoder_outputs = self.decoder(hidden_states, **kwargs_decoder) 188 | 189 | sequence_output = decoder_outputs[0] 190 | sequence_output = sequence_output.reshape(batch_size, num_choices, seq_len, -1) 191 | # Rescale output before projecting on vocab 192 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 193 | sequence_output = sequence_output * (self.model_dim ** -0.5) 194 | lm_logits = self.lm_head(sequence_output) 195 | 196 | decoder_outputs = (lm_logits,) + decoder_outputs[1:] # Add hidden states and attention if they are here 197 | if lm_labels is not None: 198 | 199 | shift_logits = lm_logits[..., :-1, :].contiguous() 200 | shift_labels = lm_labels[..., 1:].contiguous() 201 | loss_fct = CrossEntropyLoss(ignore_index=-100, reduction = 'none') 202 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 203 | 204 | # reshape to batch_size x num choices x -1 then sum out the token dim (alternatively, could take the mean and not penalize longer answers) 205 | loss = loss.reshape(batch_size, num_choices, -1) 206 | loss = torch.mean(loss, dim=-1) 207 | 208 | decoder_outputs = ( 209 | loss, 210 | ) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 211 | 212 | return decoder_outputs + encoder_outputs 213 | 214 | 215 | -------------------------------------------------------------------------------- /rl_experiments/run_cqa_t5_rl_ra.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | for seed in 11 3 | do 4 | # generated rl explanations and predictions 5 | python t5_rl.py \ 6 | --gpu=0 \ 7 | --seed=${seed} \ 8 | --model_name="cqa_t5_rl_ra_seed${seed}" \ 9 | --train_task=True --explain_task=True --sample_task=True --train_sim=True \ 10 | --do_rl=True --rl_sampling_strategy="multinomial" --ce_loss=True \ 11 | --rationalize=True \ 12 | --select_for="sim_acc" \ 13 | --task_lr=5e-5 --sim_lr=1e-4 --alpha=0.95 --temperature=0.1 \ 14 | --train_batch_size=2 --eval_batch_size=2 --grad_accumulation_factor=6 --num_train_epoch=20 \ 15 | --dataset="cqa" --do_test=False \ 16 | --max_seq_len=110 \ 17 | --write_prediction \ 18 | --train_output_file="cqa_ra_train_rl_seed${seed}.csv" \ 19 | --eval_output_file="cqa_ra_eval_rl_seed${seed}.csv" \ 20 | --test_output_file="cqa_ra_test_rl_seed${seed}.csv" 21 | 22 | # train task model using generated explanations and predictions 23 | python t5_rl.py \ 24 | --gpu=0 \ 25 | --seed=${seed} \ 26 | --model_name="cqa_t5_ra_rl_sim_seed${seed}" \ 27 | --train_data_file="cqa_ra_train_rl_seed${seed}.csv" \ 28 | --eval_data_file="cqa_ra_eval_rl_seed${seed}.csv" \ 29 | --test_data_file="cqa_ra_test_rl_seed${seed}.csv" \ 30 | --task_base_model="t5-small" \ 31 | --train_task=True --explain_task=False --sample_task=False \ 32 | --train_sim=False --explain_sim=False \ 33 | --do_rl=False \ 34 | --select_for="task_acc" \ 35 | --task_lr=1e-4 \ 36 | --train_batch_size=12 --eval_batch_size=12 --grad_accumulation_factor=1 --num_train_epoch=10 \ 37 | --condition_on_explanation=True --explanation_to_use="t5" --label_to_use="t5" \ 38 | --dataset="cqa" --do_test=False \ 39 | --max_seq_len=110 40 | done 41 | -------------------------------------------------------------------------------- /rl_experiments/run_cqa_t5_rl_re.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | for seed in 11 3 | do 4 | # generated rl explanations and predictions 5 | python t5_rl.py \ 6 | --gpu=0 \ 7 | --seed=${seed} \ 8 | --model_name="cqa_t5_rl_re_seed${seed}" \ 9 | --train_task=True --explain_task=True --sample_task=True --train_sim=True \ 10 | --do_rl=True --rl_sampling_strategy="multinomial" --ce_loss=True \ 11 | --rationalize=False \ 12 | --select_for="sim_acc" \ 13 | --task_lr=5e-5 --sim_lr=1e-4 --alpha=0.95 --temperature=0.1 \ 14 | --train_batch_size=2 --eval_batch_size=2 --grad_accumulation_factor=6 --num_train_epoch=20 \ 15 | --dataset="cqa" --do_test=False \ 16 | --max_seq_len=110 \ 17 | --write_prediction \ 18 | --train_output_file="cqa_re_train_rl_seed${seed}.csv" \ 19 | --eval_output_file="cqa_re_eval_rl_seed${seed}.csv" \ 20 | --test_output_file="cqa_re_test_rl_seed${seed}.csv" 21 | 22 | # train task model using generated explanations and predictions 23 | python t5_rl.py \ 24 | --gpu=0 \ 25 | --seed=${seed} \ 26 | --model_name="cqa_t5_re_rl_sim_seed${seed}" \ 27 | --train_data_file="cqa_re_train_rl_seed${seed}.csv" \ 28 | --eval_data_file="cqa_re_eval_rl_seed${seed}.csv" \ 29 | --test_data_file="cqa_re_test_rl_seed${seed}.csv" \ 30 | --task_base_model="t5-small" \ 31 | --train_task=True --explain_task=False --sample_task=False \ 32 | --train_sim=False --explain_sim=False \ 33 | --do_rl=False \ 34 | --select_for="task_acc" \ 35 | --task_lr=1e-4 \ 36 | --train_batch_size=12 --eval_batch_size=12 --grad_accumulation_factor=1 --num_train_epoch=10 \ 37 | --condition_on_explanation=True --explanation_to_use="t5" --label_to_use="t5" \ 38 | --dataset="cqa" --do_test=False \ 39 | --max_seq_len=110 40 | done 41 | -------------------------------------------------------------------------------- /rl_experiments/run_nli_t5_rl_ra.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | for seed in 11 3 | do 4 | # generated rl explanations and predictions 5 | python t5_rl.py \ 6 | --seed=${seed} \ 7 | --model_name="nli_t5_rl_ra_seed${seed}" \ 8 | --train_task=True --explain_task=True --sample_task=True \ 9 | --train_sim=True --explain_sim=False \ 10 | --do_rl=True --rl_sampling_strategy="multinomial" --ce_loss=True \ 11 | --select_for="sim_acc" \ 12 | --task_lr=1e-5 --sim_lr=1e-4 --alpha=0.9 --temperature=0.1 \ 13 | --train_batch_size=4 --eval_batch_size=4 --grad_accumulation_factor=3 --num_train_epoch=10 \ 14 | --dataset="nli" --do_test=True \ 15 | --train_data_file="train.tsv" \ 16 | --eval_data_file="dev.tsv" \ 17 | --test_data_file="test.tsv" \ 18 | --max_seq_len=110 \ 19 | --write_prediction \ 20 | --train_output_file="train_rl_ra_seed${seed}.tsv" \ 21 | --eval_output_file="eval_rl_ra_seed${seed}.tsv" \ 22 | --test_output_file="test_rl_ra_seed${seed}.tsv" 23 | 24 | # train task model using generated explanations and predictions 25 | python t5_rl.py \ 26 | --seed=${seed} \ 27 | --model_name="nli_t5_rl_ra_sim_seed${seed}" \ 28 | --task_base_model="t5-small" \ 29 | --train_data_file="train_rl_ra_seed${seed}.tsv" \ 30 | --eval_data_file="eval_rl_ra_seed${seed}.tsv" \ 31 | --test_data_file="test_rl_ra_seed${seed}.tsv" \ 32 | --train_task=True --explain_task=False --sample_task=False \ 33 | --train_sim=False --explain_sim=False \ 34 | --do_rl=False \ 35 | --select_for="task_acc" \ 36 | --task_lr=1e-4 \ 37 | --train_batch_size=12 --eval_batch_size=12 --grad_accumulation_factor=1 --num_train_epoch=10 \ 38 | --condition_on_explanation=True --explanation_to_use="t5" --label_to_use="t5" \ 39 | --dataset="nli" --do_test=True \ 40 | --max_seq_len=110 41 | done 42 | -------------------------------------------------------------------------------- /rl_experiments/run_nli_t5_rl_re.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | for seed in 11 3 | do 4 | # generated rl explanations and predictions 5 | python t5_rl.py \ 6 | --seed=${seed} \ 7 | --model_name="nli_t5_rl_re_seed${seed}" \ 8 | --train_task=True --explain_task=True --sample_task=True \ 9 | --train_sim=True --explain_sim=False \ 10 | --do_rl=True --rl_sampling_strategy="multinomial" --ce_loss=True \ 11 | --select_for="sim_acc" \ 12 | --task_lr=1e-5 --sim_lr=1e-4 --alpha=0.9 --temperature=0.1 \ 13 | --train_batch_size=4 --eval_batch_size=4 --grad_accumulation_factor=3 --num_train_epoch=10 \ 14 | --dataset="nli" --do_test=True \ 15 | --train_data_file="train.tsv" \ 16 | --eval_data_file="dev.tsv" \ 17 | --test_data_file="test.tsv" \ 18 | --max_seq_len=110 \ 19 | --write_prediction \ 20 | --train_output_file="train_rl_re_seed${seed}.tsv" \ 21 | --eval_output_file="eval_rl_re_seed${seed}.tsv" \ 22 | --test_output_file="test_rl_re_seed${seed}.tsv" 23 | 24 | # train task model using generated explanations and predictions 25 | python t5_rl.py \ 26 | --seed=${seed} \ 27 | --model_name="nli_t5_rl_re_sim_seed${seed}" \ 28 | --task_base_model="t5-small" \ 29 | --train_data_file="train_rl_re_seed${seed}.tsv" \ 30 | --eval_data_file="eval_rl_re_seed${seed}.tsv" \ 31 | --test_data_file="test_rl_re_seed${seed}.tsv" \ 32 | --train_task=True --explain_task=False --sample_task=False \ 33 | --train_sim=False --explain_sim=False \ 34 | --do_rl=False \ 35 | --select_for="task_acc" \ 36 | --task_lr=1e-4 \ 37 | --train_batch_size=12 --eval_batch_size=12 --grad_accumulation_factor=1 --num_train_epoch=10 \ 38 | --condition_on_explanation=True --explanation_to_use="t5" --label_to_use="t5" \ 39 | --dataset="nli" --do_test=True \ 40 | --max_seq_len=110 41 | done -------------------------------------------------------------------------------- /rl_experiments/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import jsonlines 5 | import csv 6 | from sacrebleu import corpus_bleu 7 | 8 | 9 | class SMExample(object): 10 | 11 | def __init__(self, input_dict): 12 | self.sm_id = input_dict['id'] 13 | self.statements = [input_dict['sentence0'], input_dict['sentence1']] 14 | self.statement_label = 1 - int(input_dict['false']) 15 | self.explanations = [input_dict['A'], input_dict['B'], input_dict['C']] 16 | self.explanation_label = ['A', 'B', 'C'].index(input_dict['reason']) 17 | self.human_explanation = self.explanations[self.explanation_label] 18 | self.input_dict = input_dict 19 | 20 | 21 | class CQAExample: 22 | 23 | def __init__(self, input_dict): 24 | self.cqa_id = input_dict['id'] 25 | self.question = input_dict['question'] 26 | self.choices = [input_dict['choice_0'], input_dict['choice_1'], input_dict['choice_2']] 27 | self.label = int(input_dict.get('label', -1)) 28 | self.human_explanation = input_dict.get('human_expl_open-ended', '') 29 | self.input_dict = input_dict 30 | 31 | 32 | class NLIExample(object): 33 | 34 | def __init__(self, input_dict): 35 | self.nli_id = input_dict['unique_key'] 36 | self.premise = input_dict['premise'] 37 | self.hypothesis = input_dict['hypothesis'] 38 | self.human_explanation = input_dict['explanation1'] if 'explanation1' in input_dict else input_dict['explanation'] 39 | self.choices = ['neutral', 'entailment', 'contradiction'] 40 | self.label = int(input_dict['label']) 41 | self.input_dict = input_dict 42 | 43 | 44 | class Report(): 45 | """Report stores evaluation results during the training process as text files.""" 46 | 47 | def __init__(self, args, file_path, score_names): 48 | self.fn = file_path 49 | self.args = args 50 | self.header = score_names 51 | self.text = '' 52 | 53 | # write input arguments at the top 54 | self.text += 'Input: python %s %s \n\n' % \ 55 | (sys.argv[0], 56 | ' '.join([arg for arg in sys.argv[1:]])) 57 | 58 | # make header 59 | header_str = '%5s |' % 'epoch' 60 | for n, score_name in enumerate(self.header): 61 | header_str += ' %20s ' % score_name 62 | if n < len(score_names) - 1: header_str += '|' 63 | 64 | # write header 65 | self.blank_line = '-' * len(header_str) 66 | self.text += self.blank_line + \ 67 | f"\nTraining report for model: {args.model_name}" + \ 68 | '\n' + self.blank_line + \ 69 | '\n' 70 | self.text += header_str 71 | 72 | def write_epoch_scores(self, epoch, scores): 73 | # write scores 74 | self.text += '\n%5s |' % ('%d' % epoch) 75 | for idx, column_name in enumerate(self.header): 76 | self.text += ' %20s ' % ('%1.5f' % scores[column_name]) 77 | if idx < len(scores) - 1: self.text += '|' 78 | self.__save() 79 | 80 | def write_final_score(self, args, final_score_str): 81 | self.text += '\n' + self.blank_line 82 | self.text += '\n%s' % final_score_str 83 | self.text += '\n' + self.blank_line + '\n' 84 | 85 | self.text += '\n' 86 | self.write_all_arguments(args) 87 | 88 | self.__save() 89 | 90 | def write_msg(self, msg): 91 | self.text += self.blank_line 92 | self.text += msg 93 | self.__save() 94 | 95 | def write_all_arguments(self, args): 96 | self.text += "\nAll arguments:\n" 97 | self.text += '\n'.join(['\t' + hp for hp in str(args).replace('Namespace(', '').replace(')', '').split(', ')]) 98 | self.__save() 99 | 100 | def full_print(self): 101 | print('\n' + self.text + '\n') 102 | 103 | def __save(self): 104 | if self.fn is not None: 105 | with open(self.fn, mode='w') as text_file: 106 | text_file.write(self.text) 107 | 108 | 109 | def print_epoch_scores(epoch, scores): 110 | epoch_text = ' %5s |' % 'epoch' 111 | for n, score_name in enumerate(scores.keys()): 112 | epoch_text += ' %20s ' % score_name 113 | if n < len(scores) - 1: epoch_text += '|' 114 | epoch_text += '\n %5s |' % ('%d' % epoch) 115 | for n, score in enumerate(scores.values()): 116 | epoch_text += ' %20s ' % ('%1.5f' % score) 117 | if n < len(scores) - 1: epoch_text += '|' 118 | print(epoch_text + '\n') 119 | 120 | 121 | def read_sm_examples(input_filepath): 122 | examples = [] 123 | with jsonlines.open(input_filepath, 'r') as reader: 124 | for line in reader: 125 | examples.append(SMExample(line)) 126 | return examples 127 | 128 | 129 | def read_cqa_examples(input_filepath): 130 | examples = [] 131 | with open(input_filepath, newline='') as csv_file: 132 | csv_reader = csv.DictReader(csv_file) 133 | for row in csv_reader: 134 | examples.append(CQAExample(row)) 135 | return examples 136 | 137 | 138 | def read_nli_examples(input_filepath): 139 | examples = [] 140 | with open(input_filepath, newline='') as csv_file: 141 | csv_reader = csv.DictReader(csv_file, delimiter='\t') 142 | for row in csv_reader: 143 | examples.append(NLIExample(row)) 144 | return examples 145 | 146 | 147 | def detok_batch(tokenizer, x, ignore_tokens=None, eos_token=None): 148 | ''' 149 | - convert x into strings using tokenizer 150 | - x is either tensor of dim 2 or dim 3 or a .tolist() of such a tensor 151 | - stop decoding if eos_token hit, if eos_token provided 152 | - skip over tokens in ignore_tokens 153 | ''' 154 | if ignore_tokens is not None: 155 | ignore_tokens_idx = tokenizer.convert_tokens_to_ids(ignore_tokens) 156 | ignore_tokens_idx += [-100, -1] 157 | else: 158 | ignore_tokens = [] 159 | ignore_tokens_idx = [-100, -1] 160 | 161 | # if tokenizer.pad_token_id is None: 162 | ignore_tokens_idx += [0] 163 | if not isinstance(x, list): 164 | x = x.tolist() 165 | dim = 3 if isinstance(x[0][0], list) else 2 166 | eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) 167 | texts = [] 168 | for i in range(len(x)): 169 | if dim == 2: 170 | current_idx = [] 171 | for j in range(len(x[i])): 172 | current_id = x[i][j] 173 | if current_id == eos_token_id: 174 | break 175 | elif current_id not in ignore_tokens_idx: 176 | current_idx.append(current_id) 177 | decoded_sequence = tokenizer.decode(current_idx) 178 | 179 | # check if any ignore_tokens are in decoded_sequence. 180 | # this is happening for some reason. many token_ids lead to [UNK], but [UNK] maps to id=100 181 | if any([ignore_token in decoded_sequence for ignore_token in ignore_tokens]): 182 | decoded_sequence = ' '.join([token for token in decoded_sequence.split() if token not in ignore_tokens]) 183 | 184 | # APPEND 185 | texts.append(decoded_sequence) 186 | 187 | elif dim == 3: 188 | decoded_sequences = [] 189 | for j in range(len(x[i])): 190 | current_idx = [] 191 | for k in range(len(x[i][j])): 192 | current_id = x[i][j][k] 193 | if current_id == eos_token_id: 194 | break 195 | elif current_id not in ignore_tokens_idx: 196 | current_idx.append(current_id) 197 | decoded_sequence = tokenizer.decode(current_idx) 198 | 199 | # check if any ignore_tokens are in decoded_sequence. 200 | # this is happening for some reason. many token_ids lead to [UNK], but [UNK] maps to id=100 201 | if any([ignore_token in decoded_sequence for ignore_token in ignore_tokens]): 202 | decoded_sequence = ' '.join( 203 | [token for token in decoded_sequence.split() if token not in ignore_tokens]) 204 | 205 | # APPEND single decoding 206 | decoded_sequences.append(decoded_sequence) 207 | 208 | # APPEND list of n decodings 209 | texts.append(decoded_sequences) 210 | 211 | return texts 212 | 213 | 214 | def str2bool(v): 215 | # used for boolean argparse values 216 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 217 | return True 218 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 219 | return False 220 | else: 221 | raise argparse.ArgumentTypeError('Boolean value expected.') 222 | 223 | 224 | def truncate_seq_pair(tokens_a, tokens_b, max_length): 225 | """Truncates a sequence pair in place to the maximum length.""" 226 | # This is a simple heuristic which will always truncate the longer sequence 227 | # one token at a time. This makes more sense than truncating an equal percent 228 | # of tokens from each, since if one sequence is very short then each token 229 | # that's truncated likely contains more information than a longer sequence. 230 | while True: 231 | total_length = len(tokens_a) + len(tokens_b) 232 | if total_length <= max_length: 233 | break 234 | if len(tokens_a) > len(tokens_b): 235 | tokens_a.pop() 236 | else: 237 | tokens_b.pop() 238 | 239 | 240 | def write_prediction_to_sm_file(pred_dict, data_filepath, output_filepath): 241 | print(f'\nUpdating {data_filepath} with model predictions...') 242 | # read dataset 243 | examples = [] 244 | with jsonlines.open(data_filepath, 'r') as reader: 245 | for line in reader: 246 | examples.append(line) 247 | 248 | # add or replace generated explanation to dictionary 249 | for column_name, predictions in pred_dict.items(): 250 | if len(examples) != len(predictions): 251 | print('Warning: number of predictions not equal to number of input examples.') 252 | min_len = min(len(examples), len(predictions)) 253 | examples = examples[:min_len] 254 | predictions = predictions[:min_len] 255 | for i, example in enumerate(examples): 256 | example[column_name] = predictions[i] 257 | 258 | with jsonlines.open(output_filepath, 'w') as writer: 259 | for line in examples: 260 | writer.write(line) 261 | print(f'Predictions written to {output_filepath} under columns {pred_dict.keys()}.') 262 | 263 | 264 | def write_prediction_to_cqa_file(pred_dict, data_filepath, output_filepath): 265 | print(f'\nUpdating {output_filepath} with model predictions...') 266 | # read dataset 267 | examples = [] 268 | with open(data_filepath, newline='') as csv_file: 269 | csv_reader = csv.DictReader(csv_file) 270 | for row in csv_reader: 271 | examples.append(row) 272 | 273 | for column_name, predictions in pred_dict.items(): 274 | if len(examples) != len(predictions): 275 | print('Warning: number of predictions not equal to number of input examples.') 276 | min_len = min(len(examples), len(predictions)) 277 | examples = examples[:min_len] 278 | predictions = predictions[:min_len] 279 | for i, example in enumerate(examples): 280 | example[column_name] = predictions[i] 281 | 282 | # write to csv file 283 | with open(output_filepath, 'w', newline='') as csvfile: 284 | fieldnames = examples[0].keys() 285 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 286 | writer.writeheader() 287 | for example in examples: 288 | writer.writerow(example) 289 | print(f'Predictions written to {output_filepath} under columns {pred_dict.keys()}.') 290 | 291 | 292 | def write_prediction_to_nli_file(pred_dict, data_filepath, output_filepath): 293 | print(f'\nUpdating {output_filepath} with model predictions...') 294 | # read dataset 295 | examples = [] 296 | with open(data_filepath, newline='') as csv_file: 297 | csv_reader = csv.DictReader(csv_file, delimiter='\t') 298 | for row in csv_reader: 299 | examples.append(row) 300 | 301 | for column_name, predictions in pred_dict.items(): 302 | if len(examples) != len(predictions): 303 | print('Warning: number of predictions not equal to number of input examples.') 304 | min_len = min(len(examples), len(predictions)) 305 | examples = examples[:min_len] 306 | predictions = predictions[:min_len] 307 | for i, example in enumerate(examples): 308 | example[column_name] = predictions[i] 309 | 310 | # write to csv file 311 | with open(output_filepath, 'w', newline='') as csvfile: 312 | fieldnames = examples[0].keys() 313 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter='\t') 314 | writer.writeheader() 315 | for example in examples: 316 | writer.writerow(example) 317 | print(f'Predictions written to {output_filepath} under columns {pred_dict.keys()}.') 318 | 319 | 320 | def compute_bleu(outputs, targets): 321 | # see https://github.com/mjpost/sacreBLEU 322 | targets = [[t[i] for t in targets] for i in range(len(targets[0]))] 323 | return corpus_bleu(outputs, targets, lowercase=True).score 324 | -------------------------------------------------------------------------------- /sim_experiments/NLI_data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import argparse 4 | import logging 5 | import json 6 | import time 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import pandas as pd 11 | import utils 12 | from utils import isNaN 13 | 14 | class NLIExample(object): 15 | '''used for training models with CQA data''' 16 | def __init__(self, 17 | idx, 18 | premise, 19 | hypothesis, 20 | label, 21 | choices, 22 | explanation): 23 | self.idx = idx 24 | self.premise = premise 25 | self.hypothesis = hypothesis 26 | self.explanation = explanation 27 | self.choices = choices 28 | self.label = label 29 | self.explanation_list = [explanation] \ 30 | if not isinstance(explanation, list) \ 31 | else \ 32 | explanation 33 | 34 | def read_NLI(args, input_file, explanations_to_use, version, 35 | labels_to_use = 'label', filter_explanations = None): 36 | 37 | label_map = {0: "neutral", 1: "entailment", 2: "contradiction"} 38 | is_train = 'train' in input_file 39 | exp_cols = ['explanation%d' % d for d in range(1,4)] if not is_train else ['explanation'] 40 | df = pd.read_csv(input_file, delimiter = '\t') 41 | n = len(df) if not args.small_data else args.small_size 42 | num_choices = 3 43 | multi_exp = (args.condition_on_explanations and 'multi' in explanations_to_use and args.multi_explanation) # ST-Ra 44 | # simulate_rationalized is used to pull out the predicted explanation when simulating a ST-Ra model 45 | simulate_rationalized = (args.condition_on_explanations and not args.multi_explanation and 'st.ra' in (labels_to_use.lower() if isinstance(labels_to_use, str) else '' )) 46 | ids = df['unique_key'] 47 | premises = df['premise'] 48 | hypotheses = df['hypothesis'] 49 | print("using labels: %s" % labels_to_use) 50 | labels = df[labels_to_use] 51 | 52 | if explanations_to_use == 'None': 53 | explanations = [''] * n 54 | else: 55 | exp_cols = explanations_to_use 56 | try: 57 | explanations = df[exp_cols] 58 | print(f"getting explanations from {explanations_to_use}") 59 | except: 60 | if explanations_to_use == 'ground_truth': 61 | exp_cols = 'explanation' if is_train else 'explanation1' 62 | elif explanations_to_use == 'oracle': 63 | exp_cols = 'explanation' if is_train else 'explanation1' 64 | elif explanations_to_use == 'gpt2': 65 | exp_cols = 'gpt2-single-exp' 66 | elif explanations_to_use == 'multi_gpt2': 67 | exp_cols = [f'gpt2-multi-exp-{i}' for i in range(num_choices)] 68 | elif explanations_to_use == 't5': 69 | exp_cols = 't5-single-exp' 70 | elif explanations_to_use == 'multi_t5': 71 | exp_cols = [f't5-multi-exp-{i}' for i in range(num_choices)] 72 | elif explanations_to_use == 'MT_t5': 73 | exp_cols = 't5-MT-single-exp' 74 | elif explanations_to_use == 'MT_multi_t5': 75 | exp_cols = [f't5-MT-multi-exp-{i}' for i in range(num_choices)] 76 | elif explanations_to_use == 'MT_multi_t5_pred': 77 | exp_cols = 't5-MT-multi-exp-pred' 78 | elif explanations_to_use == 'bert_cage': 79 | exp_cols = 'bert-cage-single-exp' 80 | elif explanations_to_use == 'bert_multi_cage': 81 | exp_cols = [f'bert-cage-multi-exp-{i}' for i in range(num_choices)] 82 | elif explanations_to_use == 't5-agent-re': 83 | exp_cols = 't5-agent-re-exp' 84 | elif explanations_to_use == 't5-agent-ra': 85 | exp_cols = 't5-agent-ra-exp' 86 | # ST (task or simulation) 87 | elif 'multi-exp' in explanations_to_use and 'MT' not in explanations_to_use: 88 | exp_cols = [f't5-multi-exp-{i}-seed{args.seed}' for i in range(num_choices)] 89 | # MT (simulation) 90 | elif 'multi-exp' in explanations_to_use and 'MT' in explanations_to_use: 91 | exp_cols = [f't5-MT-multi-exp-pred-seed{args.seed}' for i in range(num_choices)] 92 | print(f"getting explanations from {exp_cols}") 93 | explanations = df[exp_cols] 94 | 95 | # pick out the predicted explanations, according to the task model's prediction 96 | if simulate_rationalized: 97 | print("picking out predicted explanation") 98 | explanations = [explanations.loc[i,exp_cols[label]] for i, label in enumerate(labels)] 99 | 100 | examples = [NLIExample(idx = ids[i], 101 | premise = premises[i], 102 | hypothesis = hypotheses[i], 103 | explanation = explanations.iloc[i].tolist() if multi_exp else explanations[i], 104 | choices = [v for v in label_map.values()], 105 | label = labels[i]) 106 | for i in range(n)] 107 | 108 | return examples 109 | 110 | def get_tensors_for_bert(args, examples, tokenizer, max_seq_length : int, condition_on_explanations : bool, multi_explanation : bool, 111 | spliced_explanation_len = None, explanations_only = False): 112 | """ 113 | Converts a list of examples into features for use with T5. 114 | ref_answer -- reference the answer in the explanation output ids, or the distractor if False 115 | Returns: list of tensors 116 | """ 117 | input_padding_id = tokenizer.pad_token_id 118 | label_padding_id = -100 119 | eos_token_id = tokenizer.eos_token_id 120 | explanation_prefix_ids = tokenizer.encode("explain:", add_special_tokens = False) 121 | return_data = [] 122 | start = time.time() 123 | for example_index, example in enumerate(examples): 124 | if example_index > args.small_size and args.small_data: 125 | break 126 | # per-question variables 127 | premise = example.premise 128 | hypothesis = example.hypothesis 129 | choice_label = example.label 130 | answer_str = example.choices[choice_label] 131 | explanation_str = example.explanation 132 | task_input_ids_list = [] 133 | # first screen for length. want to keep input formatting as is due to tokenization differences with spacing before words (rather than adding all the ids) 134 | input_str = f"{tokenizer.cls_token} {premise} {tokenizer.sep_token} {hypothesis} {tokenizer.sep_token}" 135 | if spliced_explanation_len is not None: 136 | cap_length = max_seq_length-spliced_explanation_len 137 | else: 138 | cap_length = max_seq_length 139 | 140 | if explanations_only: 141 | premise = "" 142 | hypothesis = "" 143 | 144 | init_input_ids = tokenizer.encode(input_str) 145 | if len(init_input_ids) > (cap_length): 146 | over_by = len(init_input_ids) - cap_length 147 | premise_tokens = tokenizer.encode(premise) 148 | keep_up_to = len(premise_tokens) - over_by - 2 # leaves buffer 149 | new_premise_tokens = premise_tokens[:keep_up_to] 150 | premise = tokenizer.decode(new_premise_tokens) + '.' 151 | 152 | # get string formats 153 | input_str = f"{tokenizer.cls_token} {premise} {tokenizer.sep_token} {hypothesis} {tokenizer.sep_token}" 154 | if condition_on_explanations: 155 | input_str += f" My commonsense tells me {explanation_str} {tokenizer.sep_token}" 156 | 157 | explanation_context_str = f"My commonsense tells me that" 158 | explanation_context_ids = tokenizer.encode(explanation_context_str, add_special_tokens = False) 159 | explanation_only_ids = tokenizer.encode(example.explanation, add_special_tokens = False) 160 | explanation_len = len(explanation_context_ids) + len(explanation_only_ids) 161 | 162 | # get token_ids 163 | _input_ids = tokenizer.encode(input_str, add_special_tokens = False) 164 | task_input_ids = _input_ids 165 | 166 | # truncate to fit in max_seq_length 167 | _truncate_seq_pair(task_input_ids, [], max_seq_length) 168 | 169 | # pad up to the max sequence len. NOTE input_padding_id goes on inputs to either the encoder or decoder. label_padding_id goes on lm_labels for decode 170 | padding = [input_padding_id] * (max_seq_length - len(task_input_ids)) 171 | task_input_ids += padding 172 | 173 | # make into tensors and accumulate 174 | task_input_ids = torch.tensor(task_input_ids if len(task_input_ids_list) < 1 else task_input_ids_list, dtype = torch.long) 175 | task_input_masks = (task_input_ids!=input_padding_id).float() 176 | task_choice_label = torch.tensor(choice_label, dtype = torch.long) 177 | explanation_len = torch.tensor(explanation_len).long() 178 | # cross-compatability with number of items in t5_split below... 179 | data_point = [task_input_ids, task_input_masks, task_input_ids, task_input_ids, task_input_ids, task_input_ids, task_input_ids, task_input_ids, task_choice_label] 180 | data_point += [task_choice_label] * 7 + [explanation_len] 181 | return_data.append(data_point) 182 | print("loading data took %.2f seconds" % (time.time() - start)) 183 | # now reshape list of lists of tensors to list of tensors 184 | n_cols = len(return_data[0]) 185 | return_data = [torch.stack([data_point[j] for data_point in return_data], dim=0) for j in range(n_cols)] 186 | return return_data 187 | 188 | def get_tensors_for_T5_split(args, examples, tokenizer, max_seq_length : int, condition_on_explanations : bool, multi_explanation : bool, 189 | spliced_explanation_len = None, explanations_only = False): 190 | """ 191 | Converts a list of examples into features for use with T5. 192 | 193 | ref_answer -- reference the answer in the explanation output ids, or the distractor if False 194 | 195 | Returns: list of tensors 196 | """ 197 | 198 | input_padding_id = tokenizer.pad_token_id 199 | label_padding_id = -100 200 | eos_token_id = tokenizer.eos_token_id 201 | explanation_prefix_ids = tokenizer.encode("explain:", add_special_tokens = False) 202 | 203 | return_data = [] 204 | 205 | start = time.time() 206 | 207 | for example_index, example in enumerate(examples): 208 | 209 | # per-question variables 210 | premise = example.premise 211 | hypothesis = example.hypothesis 212 | choice_label = example.label 213 | answer_str = example.choices[choice_label] 214 | explanation_str = example.explanation 215 | if isNaN(explanation_str): 216 | print("got nan explanation") 217 | example.explanation = '__' 218 | 219 | task_input_ids_list = [] 220 | task_output_ids_list = [] 221 | task_output_labels_list = [] 222 | explanation_context_ids_list = [] 223 | 224 | # first screen for length. want to keep input formatting as is due to tokenization differences with spacing before words (rather than adding all the ids) 225 | input_str = f"nli premise: [CLS] {premise} [SEP] hypothesis: {hypothesis} [SEP]" 226 | if spliced_explanation_len is not None: 227 | cap_length = max_seq_length-spliced_explanation_len 228 | else: 229 | cap_length = max_seq_length 230 | 231 | init_input_ids = tokenizer.encode(input_str) 232 | if len(init_input_ids) > (cap_length): 233 | over_by = len(init_input_ids) - cap_length 234 | premise_tokens = tokenizer.encode(premise) 235 | keep_up_to = len(premise_tokens) - over_by - 2 # leaves buffer 236 | new_premise_tokens = premise_tokens[:keep_up_to] 237 | premise = tokenizer.decode(new_premise_tokens) + '.' 238 | # print() 239 | # print("old premise: ", tokenizer.decode(premise_tokens)) 240 | # print("new premise: ", premise) 241 | 242 | # in explanations only, remove the task input 243 | if explanations_only: 244 | premise = "" 245 | hypothesis = "" 246 | 247 | # get string formats 248 | input_str = f"nli premise: [CLS] {premise} [SEP] hypothesis: {hypothesis} [SEP]" 249 | if condition_on_explanations and not multi_explanation: 250 | input_str += f" My commonsense tells me {explanation_str}" 251 | elif condition_on_explanations and multi_explanation: 252 | # make task_input_ids in answer loop below 253 | input_str = "" 254 | task_answer_str = f"answer {answer_str}" # want the prefix to be just a single token id 255 | if multi_explanation: 256 | explanation_output_str = f"The answer is '{answer_str}' because {explanation_str}" 257 | elif not multi_explanation: 258 | explanation_output_str = f"My commonsense tells me that {explanation_str}" 259 | 260 | # get token_ids 261 | _input_ids = tokenizer.encode(input_str, add_special_tokens = False) 262 | task_input_ids = _input_ids 263 | explanation_input_ids = explanation_prefix_ids + _input_ids 264 | explanation_only_ids = tokenizer.encode(example.explanation, add_special_tokens = False) 265 | _task_answer_ids = tokenizer.encode(task_answer_str, add_special_tokens = False) 266 | _explanation_output_ids = tokenizer.encode(explanation_output_str, add_special_tokens = False) + [eos_token_id] 267 | 268 | # truncate to fit in max_seq_length 269 | _truncate_seq_pair(task_input_ids, [], max_seq_length) 270 | _truncate_seq_pair(explanation_input_ids, [], max_seq_length) 271 | _truncate_seq_pair(_explanation_output_ids, [], max_seq_length) 272 | _truncate_seq_pair(explanation_only_ids, [], max_seq_length) 273 | 274 | for choice_index, choice in enumerate(example.choices): 275 | 276 | # make multiple inputs, for this condition 277 | if condition_on_explanations and multi_explanation: 278 | if len(example.explanation_list) > 1: 279 | explanation_str = example.explanation_list[choice_index] 280 | else: 281 | explanation_str = '' 282 | explanation_output_str = f"The answer is '{choice}' because {explanation_str}" 283 | task_input_str = f"nli premise: [CLS] {premise} [SEP] hypothesis: {hypothesis} [SEP] {explanation_output_str}" 284 | task_input_ids = tokenizer.encode(task_input_str, add_special_tokens = False) 285 | _truncate_seq_pair(task_input_ids, [], max_seq_length) 286 | ids_padding = [input_padding_id] * (max_seq_length - len(task_input_ids)) 287 | task_input_ids += ids_padding 288 | task_input_ids_list.append(task_input_ids) 289 | 290 | task_output_str = f"answer {choice}" 291 | _task_output_ids = tokenizer.encode(task_output_str, add_special_tokens = False) 292 | ids_padding = [input_padding_id] * (max_seq_length - len(_task_output_ids)) 293 | labels_padding = [label_padding_id] * (max_seq_length - len(_task_output_ids)) 294 | task_output_ids = _task_output_ids + ids_padding 295 | task_output_labels = _task_output_ids + labels_padding 296 | task_output_ids_list.append(task_output_ids) 297 | task_output_labels_list.append(task_output_labels) 298 | 299 | # make context str(s) 300 | if multi_explanation: 301 | explanation_context_str = f"The answer is '{choice}' because" 302 | elif not multi_explanation: 303 | explanation_context_str = f"My commonsense tells me that" 304 | explanation_context_ids = tokenizer.encode(explanation_context_str, add_special_tokens = False) 305 | if choice == answer_str: 306 | context_len = len(explanation_context_ids) 307 | explanation_context_ids += [input_padding_id] * (max_seq_length - len(explanation_context_ids)) 308 | _truncate_seq_pair(explanation_context_ids, [], max_seq_length) 309 | explanation_context_ids_list.append(explanation_context_ids) 310 | 311 | # pad up to the max sequence len. NOTE input_padding_id goes on inputs to either the encoder or decoder. label_padding_id goes on lm_labels for decode 312 | padding = [input_padding_id] * (max_seq_length - len(task_input_ids)) 313 | task_input_ids += padding 314 | padding = [input_padding_id] * (max_seq_length - len(explanation_input_ids)) 315 | explanation_input_ids += padding 316 | padding = [input_padding_id] * (max_seq_length - len(explanation_only_ids)) 317 | explanation_only_ids += padding 318 | 319 | # store explanation_len for dropout/masking purposes 320 | explanation_len = len([e for e in explanation_context_ids if e != input_padding_id]) + len([e for e in explanation_only_ids if e != input_padding_id]) 321 | 322 | ids_padding = [input_padding_id] * (max_seq_length - len(_task_answer_ids)) 323 | labels_padding = [label_padding_id] * (max_seq_length - len(_task_answer_ids)) 324 | task_answer_ids = _task_answer_ids + ids_padding 325 | task_answer_labels = _task_answer_ids + labels_padding 326 | 327 | ids_padding = [input_padding_id] * (max_seq_length - len(_explanation_output_ids)) 328 | labels_padding = [label_padding_id] * (max_seq_length - len(_explanation_output_ids)) 329 | explanation_output_ids = _explanation_output_ids + ids_padding 330 | explanation_output_labels = _explanation_output_ids + labels_padding 331 | explanation_output_labels[:context_len] = [label_padding_id]*context_len # no LM loss on the explanation_context_str 332 | 333 | # make into tensors and accumulate 334 | task_input_ids = torch.tensor(task_input_ids if len(task_input_ids_list) < 1 else task_input_ids_list, dtype = torch.long) 335 | task_input_masks = (task_input_ids!=input_padding_id).float() 336 | task_answer_ids = torch.tensor(task_answer_ids, dtype = torch.long) 337 | task_answer_masks = (task_answer_ids!=input_padding_id).float() 338 | task_answer_labels = torch.tensor(task_answer_labels, dtype = torch.long) 339 | task_output_ids = torch.tensor(task_output_ids_list, dtype = torch.long) 340 | task_output_masks = (task_output_ids!=input_padding_id).float() 341 | task_output_labels = torch.tensor(task_output_labels_list, dtype = torch.long) 342 | explanation_input_ids = torch.tensor(explanation_input_ids, dtype = torch.long) 343 | explanation_input_masks = (explanation_input_ids!=input_padding_id).float() 344 | explanation_output_ids = torch.tensor(explanation_output_ids, dtype = torch.long) 345 | explanation_output_masks = (explanation_output_ids!=input_padding_id).float() 346 | explanation_output_labels = torch.tensor(explanation_output_labels, dtype = torch.long) 347 | explanation_context_ids = torch.tensor(explanation_context_ids_list, dtype = torch.long) 348 | task_choice_label = torch.tensor(choice_label, dtype = torch.long) 349 | explanation_only_ids = torch.tensor(explanation_only_ids, dtype = torch.long) 350 | explanation_len = torch.tensor(explanation_len).long() 351 | 352 | data_point = [task_input_ids, task_input_masks, 353 | task_answer_ids, task_answer_masks, task_answer_labels, 354 | task_output_ids, task_output_masks, task_output_labels, task_choice_label, 355 | explanation_input_ids, explanation_input_masks, 356 | explanation_output_ids, explanation_output_masks, explanation_output_labels, 357 | explanation_context_ids, explanation_only_ids, explanation_len] 358 | return_data.append(data_point) 359 | 360 | print("making data into tensors took %.2f seconds" % (time.time() - start)) 361 | 362 | # now reshape list of lists of tensors to list of tensors 363 | n_cols = len(return_data[0]) 364 | return_data = [torch.stack([data_point[j] for data_point in return_data], dim=0) for j in range(n_cols)] 365 | 366 | return return_data 367 | 368 | 369 | 370 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 371 | """Truncates a sequence pair in place to the maximum length.""" 372 | 373 | # This is a simple heuristic which will always truncate the longer sequence 374 | # one token at a time. This makes more sense than truncating an equal percent 375 | # of tokens from each, since if one sequence is very short then each token 376 | # that's truncated likely contains more information than a longer sequence. 377 | while True: 378 | total_length = len(tokens_a) + len(tokens_b) 379 | if total_length <= max_length: 380 | break 381 | if len(tokens_a) > len(tokens_b): 382 | tokens_a.pop() 383 | else: 384 | tokens_b.pop() -------------------------------------------------------------------------------- /sim_experiments/QA_data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import argparse 4 | import logging 5 | import json 6 | import time 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import pandas as pd 11 | from utils import removeNonAscii, isNaN 12 | 13 | 14 | class CQAExample(object): 15 | '''used for training models with CQA data''' 16 | def __init__(self, 17 | cqa_id, 18 | question, 19 | explanation, 20 | choices, 21 | label = None): 22 | self.cqa_id = cqa_id 23 | self.question = question 24 | self.explanation = explanation 25 | self.label = int(label) 26 | self.choices = choices 27 | self.version = '1.0' if len(choices) == 3 else '1.1' 28 | self.choices_str = f'The choices are {self.choices[0]}, {self.choices[1]}, and {self.choices[2]}.' \ 29 | if self.version == '1.0' \ 30 | else \ 31 | f'The choices are {self.choices[0]}, {self.choices[1]}, {self.choices[2]}, {self.choices[3]}, and {self.choices[4]}.' 32 | self.explanation_list = [explanation] \ 33 | if not isinstance(explanation, list) \ 34 | else \ 35 | explanation 36 | 37 | 38 | def __str__(self): 39 | return self.__repr__() 40 | 41 | def __repr__(self): 42 | 43 | list_ = [f"question: {self.question}"] + \ 44 | [f"choice {d}: {exp}" for d,exp in enumerate(self.choices)] + \ 45 | [f"explanation: {self.explanation}"] 46 | 47 | if self.label is not None: 48 | list_.append(f"label: {self.label}") 49 | 50 | return "\n".join(list_) 51 | 52 | 53 | 54 | def read_CQA(args, input_file, explanations_to_use, version, 55 | labels_to_use = 'label', filter_explanations = None): 56 | 57 | df = pd.read_csv(input_file) 58 | df = df.applymap(removeNonAscii) 59 | n = len(df) if not args.small_data else args.small_size 60 | num_choices = 3 if version == '1.0' else 5 61 | multi_exp = (args.condition_on_explanations and 'multi' in explanations_to_use and args.multi_explanation) 62 | # simulate_rationalized is used to pull out the predicted explanation when simulating a CAGE-Ra model 63 | simulate_rationalized = (args.condition_on_explanations and not args.multi_explanation and 'st.ra' in (labels_to_use.lower() if isinstance(labels_to_use, str) else '' )) 64 | 65 | # if test data, make sure explanations_to_use isn't ground_truth or oracle 66 | if 'test' in input_file and (explanations_to_use == 'ground_truth' or explanations_to_use == 'oracle'): 67 | explanations_to_use = 'None' 68 | 69 | ids = df['id'] 70 | questions = df['question'] 71 | choice_cols = [f'choice_{i}' for i in range(num_choices)] 72 | choices = df[choice_cols] 73 | labels = df[labels_to_use] if labels_to_use is not None else [0] * n 74 | print("using labels: %s" % labels_to_use) 75 | 76 | if explanations_to_use == 'None': 77 | explanations = [''] * n 78 | else: 79 | exp_cols = explanations_to_use 80 | try: 81 | explanations = df[exp_cols] 82 | print(f"getting explanations from {explanations_to_use}") 83 | except: 84 | if explanations_to_use == 'ground_truth': 85 | exp_cols = 'human_exp' if 'human_exp' in df.columns else 'human_expl_open-ended' 86 | elif explanations_to_use == 'oracle': 87 | exp_cols = 'human_exp' if 'human_exp' in df.columns else 'human_expl_open-ended' 88 | elif explanations_to_use == 'gpt': 89 | exp_cols = 'gpt' 90 | elif explanations_to_use == 'gpt2': 91 | exp_cols = 'gpt2' 92 | elif explanations_to_use == 'multi_gpt2': 93 | exp_cols = [f'gpt2_exps_{i}' for i in range(num_choices)] 94 | elif explanations_to_use == 't5': 95 | exp_cols = 't5-single-exp' 96 | elif explanations_to_use == 'MT_t5': 97 | exp_cols = 't5-MT-single-exp' 98 | elif explanations_to_use == 'multi_t5': 99 | exp_cols = [f't5-multi-exp-{i}' for i in range(num_choices)] 100 | elif explanations_to_use == 'MT_multi_t5': 101 | exp_cols = [f't5-MT-multi-exp-{i}' for i in range(num_choices)] 102 | elif explanations_to_use == 'MT_multi_t5_pred': 103 | exp_cols = 't5-MT-multi-exp-pred' 104 | elif explanations_to_use == 'bert_cage': 105 | exp_cols = 'bert-cage-single-exp' 106 | elif explanations_to_use == 'bert_multi_cage': 107 | exp_cols = [f'bert-cage-multi-exp-{i}' for i in range(num_choices)] 108 | elif explanations_to_use == 't5-agent-re': 109 | exp_cols = 't5-agent-re-exp' 110 | elif explanations_to_use == 't5-agent-ra': 111 | exp_cols = 't5-agent-ra-exp' 112 | # ST (task or simulation) 113 | elif 'multi-exp' in explanations_to_use and 'MT' not in explanations_to_use: 114 | exp_cols = [f't5-multi-exp-{i}-seed{args.seed}' for i in range(num_choices)] 115 | # MT (simulation) 116 | elif 'multi-exp' in explanations_to_use and 'MT' in explanations_to_use: 117 | exp_cols = [f't5-MT-multi-exp-pred-seed{args.seed}' for i in range(num_choices)] 118 | print(f"getting explanations from {exp_cols}") 119 | explanations = df[exp_cols] 120 | 121 | # pick out the predicted explanations, according to the task model's prediction 122 | if simulate_rationalized: 123 | print("picking out predicted explanations") 124 | explanations = [explanations.loc[i,exp_cols[label]] for i, label in enumerate(labels)] 125 | 126 | examples = [CQAExample(cqa_id = ids[i], 127 | question = questions[i], 128 | choices = choices.iloc[i].tolist(), 129 | explanation = explanations.iloc[i].tolist() if multi_exp else explanations[i], 130 | label = labels[i]) 131 | for i in range(n)] 132 | 133 | # filter pre-specified bad explanations (e.g. bad explanations in v1.1 data). see https://github.com/salesforce/cos-e/issues/2 134 | if filter_explanations is not None: 135 | examples = [ex for ex in examples if not ex.explanation in filter_explanations] 136 | 137 | return examples 138 | 139 | 140 | def get_tensors_for_T5_split(args, examples, tokenizer, max_seq_length : int, condition_on_explanations : bool, multi_explanation : bool, 141 | spliced_explanation_len = None, explanations_only = False): 142 | """ 143 | Converts a list of CQAExamples into features for use with T5. 144 | 145 | Spliced explanation len is used in 2-agent setup, where input_ids are spliced into with sampled explanations from a model. (need to leave enough room for this) 146 | 147 | Format: 148 | Sequence 1: "[task/explain]: What is the answer to this question? The choices are choice0, choice1, choice2." 149 | Task Sequence 2: "The answer is: {answer}" 150 | Exp. Sequence 2: "The answer is {choice} because {explanation}" 151 | 152 | Note: 153 | tensor_ids serves as input_ids to model.forward 154 | tensors_labels serves as lm_labels to model.forward 155 | 156 | Returns: list of tensors 157 | 158 | """ 159 | input_padding_id = tokenizer.pad_token_id 160 | label_padding_id = -100 161 | eos_token_id = tokenizer.eos_token_id 162 | task_prefix_ids = tokenizer.encode("task:", add_special_tokens = False) 163 | explanation_prefix_ids = tokenizer.encode("explain:", add_special_tokens = False) 164 | 165 | return_data = [] 166 | 167 | for example_index, example in enumerate(examples): 168 | 169 | # per-question variables 170 | question_str = example.question 171 | choices_str = example.choices_str 172 | answer_str = example.choices[example.label] 173 | explanation_str = example.explanation 174 | if isNaN(explanation_str): 175 | print("got nan explanation") 176 | example.explanation = '__' 177 | choice_label = example.label 178 | task_input_ids_list = [] 179 | task_output_ids_list = [] 180 | task_output_labels_list = [] 181 | explanation_context_ids_list = [] 182 | 183 | # first screen for length. want to keep input formatting as is due to tokenization differences with spacing before words (rather than adding all the ids) 184 | input_str = f"[CLS] {question_str} {choices_str} [SEP]" 185 | if spliced_explanation_len is not None: 186 | cap_length = max_seq_length-len(task_prefix_ids)-spliced_explanation_len 187 | else: 188 | cap_length = max_seq_length-len(task_prefix_ids) 189 | 190 | init_input_ids = tokenizer.encode(input_str) 191 | if len(init_input_ids) > (cap_length): 192 | over_by = len(init_input_ids) - cap_length 193 | question_tokens = tokenizer.encode(question_str) 194 | keep_up_to = len(question_tokens) - over_by - 1 # leaves buffer question mark below 195 | new_question_tokens = question_tokens[:keep_up_to] 196 | question_str = tokenizer.decode(new_question_tokens) + '?' 197 | # print("Trimmed a question by %d tokens" % (len(question_tokens) - len(new_question_tokens))) 198 | # print("OLD:", tokenizer.decode(question_tokens)) 199 | # print("NEW:", question_str) 200 | # print() 201 | 202 | # in explanations only, remove the question 203 | if explanations_only: 204 | question_str = "" 205 | 206 | # get string formats 207 | if not condition_on_explanations: 208 | input_str = f"[CLS] {question_str} {choices_str} [SEP]" 209 | if condition_on_explanations and not multi_explanation: 210 | input_str = f"[CLS] {question_str} {choices_str} [SEP] My commonsense tells me {explanation_str}" 211 | elif condition_on_explanations and multi_explanation: 212 | # make task_input_ids in answer loop below 213 | input_str = "" 214 | task_answer_str = f"The answer is: {answer_str}" 215 | explanation_output_str = f"The answer is {answer_str} because {explanation_str}" \ 216 | if multi_explanation \ 217 | else \ 218 | f"My commonsense tells me that {explanation_str}" 219 | 220 | # get token_ids 221 | _input_ids = tokenizer.encode(input_str, add_special_tokens = False) 222 | task_input_ids = task_prefix_ids + _input_ids 223 | explanation_input_ids = explanation_prefix_ids + _input_ids 224 | explanation_only_ids = tokenizer.encode(example.explanation, add_special_tokens = False) 225 | _task_answer_ids = tokenizer.encode(task_answer_str, add_special_tokens = False) 226 | _explanation_output_ids = tokenizer.encode(explanation_output_str, add_special_tokens = False) + [eos_token_id] 227 | 228 | _truncate_seq_pair(task_input_ids, [], max_seq_length) 229 | _truncate_seq_pair(explanation_input_ids, [], max_seq_length) 230 | _truncate_seq_pair(_explanation_output_ids, [], max_seq_length) 231 | _truncate_seq_pair(explanation_only_ids, [], max_seq_length) 232 | 233 | for choice_index, choice in enumerate(example.choices): 234 | 235 | if condition_on_explanations and multi_explanation: 236 | if len(example.explanation_list) > 1: 237 | explanation_str = example.explanation_list[choice_index] 238 | else: 239 | explanation_str = '' 240 | task_input_str = f"[CLS] {question_str} {choices_str} [SEP] The answer is {choice} because {explanation_str}" 241 | task_input_ids = task_prefix_ids + tokenizer.encode(task_input_str, add_special_tokens = False) 242 | _truncate_seq_pair(task_input_ids, [], max_seq_length) 243 | ids_padding = [input_padding_id] * (max_seq_length - len(task_input_ids)) 244 | task_input_ids += ids_padding 245 | task_input_ids_list.append(task_input_ids) 246 | 247 | task_output_str = f"The answer is: {choice}" 248 | _task_output_ids = tokenizer.encode(task_output_str, add_special_tokens = False) 249 | ids_padding = [input_padding_id] * (max_seq_length - len(_task_output_ids)) 250 | labels_padding = [label_padding_id] * (max_seq_length - len(_task_output_ids)) 251 | task_output_ids = _task_output_ids + ids_padding 252 | task_output_labels = _task_output_ids + labels_padding 253 | task_output_ids_list.append(task_output_ids) 254 | task_output_labels_list.append(task_output_labels) 255 | 256 | explanation_context_str = f"The answer is {choice} because" \ 257 | if multi_explanation \ 258 | else \ 259 | f"My commonsense tells me that" 260 | explanation_context_ids = tokenizer.encode(explanation_context_str, add_special_tokens = False) 261 | if choice == answer_str: 262 | context_len = len(explanation_context_ids) 263 | explanation_context_ids += [input_padding_id] * (max_seq_length - len(explanation_context_ids)) 264 | _truncate_seq_pair(explanation_context_ids, [], max_seq_length) 265 | explanation_context_ids_list.append(explanation_context_ids) 266 | 267 | # pad up to the max sequence len. NOTE input_padding_id goes on inputs to either the encoder or decoder. label_padding_id goes on lm_labels for decode 268 | padding = [input_padding_id] * (max_seq_length - len(task_input_ids)) 269 | task_input_ids += padding 270 | padding = [input_padding_id] * (max_seq_length - len(explanation_input_ids)) 271 | explanation_input_ids += padding 272 | padding = [input_padding_id] * (max_seq_length - len(explanation_only_ids)) 273 | explanation_only_ids += padding 274 | 275 | # store explanation_len for dropout/masking purposes 276 | explanation_len = len([e for e in explanation_context_ids if e != input_padding_id]) + len([e for e in explanation_only_ids if e != input_padding_id]) 277 | 278 | ids_padding = [input_padding_id] * (max_seq_length - len(_task_answer_ids)) 279 | labels_padding = [label_padding_id] * (max_seq_length - len(_task_answer_ids)) 280 | task_answer_ids = _task_answer_ids + ids_padding 281 | task_answer_labels = _task_answer_ids + labels_padding 282 | 283 | ids_padding = [input_padding_id] * (max_seq_length - len(_explanation_output_ids)) 284 | labels_padding = [label_padding_id] * (max_seq_length - len(_explanation_output_ids)) 285 | explanation_output_ids = _explanation_output_ids + ids_padding 286 | explanation_output_labels = _explanation_output_ids + labels_padding 287 | explanation_output_labels[:context_len] = [label_padding_id]*context_len # no LM loss on the explanation_context_str 288 | 289 | # make into tensors and accumulate 290 | task_input_ids = torch.tensor(task_input_ids if len(task_input_ids_list) < 1 else task_input_ids_list, dtype = torch.long) 291 | task_input_masks = (task_input_ids!=input_padding_id).float() 292 | task_answer_ids = torch.tensor(task_answer_ids, dtype = torch.long) 293 | task_answer_masks = (task_answer_ids!=input_padding_id).float() 294 | task_answer_labels = torch.tensor(task_answer_labels, dtype = torch.long) 295 | task_output_ids = torch.tensor(task_output_ids_list, dtype = torch.long) 296 | task_output_masks = (task_output_ids!=input_padding_id).float() 297 | task_output_labels = torch.tensor(task_output_labels_list, dtype = torch.long) 298 | explanation_input_ids = torch.tensor(explanation_input_ids, dtype = torch.long) 299 | explanation_input_masks = (explanation_input_ids!=input_padding_id).float() 300 | explanation_output_ids = torch.tensor(explanation_output_ids, dtype = torch.long) 301 | explanation_output_masks = (explanation_output_ids!=input_padding_id).float() 302 | explanation_output_labels = torch.tensor(explanation_output_labels, dtype = torch.long) 303 | explanation_context_ids = torch.tensor(explanation_context_ids_list, dtype = torch.long) 304 | task_choice_label = torch.tensor(choice_label, dtype = torch.long) 305 | explanation_only_ids = torch.tensor(explanation_only_ids, dtype = torch.long) 306 | explanation_len = torch.tensor(explanation_len).long() 307 | 308 | data_point = [task_input_ids, task_input_masks, 309 | task_answer_ids, task_answer_masks, task_answer_labels, 310 | task_output_ids, task_output_masks, task_output_labels, task_choice_label, 311 | explanation_input_ids, explanation_input_masks, 312 | explanation_output_ids, explanation_output_masks, explanation_output_labels, 313 | explanation_context_ids, explanation_only_ids, explanation_len] 314 | return_data.append(data_point) 315 | 316 | # now reshape list of lists of tensors to list of tensors 317 | n_cols = len(return_data[0]) 318 | return_data = [torch.stack([data_point[j] for data_point in return_data], dim=0) for j in range(n_cols)] 319 | 320 | return return_data 321 | 322 | 323 | 324 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 325 | """Truncates a sequence pair in place to the maximum length.""" 326 | 327 | # This is a simple heuristic which will always truncate the longer sequence 328 | # one token at a time. This makes more sense than truncating an equal percent 329 | # of tokens from each, since if one sequence is very short then each token 330 | # that's truncated likely contains more information than a longer sequence. 331 | while True: 332 | total_length = len(tokens_a) + len(tokens_b) 333 | if total_length <= max_length: 334 | break 335 | if len(tokens_a) > len(tokens_b): 336 | tokens_a.pop() 337 | else: 338 | tokens_b.pop() -------------------------------------------------------------------------------- /sim_experiments/README.md: -------------------------------------------------------------------------------- 1 | ## Reproducing Experiments 2 | 3 | Training task models and simulators is done with the `main.py` script, and running particular experiments can be done with `run_tasks.py` in the manner specified below. We give commands corresponding to the four graphical models we evaluate, simulators for each of these models, and the two-agent experiments with SGD optimization. 4 | 5 | **Note** that for all modeling experiments below, `gpu`, `save_dir`, and `cache_dir` must be provided as args to argpase (recommended to make save_dir and cache_dir in same directory). `-b` and `-g` refer to train batch size and gradient accumulation factors, respectively (effective batch size is their product). 6 | 7 | Lastly, simply replace all instances of "QA" with "NLI" to run the analogous experiments for e-SNLI instead of CoS-E (and adjust the effective batch size to 36). 8 | 9 | *Task Model*: 10 | `python run_tasks.py --gpu gpu -e QA.task -b 4 -g 3 --save_dir save_dir --cache_dir cache_dir` 11 | 12 | *Human Simulator*: 13 | `python run_tasks.py --gpu gpu -e QA.SIM.human -b 4 -g 3 --save_dir save_dir --cache_dir cache_dir` 14 | 15 | *MT-Re*: 16 | Task model: `python run_tasks.py --gpu gpu -e QA.CLM.reason.MT -b 4 -g 3 --save_dir save_dir --cache_dir cache_dir` 17 | Simulator: `python run_tasks.py --gpu gpu -e QA.SIM.MT.RE -b 4 -g 3 --save_dir save_dir --cache_dir cache_dir` 18 | 19 | *MT-Ra*: 20 | Task model: `python run_tasks.py --gpu gpu -e QA.CLM.rationalize.MT -b 4 -g 3 --save_dir save_dir --cache_dir cache_dir` 21 | Simulator: `python run_tasks.py --gpu gpu -e QA.SIM.MT.RA -b 4 -g 3 --save_dir save_dir --cache_dir cache_dir` 22 | 23 | *ST-Re*: 24 | Generator: `python run_tasks.py --gpu gpu -e QA.CLM.reason -b 6 -g 6 --save_dir save_dir --cache_dir cache_dir ` 25 | Task model: `python run_tasks.py --gpu gpu -e QA.ST.RE -b 4 -g 3 --save_dir save_dir --cache_dir cache_dir ` 26 | Simulator: `python run_tasks.py --gpu gpu -e QA.SIM.ST.RE -b 4 -g 3 --save_dir save_dir --cache_dir cache_dir` 27 | 28 | *ST-Ra*: 29 | Generator: `python run_tasks.py --gpu gpu -e QA.CLM.rationalize -b 6 -g 6 --save_dir save_dir --cache_dir cache_dir ` 30 | Task model: `python run_tasks.py --gpu gpu -e QA.ST.RA -b 4 -g 3 --save_dir save_dir --cache_dir cache_dir` 31 | Simulator: `python run_tasks.py --gpu gpu -e QA.SIM.ST.RA -b 4 -g 3 --save_dir save_dir --cache_dir cache_dir ` 32 | 33 | Note that in our two-agent experiments we initialize the models with pretrained task models and simulators (using `--task_prefinetuned_name` and `sim_prefinetuned_name` args). 34 | 35 | *Two-agent CQA Reasoning*: 36 | `python T5-2-agent_main.py --gpu gpu --model_name pass.reason --task_pretrained_name t5-base --human_exp_coef .15 --task_coef .35 --suppress_coef .2 --X_coef .5 --train_batch_size 1 --grad_accumulation_factor 12 --multi_explanation false --save_dir save_dir --cache_dir cache_dir --data_dir data/v1.0` 37 | 38 | *Two-agent CQA Rationalizing*: 39 | `python T5-2-agent_main.py --gpu gpu --model_name pass.rationalize --task_pretrained_name t5-base --human_exp_coef .15 --task_coef .35 --suppress_coef .2 --X_coef .5 --train_batch_size 1 --grad_accumulation_factor 12 --multi_explanation true --save_dir save_dir --cache_dir cache_dir --data_dir data/v1.0` 40 | 41 | *Two-agent NLI Reasoning*: 42 | `python T5-2-agent_main.py --gpu gpu --model_name pass.reason --task_pretrained_name t5-base --human_exp_coef .15 --task_coef .35 --suppress_coef .2 --X_coef .4 --E_coef .2 --train_batch_size 1 --grad_accumulation_factor 12 --multi_explanation false --save_dir save_dir --cache_dir cache_dir --data_dir data/e-SNLI-data` 43 | 44 | *Two-agent NLI Rationalizing*: 45 | `python T5-2-agent_main.py --gpu gpu --model_name pass.rationalize --task_pretrained_name t5-base --human_exp_coef .15 --task_coef .35 --suppress_coef .2 --X_coef .4 --E_coef .2 --train_batch_size 1 --grad_accumulation_factor 12 --multi_explanation true --save_dir save_dir --cache_dir cache_dir --data_dir data/e-SNLI-data` 46 | 47 | 48 | ## Computing LAS 49 | 50 | We compute LAS scores with the `compute_sim.py` script. Here, `gpu` and `base_dir` must be provided as arguments. `base_dir` should include a `saved_models` and `cached_models` directories. For each condition, LAS computation is possible after running the respective experiments from above. Note `split_name` is `dev` for CQA to compare with human provided explanations, but should be `test` for SNLI. 51 | 52 | *Human Simulator*: 53 | `python compute_sim.py --model_name sim.human --explanations_to_use ground_truth --gpu gpu --split_name dev --data QA --seed seed --bootstrap` 54 | 55 | *MT-Re*: 56 | `python compute_sim.py --model_name sim.MT.RE --explanations_to_use t5-MT-single-exp-seed21 --gpu gpu --split_name dev --data QA --seed 21 --bootstrap --labels_to_use preds_QA_t5-base_MT.RE_seed21` 57 | 58 | *MT-Ra*: 59 | `python compute_sim.py --model_name sim.MT.RA --explanations_to_use t5-MT-multi-exp-pred-seed21 --gpu gpu --split_name dev --data QA --seed 21 --bootstrap -s aws --overwrite --labels_to_use preds_QA_t5-base_MT.RA_seed21` 60 | 61 | `python compute_sim.py --model_name sim.ST.RE --explanations_to_use t5-single-exp-seed21 --gpu gpu --split_name dev --data QA --seed 21 --bootstrap --labels_to_use preds_QA_t5-base_ST.RE_seed21` 62 | 63 | *ST-Ra*: 64 | `python compute_sim.py --model_name sim.ST.RA --explanations_to_use t5-multi-exp-seed21 --gpu gpu --split_name dev --data QA --seed 21 --bootstrap --labels_to_use preds_QA_t5-base_ST.RA_seed21` 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /sim_experiments/causal_estimation.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "causal estimation" 3 | author: "Peter Hase" 4 | output: pdf_document 5 | --- 6 | 7 | ```{r setup, include=FALSE} 8 | library(tidyverse) 9 | library(arm) 10 | library(extrafont) 11 | ``` 12 | 13 | ```{r read data, warning=FALSE, message=FALSE} 14 | qa_path <- "cqa-dev-qual.csv" 15 | nli_path <- "nli-test-qual.tsv" 16 | qa_data <- read_csv(qa_path) 17 | nli_data <- read_tsv(nli_path) 18 | 19 | ``` 20 | 21 | 22 | ```{r NLI calibration} 23 | 24 | nli_data <- nli_data %>% 25 | mutate(re_ye_correct = 1*(re_ye==re_label), 26 | ra_ye_correct = 1*(ra_ye==ra_label), 27 | cage_re_ye_correct = 1*(cage_re_ye==cage_re_label), 28 | cage_ra_ye_correct = 1*(cage_ra_ye==cage_ra_label), 29 | human_ye_correct = 1*(human_ye==label), 30 | human_LAS=1*(human_yxe==label)-(human_yx==label), 31 | re_LAS=1*(re_yxe==re_label)-(re_yx==re_label), 32 | ra_LAS=1*(ra_yxe==ra_label)-(ra_yx==ra_label), 33 | cage_re_LAS=1*(cage_re_yxe==cage_re_label)-(cage_re_yx==cage_re_label), 34 | cage_ra_LAS=1*(cage_ra_yxe==cage_ra_label)-(cage_ra_yx==cage_ra_label)) 35 | 36 | yxe <- c(nli_data$human_yxe, nli_data$re_yxe, nli_data$ra_yxe, nli_data$cage_re_yxe, nli_data$cage_ra_yxe) 37 | label <- c(nli_data$label, nli_data$re_label, nli_data$ra_label, nli_data$cage_re_label, nli_data$cage_ra_label) 38 | ye <- c(nli_data$human_ye, nli_data$re_ye, nli_data$ra_ye, nli_data$cage_re_ye, nli_data$cage_ra_ye) 39 | yx <- c(nli_data$human_yx, nli_data$re_yx, nli_data$ra_yx, nli_data$cage_re_yx, nli_data$cage_ra_yx) 40 | ye_probs <- c(nli_data$human_ye_probs, nli_data$re_ye_probs, nli_data$ra_ye_probs, nli_data$cage_re_ye_probs, nli_data$cage_ra_ye_probs) 41 | ye_correct <- 1*(label==ye) 42 | yxe_correct <- 1*(label==yxe) 43 | yx_correct <- 1*(label==yx) 44 | model <- c(rep('human',9824),rep('re',9824),rep('ra',9824),rep('cage_re',9824),rep('cage_ra',9824)) 45 | nli_gather <- tibble(yxe=yxe, 46 | label=label, 47 | ye=ye, 48 | ye_prob=ye_probs, 49 | ye_correct=ye_correct, 50 | yxe_correct=yxe_correct, 51 | yx_correct=yx_correct, 52 | model=model, 53 | LAS=yxe_correct - yx_correct) 54 | 55 | hist(nli_data$human_ye_probs) 56 | hist(nli_data$re_ye_probs) 57 | hist(nli_data$ra_ye_probs) 58 | hist(nli_data$cage_re_ye_probs) 59 | hist(nli_data$cage_ra_ye_probs) 60 | 61 | hist(ye_probs) 62 | binnedplot(nli_data$human_ye_probs, nli_data$human_ye_correct) 63 | binnedplot(nli_data$re_ye_probs, nli_data$re_ye_correct) 64 | binnedplot(nli_data$ra_ye_probs, nli_data$ra_ye_correct) 65 | binnedplot(nli_data$cage_re_ye_probs, nli_data$cage_re_ye_correct) 66 | binnedplot(nli_data$cage_ra_ye_probs, nli_data$cage_ra_ye_correct) 67 | 68 | model <- glm(ye_correct ~ model * ye_prob, data=nli_gather, family = 'binomial') 69 | nli_cal_model = model 70 | 71 | new_data <- tibble(ye_prob = seq(.3,.4,.001), model = 'human') 72 | preds <- predict(model, new_data, type='response') 73 | new_data <- new_data %>% 74 | mutate(preds=preds) 75 | new_data %>% 76 | ggplot(aes(ye_prob, preds)) + 77 | geom_point() 78 | new_data <- tibble(ye_prob = seq(.3,.4,.001), model = 're') 79 | preds <- predict(model, new_data, type='response') 80 | new_data <- new_data %>% 81 | mutate(preds=preds) 82 | new_data %>% 83 | ggplot(aes(ye_prob, preds)) + 84 | geom_point() 85 | new_data <- tibble(ye_prob = seq(.3,.4,.001), model = 'ra') 86 | preds <- predict(model, new_data, type='response') 87 | new_data <- new_data %>% 88 | mutate(preds=preds) 89 | new_data %>% 90 | ggplot(aes(ye_prob, preds)) + 91 | geom_point() 92 | new_data <- tibble(ye_prob = seq(.3,.4,.001), model = 'cage_re') 93 | preds <- predict(model, new_data, type='response') 94 | new_data <- new_data %>% 95 | mutate(preds=preds) 96 | new_data %>% 97 | ggplot(aes(ye_prob, preds)) + 98 | geom_point() 99 | new_data <- tibble(ye_prob = seq(.3,.4,.001), model = 'cage_ra') 100 | preds <- predict(model, new_data, type='response') 101 | new_data <- new_data %>% 102 | mutate(preds=preds) 103 | new_data %>% 104 | ggplot(aes(ye_prob, preds)) + 105 | geom_point() 106 | 107 | preds = predict(model, nli_gather, type='response') 108 | nli_gather <- nli_gather %>% 109 | mutate(preds = preds) 110 | binnedplot(nli_gather$preds, nli_gather$ye_correct) 111 | binnedplot(nli_data$human_ye_prob, nli_data$human_ye_correct) 112 | binnedplot(nli_data$re_ye_prob, nli_data$re_ye_correct) 113 | binnedplot(nli_data$ra_ye_prob, nli_data$ra_ye_correct) 114 | binnedplot(nli_data$cage_re_ye_prob, nli_data$cage_re_ye_correct) 115 | binnedplot(nli_data$cage_ra_ye_prob, nli_data$cage_ra_ye_correct) 116 | 117 | 118 | 119 | ``` 120 | 121 | 122 | ```{r nli binning} 123 | 124 | nli_gather <- nli_gather %>% 125 | mutate(bins100=bin(nli_gather$preds, nbins=100, method = 'length'), 126 | bins4=bin(nli_gather$preds, nbins=4, method = 'length'), 127 | bins3=bin(nli_gather$preds, nbins=3, method = 'length'), 128 | bins2=bin(nli_gather$preds, nbins=2, method = 'length'), 129 | ) 130 | nli_gather %>% 131 | group_by(bins4) %>% 132 | summarise(n = n(), 133 | ye_mean=mean(ye_correct)) 134 | 135 | nli_gather <- nli_gather %>% 136 | mutate(bins10= ifelse(between(preds, 0, .1), 0, 137 | ifelse(between(preds, .1, .2), 1, 138 | ifelse(between(preds, .2, .3), 2, 139 | ifelse(between(preds, .3, .4), 3, 140 | ifelse(between(preds, .4, .5), 4, 141 | ifelse(between(preds, .5, .6), 5, 142 | ifelse(between(preds, .6, .7), 6, 143 | ifelse(between(preds, .7, .8), 7, 144 | ifelse(between(preds, .8, .9), 8, 145 | ifelse(between(preds, .9, 1), 9, NA 146 | ))))))))))) 147 | nli_gather %>% 148 | group_by(bins10) %>% 149 | summarise(n = n(), 150 | ye_mean=mean(ye_correct)) 151 | 152 | 153 | ``` 154 | 155 | 156 | ```{r nli causal estimation} 157 | 158 | # two bins approach 159 | nli_gather %>% 160 | group_by(model, ye_correct) %>% 161 | summarise( 162 | mean(yxe_correct), 163 | mean(yx_correct), 164 | LAS = mean(yxe_correct)-mean(yx_correct), 165 | n=n()) %>% 166 | ungroup() %>% 167 | group_by(model) %>% 168 | summarise(LAS=mean(LAS), 169 | n= sum(n)) 170 | 171 | # multiple bins 172 | nli_gather %>% 173 | group_by(model, bins2) %>% 174 | summarise( 175 | mean(yxe_correct), 176 | mean(yx_correct), 177 | LAS = mean(yxe_correct)-mean(yx_correct), 178 | n=n()) %>% 179 | ungroup() %>% 180 | group_by(model) %>% 181 | summarise(LAS=mean(LAS), 182 | n= sum(n)) 183 | 184 | 185 | nli_gather %>% 186 | group_by(model, bins4) %>% 187 | summarise( 188 | mean(yxe_correct), 189 | mean(yx_correct), 190 | LAS = mean(yxe_correct)-mean(yx_correct), 191 | n=n()) %>% 192 | ungroup() %>% 193 | group_by(model) %>% 194 | summarise(LAS=mean(LAS), 195 | n= sum(n)) 196 | 197 | 198 | nli_gather %>% 199 | group_by(model, bins100) %>% 200 | summarise( 201 | mean(yxe_correct), 202 | mean(yx_correct), 203 | LAS = mean(LAS,na.rm=TRUE), 204 | n=n()) %>% 205 | ungroup() %>% 206 | group_by(model) %>% 207 | summarise(LAS=mean(LAS), 208 | n= sum(n)) 209 | 210 | nli_gather %>% 211 | group_by(model, bins4) %>% 212 | summarise( 213 | mean(yxe_correct), 214 | mean(yx_correct), 215 | LAS = mean(LAS), 216 | n=n()) %>% 217 | ungroup() %>% 218 | ggplot(aes(bins4, LAS, color=model)) + 219 | geom_boxplot() 220 | 221 | nli_gather %>% 222 | ggplot(aes(preds, LAS, color=model)) + 223 | geom_smooth(se=TRUE) 224 | 225 | nli_gather %>% 226 | filter(preds > 0) %>% 227 | ggplot(aes(preds, LAS, color=model)) + 228 | geom_smooth(se=FALSE, method='lm', formula = y ~ poly(x,1)) + 229 | geom_hline(aes(yintercept=0)) + 230 | xlab("prob. leaking") 231 | 232 | nli_gather %>% 233 | group_by(bins10) %>% 234 | summarise( 235 | mean(yxe_correct), 236 | mean(yx_correct), 237 | LAS = mean(yxe_correct)-mean(yx_correct), 238 | n=n()) %>% 239 | arrange(bins10) 240 | 241 | nli_gather %>% 242 | filter(model!='re') %>% 243 | group_by(bins10) %>% 244 | summarise( 245 | mean(yxe_correct), 246 | mean(yx_correct), 247 | LAS = mean(yxe_correct)-mean(yx_correct), 248 | n=n()) %>% 249 | arrange(bins10) 250 | 251 | 252 | ``` 253 | 254 | 255 | 256 | ```{r QA calibration} 257 | 258 | qa_data <- qa_data %>% 259 | mutate(re_ye_correct = 1*(re_ye==re_label), 260 | ra_ye_correct = 1*(ra_ye==ra_label), 261 | cage_re_ye_correct = 1*(cage_re_ye==cage_re_label), 262 | cage_ra_ye_correct = 1*(cage_ra_ye==cage_ra_label), 263 | human_ye_correct = 1*(human_ye==label), 264 | human_LAS=1*(human_yxe==label)-(human_yx==label), 265 | re_LAS=1*(re_yxe==re_label)-(re_yx==re_label), 266 | ra_LAS=1*(ra_yxe==ra_label)-(ra_yx==ra_label), 267 | cage_re_LAS=1*(cage_re_yxe==cage_re_label)-(cage_re_yx==cage_re_label), 268 | cage_ra_LAS=1*(cage_ra_yxe==cage_ra_label)-(cage_ra_yx==cage_ra_label)) 269 | 270 | yxe <- c(qa_data$human_yxe, qa_data$re_yxe, qa_data$ra_yxe, qa_data$cage_re_yxe, qa_data$cage_ra_yxe) 271 | label <- c(qa_data$label, qa_data$re_label, qa_data$ra_label, qa_data$cage_re_label, qa_data$cage_ra_label) 272 | ye <- c(qa_data$human_ye, qa_data$re_ye, qa_data$ra_ye, qa_data$cage_re_ye, qa_data$cage_ra_ye) 273 | yx <- c(qa_data$human_yx, qa_data$re_yx, qa_data$ra_yx, qa_data$cage_re_yx, qa_data$cage_ra_yx) 274 | ye_probs <- c(qa_data$human_ye_prob, qa_data$re_ye_prob, qa_data$ra_ye_prob, qa_data$cage_re_ye_prob, qa_data$cage_ra_ye_prob) 275 | ye_correct <- 1*(label==ye) 276 | yxe_correct <- 1*(label==yxe) 277 | yx_correct <- 1*(label==yx) 278 | model <- c(rep('human',950),rep('re',950),rep('ra',950),rep('cage_re',950),rep('cage_ra',950)) 279 | qa_gather <- tibble(yxe=yxe, 280 | label=label, 281 | ye=ye, 282 | ye_prob=ye_probs, 283 | ye_correct=ye_correct, 284 | yxe_correct=yxe_correct, 285 | yx_correct=yx_correct, 286 | model=model, 287 | LAS=yxe_correct - yx_correct) 288 | 289 | hist(qa_data$human_ye_prob) 290 | hist(qa_data$re_ye_prob) 291 | hist(qa_data$ra_ye_prob) 292 | hist(qa_data$cage_re_ye_prob) 293 | hist(qa_data$cage_ra_ye_prob) 294 | 295 | hist(ye_probs) 296 | binnedplot(qa_data$human_ye_prob, qa_data$human_ye_correct) 297 | binnedplot(qa_data$re_ye_prob, qa_data$re_ye_correct) 298 | binnedplot(qa_data$ra_ye_prob, qa_data$ra_ye_correct) 299 | binnedplot(qa_data$cage_re_ye_prob, qa_data$cage_re_ye_correct) 300 | binnedplot(qa_data$cage_ra_ye_prob, qa_data$cage_ra_ye_correct) 301 | 302 | model <- glm(ye_correct ~ model * ye_prob, data=qa_gather, family = 'binomial') 303 | qa_cal_model = model 304 | 305 | new_data <- tibble(ye_prob = seq(.3,.4,.001), model = 'human') 306 | preds <- predict(model, new_data, type='response') 307 | new_data <- new_data %>% 308 | mutate(preds=preds) 309 | new_data %>% 310 | ggplot(aes(ye_prob, preds)) + 311 | geom_point() 312 | new_data <- tibble(ye_prob = seq(.3,.4,.001), model = 're') 313 | preds <- predict(model, new_data, type='response') 314 | new_data <- new_data %>% 315 | mutate(preds=preds) 316 | new_data %>% 317 | ggplot(aes(ye_prob, preds)) + 318 | geom_point() 319 | new_data <- tibble(ye_prob = seq(.3,.4,.001), model = 'ra') 320 | preds <- predict(model, new_data, type='response') 321 | new_data <- new_data %>% 322 | mutate(preds=preds) 323 | new_data %>% 324 | ggplot(aes(ye_prob, preds)) + 325 | geom_point() 326 | new_data <- tibble(ye_prob = seq(.3,.4,.001), model = 'cage_re') 327 | preds <- predict(model, new_data, type='response') 328 | new_data <- new_data %>% 329 | mutate(preds=preds) 330 | new_data %>% 331 | ggplot(aes(ye_prob, preds)) + 332 | geom_point() 333 | new_data <- tibble(ye_prob = seq(.3,.4,.001), model = 'cage_ra') 334 | preds <- predict(model, new_data, type='response') 335 | new_data <- new_data %>% 336 | mutate(preds=preds) 337 | new_data %>% 338 | ggplot(aes(ye_prob, preds)) + 339 | geom_point() 340 | 341 | preds = predict(model, qa_gather, type='response') 342 | qa_gather <- qa_gather %>% 343 | mutate(preds = preds) 344 | binnedplot(qa_gather$preds, qa_gather$ye_correct) 345 | binnedplot(nli_data$human_ye_prob, nli_data$human_ye_correct) 346 | binnedplot(nli_data$re_ye_prob, nli_data$re_ye_correct) 347 | binnedplot(nli_data$ra_ye_prob, nli_data$ra_ye_correct) 348 | binnedplot(nli_data$cage_re_ye_prob, nli_data$cage_re_ye_correct) 349 | binnedplot(nli_data$cage_ra_ye_prob, nli_data$cage_ra_ye_correct) 350 | 351 | plot(model) 352 | 353 | 354 | ``` 355 | 356 | 357 | 358 | 359 | ```{r qa binning} 360 | 361 | qa_gather <- qa_gather %>% 362 | mutate(bins100=bin(qa_gather$preds, nbins=100, method = 'length'), 363 | bins4=bin(qa_gather$preds, nbins=4, method = 'length'), 364 | bins3=bin(qa_gather$preds, nbins=3, method = 'length'), 365 | bins2=bin(qa_gather$preds, nbins=2, method = 'length'), 366 | ) 367 | qa_gather %>% 368 | group_by(bins4) %>% 369 | summarise(n = n(), 370 | ye_mean=mean(ye_correct)) 371 | 372 | qa_gather <- qa_gather %>% 373 | mutate(bins10= ifelse(between(preds, 0, .1), 0, 374 | ifelse(between(preds, .1, .2), 1, 375 | ifelse(between(preds, .2, .3), 2, 376 | ifelse(between(preds, .3, .4), 3, 377 | ifelse(between(preds, .4, .5), 4, 378 | ifelse(between(preds, .5, .6), 5, 379 | ifelse(between(preds, .6, .7), 6, 380 | ifelse(between(preds, .7, .8), 7, 381 | ifelse(between(preds, .8, .9), 8, 382 | ifelse(between(preds, .9, 1), 9, NA 383 | ))))))))))) 384 | qa_gather %>% 385 | group_by(bins10) %>% 386 | summarise(n = n(), 387 | ye_mean=mean(ye_correct)) 388 | 389 | 390 | ``` 391 | 392 | 393 | ```{r qa causal estimation} 394 | 395 | qa_gather %>% 396 | group_by(model, ye_correct) %>% 397 | summarise( 398 | mean(yxe_correct), 399 | mean(yx_correct), 400 | LAS = mean(yxe_correct)-mean(yx_correct), 401 | n=n()) %>% 402 | ungroup() %>% 403 | group_by(model) %>% 404 | summarise(LAS=mean(LAS), 405 | n= sum(n)) 406 | 407 | qa_gather %>% 408 | group_by(model, bins2) %>% 409 | summarise( 410 | mean(yxe_correct), 411 | mean(yx_correct), 412 | LAS = mean(yxe_correct)-mean(yx_correct), 413 | n=n()) %>% 414 | ungroup() %>% 415 | group_by(model) %>% 416 | summarise(LAS=mean(LAS), 417 | n= sum(n)) 418 | 419 | 420 | qa_gather %>% 421 | group_by(model, bins4) %>% 422 | summarise( 423 | mean(yxe_correct), 424 | mean(yx_correct), 425 | LAS = mean(yxe_correct)-mean(yx_correct), 426 | n=n()) %>% 427 | ungroup() %>% 428 | group_by(model) %>% 429 | summarise(LAS=mean(LAS), 430 | n= sum(n)) 431 | 432 | qa_gather %>% 433 | group_by(model, bins10) %>% 434 | summarise( 435 | mean(yxe_correct), 436 | mean(yx_correct), 437 | LAS = mean(yxe_correct)-mean(yx_correct), 438 | n=n()) %>% 439 | ungroup() %>% 440 | group_by(model) %>% 441 | summarise(LAS=mean(LAS), 442 | yxe=mean(yxe_correct), 443 | n= sum(n)) 444 | 445 | 446 | qa_gather %>% 447 | group_by(model, bins100) %>% 448 | summarise( 449 | mean(yxe_correct), 450 | mean(yx_correct), 451 | LAS = mean(LAS,na.rm=TRUE), 452 | n=n()) %>% 453 | ungroup() %>% 454 | group_by(model) %>% 455 | summarise(LAS=mean(LAS), 456 | n= sum(n)) 457 | 458 | qa_gather %>% 459 | group_by(model, bins4) %>% 460 | summarise( 461 | mean(yxe_correct), 462 | mean(yx_correct), 463 | LAS = mean(LAS), 464 | n=n()) %>% 465 | ungroup() %>% 466 | ggplot(aes(bins4, LAS, color=model)) + 467 | geom_boxplot() 468 | 469 | qa_gather %>% 470 | ggplot(aes(preds, LAS, color=model)) + 471 | geom_smooth(se=TRUE) 472 | 473 | qa_gather %>% 474 | filter(preds > 0) %>% 475 | ggplot(aes(preds, LAS, color=model)) + 476 | geom_smooth(se=FALSE, method='lm', formula = y ~ poly(x,2)) + 477 | geom_hline(aes(yintercept=0)) + 478 | xlab("prob. leaking") 479 | 480 | qa_gather %>% 481 | group_by(bins10) %>% 482 | summarise( 483 | mean(yxe_correct), 484 | mean(yx_correct), 485 | LAS = mean(yxe_correct)-mean(yx_correct), 486 | n=n()) %>% 487 | arrange(bins10) 488 | 489 | qa_gather %>% 490 | filter(model!='re') %>% 491 | group_by(bins10) %>% 492 | summarise( 493 | mean(yxe_correct), 494 | mean(yx_correct), 495 | LAS = mean(yxe_correct)-mean(yx_correct), 496 | n=n()) %>% 497 | arrange(bins10) 498 | 499 | ``` 500 | 501 | 502 | 503 | ```{r plots} 504 | 505 | 506 | nli_gather %>% 507 | filter(preds > 0) %>% 508 | ggplot(aes(preds, LAS, color=model)) + 509 | geom_smooth(se=FALSE, method='lm', formula = y ~ poly(x,2)) + 510 | geom_hline(aes(yintercept=0)) + 511 | ylim(c(-.4,.1)) + 512 | labs(y="Avg. \n Effect", 513 | x = "Leakage Probability", 514 | title = "Human Ratings by Simulator Predictions") + 515 | theme_classic() + 516 | theme(axis.text.x = element_text(family = "Times New Roman", 517 | size = 16, 518 | color = "black"), 519 | axis.text.y = element_text(family = "Times New Roman", 520 | size = 16, 521 | color = "black"), 522 | axis.title.x = element_text(family = "Times New Roman", 523 | size=16), 524 | axis.title.y = element_text(family = "Times New Roman", 525 | size=18, 526 | angle=0, 527 | vjust = .5, 528 | color = "white"), 529 | plot.title = element_text(family = "Times New Roman", 530 | size = 18, 531 | hjust = .5)) 532 | 533 | 534 | 535 | 536 | ``` 537 | 538 | 539 | 540 | -------------------------------------------------------------------------------- /sim_experiments/classes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | class Report(): 4 | """Report stores evaluation results during the training process as text files.""" 5 | 6 | def __init__(self, args, file_path, score_names): 7 | self.fn = file_path 8 | self.args = args 9 | self.text = '' 10 | 11 | # write input arguments at the top 12 | self.text += 'Input: python %s %s \n\n' % \ 13 | (sys.argv[0], 14 | ' '.join([arg for arg in sys.argv[1:]])) 15 | 16 | # make header 17 | header = 'epoch |' 18 | for n, score_name in enumerate(score_names): 19 | header += ' %15s ' % score_name 20 | if n < len(score_names) - 1: header += '|' 21 | self.header = header 22 | 23 | # write header 24 | self.blank_line = '-'*len(header) 25 | self.text += self.blank_line + \ 26 | f"\nTraining report for model: {args.model_name}" + \ 27 | '\n' + self.blank_line + \ 28 | '\n' 29 | self.text += header 30 | 31 | 32 | def write_epoch_scores(self, epoch, scores): 33 | # write scores 34 | self.text += '\n%5s |' % str(epoch) 35 | for n, score in enumerate(scores.values()): 36 | self.text += ' %15s ' % ('%1.2f' % score) 37 | if n < len(scores) - 1: self.text += '|' 38 | self.__save() 39 | 40 | def write_final_score(self, args, final_score_str, time_msg=None): 41 | self.text += '\n' + self.blank_line 42 | self.text += '\n%s' % final_score_str 43 | self.text += '\n' + self.blank_line + '\n' 44 | 45 | if time_msg is not None: 46 | self.text += '\n%s\n' % time_msg 47 | 48 | self.text += '\n' 49 | self.write_all_arguments(args) 50 | 51 | self.__save() 52 | 53 | def write_msg(self, msg): 54 | self.text += self.blank_line 55 | self.text += msg 56 | self.__save() 57 | 58 | def write_all_arguments(self, args): 59 | self.text += "\nAll arguments:\n" 60 | self.text += '\n'.join(['\t' + hp for hp in str(args).replace('Namespace(', '').replace(')', '').split(', ')]) 61 | self.__save() 62 | 63 | def print_epoch_scores(self, epoch, scores, time_msg=None): 64 | epoch_text = ' epoch |' 65 | for n, score_name in enumerate(scores.keys()): 66 | epoch_text += ' %15s ' % score_name 67 | if n < len(scores) - 1: epoch_text += '|' 68 | epoch_text += '\n %5s |' % ('%d'% epoch) 69 | for n, score in enumerate(scores.values()): 70 | epoch_text += ' %15s ' % ('%1.2f' % score) 71 | if n < len(scores) - 1: epoch_text += '|' 72 | print(epoch_text + '\n') 73 | 74 | def full_print(self): 75 | print('\n' + self.text + '\n') 76 | 77 | def __save(self): 78 | if self.fn is not None: 79 | with open(self.fn, mode='w') as text_file: 80 | text_file.write(self.text) 81 | 82 | -------------------------------------------------------------------------------- /sim_experiments/compute_sim.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import time 4 | import os 5 | import argparse 6 | from utils import str2bool 7 | 8 | 9 | def run_analysis(args, gpu, data, model_name, explanations_to_use, labels_to_use, seed, split_name, model_size): 10 | ''' 11 | compute sim metric for a model by writing to file (or checking if these in data) 12 | ''' 13 | if data == 'QA': 14 | extension = 'csv' 15 | sep = ',' 16 | folder = 'data/v1.0' 17 | elif data == 'NLI': 18 | extension = 'tsv' 19 | sep = '\t' 20 | folder = 'data/e-SNLI-data' 21 | save_dir = os.path.join(args.base_dir, 'saved_models') 22 | cache_dir = os.path.join(args.base_dir, 'cached_models') 23 | pretrained_name = args.task_pretrained_name + '-' + model_size 24 | train_file = os.path.join(folder, 'train.%s' % extension) 25 | dev_file = os.path.join(folder, 'dev.%s' % extension) 26 | test_file = os.path.join(folder, 'test.%s' % extension) 27 | write_base = 'preds' 28 | xe_col = '%s_%s_%s_%s_seed%s_XE' % (write_base, data, pretrained_name, model_name, seed) 29 | e_col = '%s_%s_%s_%s_seed%s_E' % (write_base, data, pretrained_name, model_name, seed) 30 | x_col = '%s_%s_%s_%s_seed%s_X' % (write_base, data, pretrained_name, model_name, seed) 31 | 32 | train = pd.read_csv(train_file, sep=sep) 33 | dev = pd.read_csv(dev_file, sep=sep) 34 | test = pd.read_csv(test_file, sep=sep) 35 | to_use = dev if split_name == 'dev' else test 36 | script = 'main' 37 | if args.small_data: 38 | small_data_add = '-s -ss 100 ' 39 | else: 40 | small_data_add = '' 41 | if xe_col not in to_use.columns or args.overwrite: 42 | print("\nWriting XE predictions...") 43 | os.system(f"python {script}.py --gpu {gpu} --model_name {model_name} --do_explain false --task_pretrained_name {pretrained_name} --multi_explanation false " 44 | f"--data_dir {folder} --condition_on_explanations true --explanations_to_use {explanations_to_use} " 45 | f"--dev_batch_size 20 " 46 | f"--labels_to_use {labels_to_use} --do_train false --do_eval false --write_predictions --preds_suffix XE " 47 | f"--save_dir {save_dir} --cache_dir {cache_dir} --seed {seed} {small_data_add}" 48 | ) 49 | if x_col not in to_use.columns or args.overwrite: 50 | print("Writing X predictions...") 51 | os.system(f"python {script}.py --gpu {gpu} --model_name {model_name} --do_explain false --task_pretrained_name {pretrained_name} --multi_explanation false " 52 | f"--data_dir {folder} --condition_on_explanations false " 53 | f"--dev_batch_size 20 " 54 | f"--labels_to_use {labels_to_use} --do_train false --do_eval false --write_predictions --preds_suffix X " 55 | f"--save_dir {save_dir} --cache_dir {cache_dir} --seed {seed} {small_data_add}" 56 | ) 57 | if e_col not in to_use.columns or args.overwrite: 58 | print("Writing E predictions...") 59 | os.system(f"python {script}.py --gpu {gpu} --model_name {model_name} --do_explain false --task_pretrained_name {pretrained_name} --multi_explanation false " 60 | f"--data_dir {folder} --condition_on_explanations true --explanations_to_use {explanations_to_use} --explanations_only true " 61 | f"--dev_batch_size 20 " 62 | f"--labels_to_use {labels_to_use} --do_train false --do_eval false --write_predictions --preds_suffix E " 63 | f"--save_dir {save_dir} --cache_dir {cache_dir} --seed {seed} {small_data_add}" 64 | ) 65 | train = pd.read_csv(train_file, sep=sep) 66 | dev = pd.read_csv(dev_file, sep=sep) 67 | test = pd.read_csv(test_file, sep=sep) 68 | to_use = dev if split_name == 'dev' else test 69 | 70 | _ = compute_sim(args, to_use, labels_to_use, data, pretrained_name, model_name, seed, print_results = True) 71 | 72 | if args.bootstrap: 73 | start = time.time() 74 | boot_times = 10000 75 | print(f"Starting bootstrap with {boot_times/1000:.0f}k samples...") 76 | leaking_diff_list = [] 77 | nonleaking_diff_list = [] 78 | overall_metric_list = [] 79 | for b in range(boot_times): 80 | boot_idx = np.random.choice(np.arange(len(to_use)), replace=True, size = len(to_use)) 81 | to_use_boot = to_use.iloc[boot_idx,:] 82 | mean, leaking_diff, nonleaking_diff = compute_sim(args, to_use_boot, labels_to_use, data, pretrained_name, model_name, seed, print_results = False) 83 | overall_metric_list.append(mean) 84 | leaking_diff_list.append(leaking_diff) 85 | nonleaking_diff_list.append(nonleaking_diff) 86 | 87 | lb, ub = np.quantile(nonleaking_diff_list, (.025, .975)) 88 | CI = (ub - lb) / 2 89 | print("\nnonleaking diff: %.2f (+/- %.2f)" % (np.mean(nonleaking_diff_list)*100, 100*CI)) 90 | 91 | lb, ub = np.quantile(leaking_diff_list, (.025, .975)) 92 | CI = (ub - lb) / 2 93 | print("\nleaking diff: %.2f (+/- %.2f)" % (np.mean(leaking_diff_list)*100, 100*CI)) 94 | 95 | lb, ub = np.quantile(overall_metric_list, (.025, .975)) 96 | CI = (ub - lb) / 2 97 | print("\nunweighted mean: %.2f (+/- %.2f)\n" % (np.mean(overall_metric_list)*100, 100*CI)) 98 | 99 | print("time for bootstrap: %.1f minutes" % ((time.time() - start)/60)) 100 | print("--------------------------\n") 101 | 102 | 103 | 104 | 105 | def compute_sim(args, to_use, labels_to_use, data, pretrained_name, model_name, seed, print_results = False): 106 | labels = to_use[labels_to_use] 107 | xe_col = '%s_%s_%s_%s_seed%s_XE' % ('preds', data, pretrained_name, model_name, seed) 108 | e_col = '%s_%s_%s_%s_seed%s_E' % ('preds', data, pretrained_name, model_name, seed) 109 | x_col = '%s_%s_%s_%s_seed%s_X' % ('preds', data, pretrained_name, model_name, seed) 110 | xe = to_use[xe_col] 111 | e = to_use[e_col] 112 | x = to_use[x_col] 113 | xe_correct = np.array(1*(labels==xe)) 114 | x_correct = np.array(1*(labels==x)) 115 | e_correct = np.array(1*(labels==e)) 116 | 117 | # baseline and leaking proxy variable 118 | baseline_correct = 1*(x_correct) 119 | leaking = 1*(e_correct) 120 | leaked = np.argwhere(leaking.tolist()).reshape(-1) 121 | 122 | # get subgroups 123 | nonleaked = np.setdiff1d(np.arange(len(e_correct)), leaked) 124 | xe_correct_leaked = xe_correct[leaked] 125 | e_correct_leaked = e_correct[leaked] 126 | x_correct_leaked = x_correct[leaked] 127 | xe_correct_nonleaked = xe_correct[nonleaked] 128 | e_correct_nonleaked = e_correct[nonleaked] 129 | x_correct_nonleaked = x_correct[nonleaked] 130 | num_leaked = len(leaked) 131 | num_non_leaked = len(xe) - num_leaked 132 | 133 | unweighted_mean = np.mean([np.mean(xe_correct[split]) - np.mean(baseline_correct[split]) for split in [leaked,nonleaked]]) 134 | nonleaking_diff = np.mean(xe_correct_nonleaked) - np.mean(baseline_correct[nonleaked]) 135 | leaking_diff = np.mean(xe_correct_leaked) - np.mean(baseline_correct[leaked]) 136 | if print_results: 137 | print("\n------------------------") 138 | print("num (probably) leaked: %d" % num_leaked) 139 | print("y|x,e : %.4f baseline : %.4f y|x,e=null: %.4f" % (np.mean(xe_correct_leaked), np.mean(baseline_correct[leaked]), np.mean(x_correct_leaked))) 140 | print("diff: %.4f" % (leaking_diff)) 141 | print() 142 | print("num (probably) nonleaked: %d" % num_non_leaked) 143 | print("y|x,e : %.4f baseline : %.4f y|x,e=null: %.4f" % (np.mean(xe_correct_nonleaked), np.mean(baseline_correct[nonleaked]), np.mean(x_correct_nonleaked))) 144 | print("diff: %.4f" % (nonleaking_diff)) 145 | print() 146 | print("overall: ") 147 | print("y|x : %.4f y|e : %.4f" % (np.mean(x_correct), np.mean(e_correct))) 148 | print("y|x,e: %.4f baseline : %.4f" % (np.mean(xe_correct), np.mean(baseline_correct))) 149 | print("\nunweighted mean: %.2f" % (unweighted_mean*100)) 150 | print("--------------------------") 151 | return unweighted_mean, leaking_diff, nonleaking_diff 152 | 153 | if __name__ == '__main__': 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument("--gpu", default=-1, type=int, help='') 156 | parser.add_argument("--condition", default = "get_sim_metric", type=str, help='') 157 | parser.add_argument("--data", default = 'NLI', help='') 158 | parser.add_argument("--model_name", default ='', type=str, help='') 159 | parser.add_argument("--explanations_to_use", default = 'ground_truth', type=str, help='') 160 | parser.add_argument("--labels_to_use", default = 'label', type=str, help='') 161 | parser.add_argument("--seed", default = '42', type=str, help='') 162 | parser.add_argument('--leaking_param', default = 0, type=float, help='') 163 | parser.add_argument('--split_name', default='dev', type=str, help='see get_sim_metric') 164 | parser.add_argument('--task_pretrained_name', default='t5', type=str, help='') 165 | parser.add_argument('--model_size', default='base', type=str, help='') 166 | parser.add_argument('--server_number', '-s', default='13', type=str, help='') 167 | parser.add_argument('--bootstrap', action='store_true', help='') 168 | parser.add_argument('--small_data', action='store_true', help='Flag for using just a few datapoints for debugging purposes') 169 | parser.add_argument('--overwrite', action='store_true', help='rewrite predictions') 170 | parser.add_argument("--base_dir", default='', required=True, type=str, help="folders for saved_models and cached_models should be in this directory") 171 | args = parser.parse_args() 172 | 173 | if args.condition == "get_sim_metric": 174 | run_analysis(args, 175 | args.gpu, 176 | args.data, 177 | args.model_name, 178 | args.explanations_to_use, 179 | args.labels_to_use, 180 | args.seed, 181 | args.split_name, 182 | args.model_size) 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /sim_experiments/models/T5ForMC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import T5PreTrainedModel 4 | from transformers.modeling_t5 import T5Stack 5 | from torch.nn import CrossEntropyLoss 6 | import copy 7 | 8 | class T5ModelForMC(T5PreTrainedModel): 9 | """ 10 | Wrapper for T5PreTrainedModel to use T5 for multiple choice under a closed choice set. 11 | - adds .QA_forward method 12 | 13 | (decoder) QA_forward 14 | Input: 15 | input_ids of shape: batch_size x num_choices x max_seq_len 16 | Output: 17 | outputs[0] is loss of shape batch_size x num_choices. preds should be torch.argmax(loss, dim = -1) 18 | 19 | """ 20 | 21 | def __init__(self, config, project_to_small = False): 22 | super().__init__(config) 23 | self.model_dim = config.d_model 24 | 25 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 26 | 27 | encoder_config = copy.deepcopy(config) 28 | self.encoder = T5Stack(encoder_config) 29 | 30 | decoder_config = copy.deepcopy(config) 31 | decoder_config.is_decoder = True 32 | self.decoder = T5Stack(decoder_config) 33 | 34 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 35 | 36 | if project_to_small: 37 | self.project_to_small = nn.Linear(768, 512) # projection matrix for use in self.project_base_to_small 38 | 39 | self.init_weights() 40 | 41 | def get_input_embeddings(self): 42 | return self.shared 43 | 44 | def set_input_embeddings(self, new_embeddings): 45 | self.shared = new_embeddings 46 | 47 | def get_output_embeddings(self): 48 | return self.lm_head 49 | 50 | 51 | def forward(self, reduce_batch = True, loss_weights = None, **kwargs): 52 | # keyword arguments come in 3 flavors: encoder-specific (prefixed by 53 | # `encoder_`), decoder-specific (prefixed by `decoder_`) and those 54 | # that apply to the model as whole. 55 | # We let the specific kwargs override the common ones in case of conflict. 56 | 57 | # IF decoder_input_ids has NUM CHOICES dimension, return .QA_forward. this exists so we can wrap model with torch.nn.DataParallel, which must call forward 58 | if 'decoder_input_ids' in kwargs.keys(): 59 | if kwargs['decoder_input_ids'].dim() == 3: # batch_size x num_choices x max_seq_len 60 | return self.QA_forward(**kwargs) 61 | 62 | 63 | lm_labels = kwargs.pop("decoder_lm_labels", None) 64 | 65 | kwargs_common = dict( 66 | (k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_") 67 | ) 68 | kwargs_encoder = kwargs_common.copy() 69 | kwargs_decoder = kwargs_common.copy() 70 | kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_"))) 71 | kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_"))) 72 | 73 | # import ipdb; ipdb.set_trace() 74 | 75 | # Encode if needed (training, first prediction pass) 76 | encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) 77 | if encoder_hidden_states is None: 78 | # Convert encoder inputs in embeddings if needed 79 | hidden_states = kwargs_encoder.pop("inputs_embeds", None) 80 | if hidden_states is None: 81 | encoder_inputs_ids = kwargs_encoder.pop("input_ids") 82 | hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings 83 | 84 | encoder_outputs = self.encoder(hidden_states, **kwargs_encoder) 85 | encoder_hidden_states = encoder_outputs[0] 86 | else: 87 | encoder_outputs = () 88 | 89 | # Decode 90 | # Convert decoder inputs in embeddings if needed 91 | hidden_states = kwargs_decoder.pop("inputs_embeds", None) 92 | if hidden_states is None: 93 | decoder_inputs_ids = kwargs_decoder.pop("input_ids") 94 | hidden_states = self.shared(decoder_inputs_ids) 95 | 96 | kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states 97 | kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None) 98 | # print(kwargs_decoder.keys()) 99 | # print(kwargs_decoder["encoder_hidden_states"].shape) 100 | # print(kwargs_decoder["encoder_attention_mask"].shape) 101 | # print(hidden_states.shape) 102 | # import ipdb; ipdb.set_trace() 103 | decoder_outputs = self.decoder(hidden_states, **kwargs_decoder) 104 | 105 | sequence_output = decoder_outputs[0] 106 | # Rescale output before projecting on vocab 107 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 108 | sequence_output = sequence_output * (self.model_dim ** -0.5) 109 | lm_logits = self.lm_head(sequence_output) 110 | 111 | decoder_outputs = (lm_logits,) + decoder_outputs[1:] # Add hidden states and attention if they are here 112 | 113 | if lm_labels is not None: 114 | shift_logits = lm_logits[..., :-1, :].contiguous() 115 | shift_labels = lm_labels[..., 1:].contiguous() 116 | loss_fct = CrossEntropyLoss(ignore_index=-100, reduction = 'none') 117 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 118 | 119 | # reshape to batch shape (shifted) 120 | batch_shape = shift_labels.shape 121 | loss = loss.view(batch_shape) 122 | # get per data point loss 123 | loss = torch.mean(loss, dim=-1) 124 | if reduce_batch: 125 | loss = loss.mean() 126 | if loss_weights is not None: 127 | loss = loss * loss_weights 128 | 129 | decoder_outputs = ( 130 | loss, 131 | ) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 132 | 133 | return decoder_outputs + encoder_outputs 134 | 135 | 136 | def QA_forward(self, **kwargs): 137 | ''' 138 | so this is basically just .forward that maintains the num_choices dimension, plus doesn't reduce the token loss into a scalar 139 | 140 | # keyword arguments come in 3 flavors: encoder-specific (prefixed by 141 | # `encoder_`), decoder-specific (prefixed by `decoder_`) and those 142 | # that apply to the model as whole. 143 | # We let the specific kwargs override the common ones in case of conflict. 144 | ''' 145 | 146 | batch_size = kwargs['decoder_input_ids'].size(0) 147 | num_choices = kwargs['decoder_input_ids'].size(1) 148 | seq_len = kwargs['decoder_input_ids'].size(2) 149 | 150 | lm_labels = kwargs.pop("decoder_lm_labels", None) 151 | 152 | # kwargs_encoder/decoder are initialized from kwargs_common, and then overwritten by any encoder_/decoder_ prefixed arguments 153 | # arguments inside of kwargs_encoder/decoder are NOT prefixed 154 | kwargs_common = dict( 155 | (k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_") 156 | ) 157 | kwargs_encoder = kwargs_common.copy() 158 | kwargs_decoder = kwargs_common.copy() 159 | kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_"))) 160 | kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_"))) 161 | 162 | # combine batch_size and num_choices dimension in all inputs 163 | for kwargs in [kwargs_encoder, kwargs_decoder]: 164 | for k, v in kwargs.items(): 165 | if hasattr(v,'dim'): 166 | if v.dim() == 3: 167 | kwargs[k] = v.reshape(-1, v.size(-1)) 168 | elif v.dim() == 4: 169 | kwargs[k] = v.reshape(-1, v.size(-2), v.size(-1)) 170 | 171 | # Encode if needed (training, first prediction pass) 172 | encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) 173 | if encoder_hidden_states is None: 174 | # Convert encoder inputs in embeddings if needed 175 | hidden_states = kwargs_encoder.pop("inputs_embeds", None) 176 | if hidden_states is None: 177 | encoder_inputs_ids = kwargs_encoder.pop("input_ids") 178 | hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings 179 | 180 | encoder_outputs = self.encoder(hidden_states, **kwargs_encoder) 181 | encoder_hidden_states = encoder_outputs[0] 182 | else: 183 | encoder_outputs = () 184 | 185 | # Decode 186 | # Convert decoder inputs in embeddings if needed 187 | hidden_states = kwargs_decoder.pop("inputs_embeds", None) 188 | if hidden_states is None: 189 | decoder_inputs_ids = kwargs_decoder.pop("input_ids") 190 | hidden_states = self.shared(decoder_inputs_ids) 191 | 192 | kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states 193 | kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None) 194 | 195 | decoder_outputs = self.decoder(hidden_states, **kwargs_decoder) 196 | 197 | sequence_output = decoder_outputs[0] 198 | sequence_output = sequence_output.reshape(batch_size, num_choices, seq_len, -1) 199 | # Rescale output before projecting on vocab 200 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 201 | sequence_output = sequence_output * (self.model_dim ** -0.5) 202 | lm_logits = self.lm_head(sequence_output) 203 | 204 | decoder_outputs = (lm_logits,) + decoder_outputs[1:] # Add hidden states and attention if they are here 205 | if lm_labels is not None: 206 | 207 | shift_logits = lm_logits[..., :-1, :].contiguous() 208 | shift_labels = lm_labels[..., 1:].contiguous() 209 | loss_fct = CrossEntropyLoss(ignore_index=-100, reduction = 'none') 210 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 211 | # reshape to batch_size x num choices x -1 then take mean over the token dim 212 | loss = loss.reshape(batch_size, num_choices, -1) 213 | loss = torch.mean(loss, dim=-1) 214 | 215 | decoder_outputs = ( 216 | loss, 217 | ) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 218 | 219 | return decoder_outputs + encoder_outputs 220 | 221 | 222 | def project_base_to_small(self, embeddings): 223 | ''' 224 | written for 2-agent experiments with task model as t5-base and simulator as t5-small. 225 | this will project a set of embeddings down to t5-small's embedding dim 226 | ''' 227 | return self.project_to_small(embeddings) -------------------------------------------------------------------------------- /sim_experiments/run_tasks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | # --- BEGIN QA --- # 5 | 6 | seed_variance_test = [21] 7 | 8 | def QA_task(args): 9 | for seed in seed_variance_test: 10 | os.system(f"python main.py --do_explain false " 11 | f"--model_name baseline " 12 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 20 --warmup_proportion .1 " 13 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 14 | ) 15 | 16 | def QA_SIM_human(args): 17 | for seed in seed_variance_test: 18 | os.system(f"python main.py --do_explain false --multi_explanation false --condition_on_explanations true --explanations_to_use ground_truth " 19 | f"--model_name sim.human --explanation_dropout .5 " 20 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 30 --warmup_proportion .1 " 21 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 22 | f"--save_dir {save_dir} --cache_dir {cache_dir} " 23 | ) 24 | 25 | def QA_CLM_reason(args): 26 | for seed in seed_variance_test: 27 | os.system(f"python main.py --do_task false --task_coef 0 --multi_explanation false --select_for bleu " 28 | f"--model_name CLM.reason --write_predictions --max_sample_len 20 " 29 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 5 --warmup_proportion .1 " 30 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 31 | f"--save_dir {save_dir} --cache_dir {cache_dir} " 32 | ) 33 | 34 | def QA_CLM_rationalize(args): 35 | for seed in seed_variance_test: 36 | os.system(f"python main.py --do_task false --task_coef 0 --multi_explanation true --select_for bleu " 37 | f"--model_name CLM.rationalize --write_predictions --max_sample_len 20 " 38 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 5 --warmup_proportion .1 " 39 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 40 | f"--save_dir {save_dir} --cache_dir {cache_dir} " 41 | ) 42 | 43 | def QA_CLM_reason_MT(args): 44 | for seed in seed_variance_test: 45 | os.system(f"python main.py --task_coef .5 --multi_explanation false " 46 | f"--model_name MT.RE --write_predictions --max_sample_len 20 " 47 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 20 --warmup_proportion .1 " 48 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 49 | f"--save_dir {save_dir} --cache_dir {cache_dir} " 50 | ) 51 | 52 | def QA_SIM_CLM_reason_MT(args): 53 | for seed in seed_variance_test: 54 | os.system(f"python main.py --task_pretrained_name t5-base --do_explain false --multi_explanation false --condition_on_explanations true --explanations_to_use t5-MT-single-exp-seed{seed} --labels_to_use preds_QA_t5-base_MT.RE_seed{seed} " 55 | f"--model_name sim.MT.RE --explanation_dropout .5 " 56 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 30 --warmup_proportion .1 " 57 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 58 | f"--save_dir {save_dir} --cache_dir {cache_dir} " 59 | ) 60 | 61 | def QA_CLM_rationalize_MT(args): 62 | for seed in seed_variance_test: 63 | os.system(f"python main.py --task_coef .5 --multi_explanation true " 64 | f"--model_name MT.RA --write_predictions --max_sample_len 20 " 65 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 20 --warmup_proportion .1 " 66 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 67 | f"--save_dir {save_dir} --cache_dir {cache_dir} " 68 | ) 69 | 70 | def QA_SIM_CLM_rationalize_MT(args): 71 | for seed in seed_variance_test: 72 | os.system(f"python main.py --task_pretrained_name t5-base --do_explain false --multi_explanation false --condition_on_explanations true --explanations_to_use t5-MT-multi-exp-pred-seed{seed} --labels_to_use preds_QA_t5-base_MT.RA_seed{seed} " 73 | f"--model_name sim.MT.RA --explanation_dropout .5 " 74 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 30 --warmup_proportion .1 " 75 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 76 | f"--save_dir {save_dir} --cache_dir {cache_dir} " 77 | ) 78 | 79 | def QA_ST_reason(args): 80 | for seed in seed_variance_test: 81 | os.system(f"python main.py --do_explain false --multi_explanation false --condition_on_explanations true --explanations_to_use t5-single-exp-seed{seed} " 82 | f"--model_name ST.RE --write_predictions " 83 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 20 --warmup_proportion .1 " 84 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 85 | f"--save_dir {save_dir} --cache_dir {cache_dir} " 86 | ) 87 | 88 | def QA_SIM_ST_reason(args): 89 | for seed in seed_variance_test: 90 | os.system(f"python main.py --task_pretrained_name t5-base --do_explain false --multi_explanation false --condition_on_explanations true --explanations_to_use t5-single-exp-seed{seed} --labels_to_use preds_QA_t5-base_ST.RE_seed{seed} " 91 | f"--model_name sim.ST.RE --explanation_dropout .5 --print_examples " 92 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 30 --warmup_proportion .1 " 93 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 94 | f"--save_dir {save_dir} --cache_dir {cache_dir} " 95 | ) 96 | 97 | def QA_ST_rationalize(args): 98 | for seed in seed_variance_test: 99 | os.system(f"python main.py --do_explain false --multi_explanation true --condition_on_explanations true --explanations_to_use t5-multi-exp-seed{seed} " 100 | f"--model_name ST.RA --write_predictions " 101 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 20 --warmup_proportion .1 " 102 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 103 | f"--save_dir {save_dir} --cache_dir {cache_dir} " 104 | ) 105 | 106 | def QA_SIM_ST_rationalize(args): 107 | for seed in seed_variance_test: 108 | os.system(f"python main.py --task_pretrained_name t5-base --do_explain false --multi_explanation false --condition_on_explanations true --explanations_to_use t5-multi-exp-pred-seed{seed} --labels_to_use preds_QA_t5-base_ST.RA_seed{seed} " 109 | f"--model_name sim.ST.RA --explanation_dropout .5 --print_examples " 110 | f"--data_dir data/v1.0 --gpu {args.gpu} --seed {seed} --num_train_epochs 30 --warmup_proportion .1 " 111 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} {small_data_addin} " 112 | f"--save_dir {save_dir} --cache_dir {cache_dir} " 113 | ) 114 | 115 | # --- BEGIN NLI --- # 116 | 117 | def NLI_task(args): 118 | for seed in seed_variance_test: 119 | os.system(f"python main.py --do_explain false " 120 | f"--model_name baseline " 121 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 122 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 10 " 123 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 124 | ) 125 | 126 | def NLI_SIM_human(args): 127 | LR = 1e-4 if 't5' in args.model else 1e-5 128 | for seed in seed_variance_test: 129 | os.system(f"python main.py --do_explain false --task_pretrained_name {args.model} --multi_explanation false --condition_on_explanations true --explanations_to_use ground_truth " 130 | f"--model_name sim.human --input_dropout .2 --explanation_dropout .4 --lr {LR} " 131 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 132 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 15 " 133 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 134 | ) 135 | 136 | def NLI_CLM_reason(args): 137 | for seed in seed_variance_test: 138 | os.system(f"python main.py --do_task false --task_coef 0 --multi_explanation false --select_for bleu " 139 | f"--model_name CLM.reason --write_predictions " 140 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 141 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 5 " 142 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 143 | ) 144 | 145 | def NLI_CLM_rationalize(args): 146 | for seed in seed_variance_test: 147 | os.system(f"python main.py --do_task false --task_coef 0 --multi_explanation true --select_for bleu " 148 | f"--model_name CLM.rationalize --write_predictions " 149 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 150 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 5 " 151 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 152 | ) 153 | 154 | def NLI_CLM_reason_MT(args): 155 | for seed in seed_variance_test: 156 | os.system(f"python main.py --task_coef .5 --multi_explanation false --do_train false --do_eval false " 157 | f"--model_name MT.RE --write_predictions " 158 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 159 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 10 " 160 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 161 | ) 162 | 163 | def NLI_SIM_CLM_reason_MT(args): 164 | for seed in seed_variance_test: 165 | os.system(f"python main.py --task_pretrained_name t5-base --do_explain false --multi_explanation false --condition_on_explanations true --explanations_to_use t5-MT-single-exp-seed{seed} --labels_to_use preds_NLI_t5-base_MT.RE_seed{seed} " 166 | f"--model_name sim.MT.RE --input_dropout .2 --explanation_dropout .4 " 167 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 168 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 15 " 169 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 170 | ) 171 | 172 | 173 | def NLI_CLM_rationalize_MT(args): 174 | for seed in seed_variance_test: 175 | os.system(f"python main.py --task_coef .5 --multi_explanation true --do_train false --do_eval false " 176 | f"--model_name MT.RA --write_predictions " 177 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 178 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 10 " 179 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 180 | ) 181 | 182 | def NLI_SIM_CLM_rationalize_MT(args): 183 | for seed in seed_variance_test: 184 | os.system(f"python main.py --task_pretrained_name t5-base --do_explain false --multi_explanation false --condition_on_explanations true --explanations_to_use t5-MT-multi-exp-pred-seed{seed} --labels_to_use preds_NLI_t5-base_MT.RA_seed{seed} " 185 | f"--model_name sim.MT.RA --input_dropout .2 --explanation_dropout .4 " 186 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 187 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 15 " 188 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 189 | ) 190 | 191 | def NLI_ST_reason(args): 192 | LR = 1e-4 if 't5' in args.model else 1e-5 193 | for seed in seed_variance_test: 194 | os.system(f"python main.py --task_pretrained_name {args.model} --do_explain false --multi_explanation false --condition_on_explanations true --explanations_to_use t5-single-exp-seed{seed} " 195 | f"--model_name ST.RE --write_predictions --lr {LR} " 196 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 197 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 10 " 198 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 199 | ) 200 | 201 | def NLI_SIM_ST_reason(args): 202 | LR = 1e-4 if 't5' in args.model else 1e-5 203 | for seed in seed_variance_test: 204 | os.system(f"python main.py --task_pretrained_name {args.model} --do_explain false --multi_explanation false --condition_on_explanations true --explanations_to_use t5-single-exp-seed{seed} --labels_to_use preds_NLI_{args.model}_ST.RE_seed{seed} " 205 | f"--model_name sim.ST.RE --input_dropout .2 --explanation_dropout .4 --lr {LR} " 206 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 207 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 15 " 208 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 209 | ) 210 | 211 | def NLI_ST_rationalize(args): 212 | LR = 1e-4 if 't5' in args.model else 1e-5 213 | for seed in seed_variance_test: 214 | os.system(f"python main.py --task_pretrained_name {args.model} --do_explain false --multi_explanation true --condition_on_explanations true --explanations_to_use t5-multi-exp-seed{seed} " 215 | f"--model_name ST.RA --write_predictions --lr {LR} " 216 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 217 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 10 " 218 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 219 | ) 220 | 221 | def NLI_SIM_ST_rationalize(args): 222 | LR = 1e-4 if 't5' in args.model else 1e-5 223 | for seed in seed_variance_test: 224 | os.system(f"python main.py --task_pretrained_name {args.model} --do_explain false --multi_explanation false --condition_on_explanations true --explanations_to_use t5-multi-exp-exp-seed{seed} --labels_to_use preds_NLI_{args.model}_ST.RA_seed{seed} " 225 | f"--model_name sim.ST.RA --input_dropout .2 --explanation_dropout .4 --lr {LR} " 226 | f"--data_dir data/e-SNLI-data --gpu {args.gpu} --seed {seed} " 227 | f"--train_batch_size {args.train_batch_size} --grad_accumulation_factor {args.grad_accumulation_factor} --warmup_proportion .01 --num_train_epochs 15 " 228 | f"--save_dir {save_dir} --cache_dir {cache_dir} {small_data_addin}" 229 | ) 230 | 231 | if __name__ == '__main__': 232 | parser = argparse.ArgumentParser() 233 | parser.add_argument("--gpu", default=0, type=int, help='') 234 | parser.add_argument("--experiment", '-e', type=str, help='') 235 | parser.add_argument("--server_number", '-s', required=True, type=str, help='') 236 | parser.add_argument("--model", default='t5-base', type=str, help='HuggingFace transformer model') 237 | parser.add_argument("--train_batch_size", '-b', default=3, type=int, help="ONLY FOR QA. Total batch size for training. Effective batch size is this times grad_accumulation_factor") 238 | parser.add_argument('--grad_accumulation_factor', '-g', type=int, default=4, help="ONLY FOR QA. Number of updates steps to accumulate before performing a backward pass and step.") 239 | parser.add_argument('--small_data', action='store_true', help='Flag for using just a few datapoints for debugging purposes') 240 | parser.add_argument("--save_dir", default='', required=True, type=str, 241 | help="The output directory where the model checkpoints will be written.") 242 | parser.add_argument("--cache_dir", default='', required=True, type=str, 243 | help="Directory for cacheing pretrained models.") 244 | args = parser.parse_args() 245 | save_dir = args.save_dir 246 | cache_dir = args.cache_dir 247 | 248 | if args.small_data: 249 | small_data_addin = '-s -ss 64 ' # uses 64 points per split in main.py 250 | else: 251 | small_data_addin = '' 252 | 253 | print("Starting experiment %s " % args.experiment) 254 | print("Using seeds ", seed_variance_test) 255 | print("Saving models in %s" % save_dir) 256 | 257 | # --- begin QA --- # 258 | 259 | if args.experiment == 'QA.task': 260 | QA_task(args) 261 | 262 | if args.experiment == 'QA.SIM.human': 263 | QA_SIM_human(args) 264 | 265 | if args.experiment == 'QA.CLM.reason': 266 | QA_CLM_reason(args) 267 | 268 | if args.experiment == 'QA.CLM.rationalize': 269 | QA_CLM_rationalize(args) 270 | 271 | if args.experiment == 'QA.CLM.reason.MT': 272 | QA_CLM_reason_MT(args) 273 | 274 | if args.experiment == 'QA.SIM.MT.RE': 275 | QA_SIM_CLM_reason_MT(args) 276 | 277 | if args.experiment == 'QA.CLM.rationalize.MT': 278 | QA_CLM_rationalize_MT(args) 279 | 280 | if args.experiment == 'QA.SIM.MT.RA': 281 | QA_SIM_CLM_rationalize_MT(args) 282 | 283 | if args.experiment == 'QA.ST.RE': 284 | QA_ST_reason(args) 285 | 286 | if args.experiment == 'QA.SIM.ST.RE': 287 | QA_SIM_ST_reason(args) 288 | 289 | if args.experiment == 'QA.ST.RA': 290 | QA_ST_rationalize(args) 291 | 292 | if args.experiment == 'QA.SIM.ST.RA': 293 | QA_SIM_ST_rationalize(args) 294 | 295 | 296 | # --- begin NLI --- # 297 | 298 | if args.experiment == 'NLI.task': 299 | NLI_task(args) 300 | 301 | if args.experiment == 'NLI.SIM.human': 302 | NLI_SIM_human(args) 303 | 304 | if args.experiment == 'NLI.CLM.reason': 305 | NLI_CLM_reason(args) 306 | 307 | if args.experiment == 'NLI.CLM.rationalize': 308 | NLI_CLM_rationalize(args) 309 | 310 | if args.experiment == 'NLI.CLM.reason.MT': 311 | NLI_CLM_reason_MT(args) 312 | 313 | if args.experiment == 'NLI.SIM.MT.RE': 314 | NLI_SIM_CLM_reason_MT(args) 315 | 316 | if args.experiment == 'NLI.CLM.rationalize.MT': 317 | NLI_CLM_rationalize_MT(args) 318 | 319 | if args.experiment == 'NLI.SIM.MT.RA': 320 | NLI_SIM_CLM_rationalize_MT(args) 321 | 322 | if args.experiment == 'NLI.ST.RE': 323 | NLI_ST_reason(args) 324 | 325 | if args.experiment == 'NLI.ST.RA': 326 | NLI_ST_rationalize(args) 327 | 328 | if args.experiment == 'NLI.SIM.ST.RE': 329 | NLI_SIM_ST_reason(args) 330 | 331 | if args.experiment == 'NLI.SIM.ST.RA': 332 | NLI_SIM_ST_rationalize(args) 333 | -------------------------------------------------------------------------------- /sim_experiments/training_reports/README.md: -------------------------------------------------------------------------------- 1 | Model training reports will be saved here (see Report class in ../classes.py). -------------------------------------------------------------------------------- /sim_experiments/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sacrebleu import corpus_bleu 3 | import csv 4 | import argparse 5 | import logging 6 | import json 7 | import time 8 | import torch 9 | import torch.nn.functional as F 10 | import numpy as np 11 | import pandas as pd 12 | 13 | ### METRICS ### 14 | 15 | def computeBLEU(outputs, targets): 16 | # see https://github.com/mjpost/sacreBLEU 17 | targets = [[t[i] for t in targets] for i in range(len(targets[0]))] 18 | return corpus_bleu(outputs, targets, lowercase=True).score 19 | 20 | def CE_Loss(probabilities, labels): 21 | pred_probs = probabilities.gather(-1, labels.unsqueeze(-1)) 22 | return torch.mean(-torch.log(pred_probs)) 23 | 24 | ### END METRICS ### 25 | 26 | 27 | ### SAMPLING FUNCTIONS ### 28 | 29 | def T5_sample(model, encoder_hidden_states, decoder_input_ids, encoder_attention_mask, tokenizer, max_sample_len): 30 | ''' 31 | Uses model to sample based on context_ids, until max_sample_len is hit, with the expectation that decoding will stop at a specified [end] token 32 | This function is batched, meaning predictions are placed at the end of each running sequence within a tensor of shape (batch_size x num_choices x max_seq_len) 33 | Before returning samples, the original contexts in running_contexts are set to the pad_token_id 34 | ''' 35 | batch_size = decoder_input_ids.size(0) 36 | vocab_size = len(tokenizer) # NOT tokenizer.vocab_size, this attr does not update when tokens are added 37 | pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 38 | eos_token_id = tokenizer.eos_token_id 39 | running_contexts = decoder_input_ids.clone() 40 | seq_len = decoder_input_ids.size(-1) 41 | device = decoder_input_ids.device 42 | start = time.time() 43 | context_len = (running_contexts!=pad_token_id).sum(dim=-1).max().item() 44 | if eos_token_id is not None: where_eos_sampled = [False] * batch_size 45 | 46 | # pad the input contexts up to max_sample_len 47 | if running_contexts.size(-1) < (max_sample_len+context_len): 48 | extend_by = (max_sample_len+context_len) - running_contexts.size(-1) 49 | extension_shape = [batch_size, extend_by] 50 | padding = pad_token_id * torch.ones(extension_shape, dtype = torch.long) 51 | padding = padding.to(running_contexts.device) 52 | running_contexts = torch.cat((running_contexts,padding),dim=-1) 53 | seq_len = max_sample_len+context_len 54 | 55 | where_last_tokens = (running_contexts!=pad_token_id).sum(-1) - 1 56 | 57 | # BEGIN SAMPLING 58 | for k in range(max_sample_len): 59 | 60 | attention_mask = (running_contexts!=pad_token_id).float() 61 | 62 | # hold onto the starting point of sampling for each context 63 | if k==0: init_where_last_tokens = where_last_tokens 64 | 65 | with torch.no_grad(): 66 | outputs = model(encoder_hidden_states = encoder_hidden_states, 67 | encoder_attention_mask = encoder_attention_mask, 68 | decoder_input_ids = running_contexts, 69 | decoder_attention_mask = attention_mask) 70 | logits = outputs[0] 71 | 72 | # get logits corresponding to last tokens in each sequence 73 | logits = torch.stack([logits[i, last_idx, :] for i, last_idx in enumerate(where_last_tokens)]) 74 | preds = torch.argmax(logits, dim = -1) 75 | 76 | # assign preds to the first pad location in each running_contexts[i,j,:] sequence. check if eos_token sampled in each sequence 77 | for i in range(batch_size): 78 | last_token_index = where_last_tokens[i] 79 | running_contexts[i,last_token_index+1] = preds[i].item() 80 | if eos_token_id is not None: 81 | if preds[i].item() == eos_token_id: where_eos_sampled[i] = True 82 | 83 | # if eos tokens sampled in every sequence, quit sampling 84 | if all(where_eos_sampled): 85 | break 86 | 87 | # iterate where_last_tokens 88 | where_last_tokens = where_last_tokens + 1 89 | 90 | # lastly, set the context portion of each sample to the pad_token_id 91 | samples = running_contexts 92 | for i in range(batch_size): 93 | end_of_context_index = init_where_last_tokens[i] 94 | samples[i,:(end_of_context_index+1)] = pad_token_id 95 | 96 | # print("sample time per input: %.2f" % ((time.time()-start)/batch_size)) 97 | del outputs, logits 98 | 99 | return samples 100 | 101 | def get_differentiable_explanations(speaker_model, listener_model, context_ids, tokenizer, max_sample_len, method = 'differentiable_argmax', eos_token_id = None, 102 | input_ids = None, input_masks = None, encoder_hidden_states = None, listener_context_ids = None): 103 | ''' 104 | - Differentiable decoding based on context_ids as the beginning of the output sequence. Context_ids of shape: batch_size x max_seq_len 105 | - Returns indices of the last 'valid' sample in each sequence (i.e. one right before first eos-token or pad-token) 106 | ''' 107 | assert context_ids.dim() == 2, "Should be sampling one sequence per data point" 108 | # get accessible models in multi-gpu case 109 | if hasattr(speaker_model, 'module'): 110 | _speaker_model = speaker_model.module 111 | else: 112 | _speaker_model = speaker_model 113 | if hasattr(listener_model, 'module'): 114 | _listener_model = listener_model.module 115 | else: 116 | _listener_model = listener_model 117 | batch_size = context_ids.size(0) 118 | seq_len = context_ids.size(1) 119 | vocab_size = _speaker_model.lm_head.out_features # NOT tokenizer.vocab_size, that would lead to an error 120 | pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 121 | running_contexts = context_ids.clone() 122 | device = context_ids.device 123 | start = time.time() 124 | softmax = torch.nn.Softmax(dim=-1) 125 | context_len = (running_contexts!=pad_token_id).sum(dim=-1).max().item() 126 | max_explanation_len = max_sample_len + context_len 127 | pad_token_embed = _speaker_model.shared(torch.tensor([pad_token_id]).to(context_ids.device)) 128 | pad_token_embed = pad_token_embed.detach() 129 | if eos_token_id is not None: where_eos_sampled = [False] * batch_size 130 | 131 | # pad running_contexts up to max_sample_len 132 | if running_contexts.size(-1) < (max_explanation_len): 133 | expand_by = (max_explanation_len) - running_contexts.size(-1) 134 | padding = torch.tensor(pad_token_id).expand(batch_size,expand_by) 135 | padding = padding.to(running_contexts.device) 136 | return_running_contexts = torch.cat((running_contexts,padding),dim=-1) 137 | seq_len = return_running_contexts.size(-1) 138 | else: 139 | return_running_contexts = running_contexts.clone() 140 | 141 | if encoder_hidden_states is None: 142 | outputs = speaker_model(input_ids = input_ids, 143 | attention_mask = input_masks) 144 | encoder_hidden_states = outputs[1] 145 | 146 | if listener_context_ids is not None: 147 | listener_embeds = _listener_model.shared(running_contexts.clone()) 148 | else: 149 | # make context tensor 150 | listener_context = "My commonsense tells me that" 151 | listener_context_ids = torch.tensor(tokenizer.encode(listener_context), dtype = torch.long).to(running_contexts.device) 152 | listener_context_ids = listener_context_ids.unsqueeze(0).expand(batch_size,len(listener_context_ids)) 153 | # pad context tensor up to max len 154 | max_explanation_len = listener_context_ids.size(-1) + max_sample_len 155 | expand_by = (max_explanation_len) - listener_context_ids.size(-1) 156 | padding = torch.tensor(pad_token_id).expand(batch_size,expand_by) 157 | padding = padding.to(running_contexts.device) 158 | listener_context_ids = torch.cat((listener_context_ids, padding),dim=-1) 159 | # look up embeddings 160 | listener_embeds = _listener_model.shared(listener_context_ids.clone()) 161 | 162 | 163 | # BEGIN SAMPLING 164 | for k in range(max_sample_len): 165 | 166 | # get locations of last non-pad tokens in each sequence for purposes of: getting predictions from logits, and updating running_contexts 167 | speaker_where_last_tokens = [] 168 | listener_where_last_tokens = [] 169 | for sequence in return_running_contexts.tolist(): 170 | if pad_token_id in sequence: 171 | speaker_where_last_tokens.append(sequence.index(pad_token_id)-1) 172 | else: 173 | speaker_where_last_tokens.append(running_contexts.size(-1)-1) 174 | for sequence in listener_context_ids.tolist(): 175 | if pad_token_id in sequence: 176 | listener_where_last_tokens.append(sequence.index(pad_token_id)-1) 177 | else: 178 | listener_where_last_tokens.append(listener_context_ids.size(-1)-1) 179 | 180 | # make logits mask 181 | logits_mask = torch.zeros(batch_size, seq_len, vocab_size) 182 | logits_mask = logits_mask.to(device) 183 | for i in range(batch_size): 184 | last_token_index = speaker_where_last_tokens[i] 185 | logits_mask[i,last_token_index,:] = 1 186 | 187 | # hold onto the starting point of sampling for each contexts 188 | if k == 0: 189 | return_sample_embeds = _speaker_model.shared(return_running_contexts.clone()) 190 | running_decoder_input_embeds = return_sample_embeds.clone() 191 | embed_dim = running_decoder_input_embeds.size(-1) 192 | 193 | # forward pass 194 | outputs = speaker_model(encoder_hidden_states = encoder_hidden_states, 195 | encoder_attention_mask = input_masks, 196 | decoder_inputs_embeds = running_decoder_input_embeds) 197 | logits = outputs[0] 198 | 199 | # get logits corresponding to last tokens in each sequence, then get preds 200 | logits = logits.view(batch_size, seq_len, vocab_size) 201 | logits = logits * logits_mask 202 | logits = torch.sum(logits, dim = 1) 203 | preds = torch.argmax(logits, dim = -1) 204 | 205 | # get the predicted token's embeddings for both the speaker and the listener 206 | if method == 'differentiable_argmax': 207 | preds_onehot = differentiable_argmax(logits, temperature = 1) 208 | pred_speaker_embeds = torch.mm(preds_onehot, _speaker_model.shared.weight) # these get passed in as decoder_inputs_embeds at next step 209 | pred_listener_embeds = torch.mm(preds_onehot, _listener_model.shared.weight) # these will get returned to be passed to the listening simulator model 210 | 211 | if method == 'averaged_embeddings': 212 | # get hidden states for each last token 213 | probs = softmax(logits) 214 | 215 | # averaged predictions over model token input embeddings 216 | averaged_embeddings = torch.mm(probs, _speaker_model.shared.weight) 217 | speaker_embeds = averaged_embeddings[preds, :] 218 | 219 | # assign preds to the first pad location in each running_contexts[i,j,:] sequence, and decoder_hidden_states to the running_decoder_input_embeds 220 | for i in range(batch_size): 221 | speaker_last_token_index = speaker_where_last_tokens[i] 222 | listener_last_token_index = listener_where_last_tokens[i] 223 | return_running_contexts[i,speaker_last_token_index+1] = preds[i].item() 224 | return_sample_embeds[i,speaker_last_token_index+1,:] = pred_speaker_embeds[i,:] 225 | listener_context_ids[i,listener_last_token_index+1] = preds[i].item() 226 | listener_embeds[i,listener_last_token_index+1,:] = pred_listener_embeds[i,:] 227 | if eos_token_id is not None: 228 | if preds[i].item() == eos_token_id: 229 | where_eos_sampled[i] = True 230 | 231 | if eos_token_id is not None: 232 | if all(where_eos_sampled): break 233 | 234 | # reassign decoder_input_embeds 235 | running_decoder_input_embeds = return_sample_embeds # .clone() appears inconsequential here 236 | 237 | # now we return a list of embeddings and token_ids. 238 | return_sample_embeds_list = [] 239 | return_messages_list = [] 240 | 241 | # for any samples in a sequence after the first eos-token or pad-token, record only up to the eos. keep track of explanation_lens 242 | context_ids_list = listener_context_ids.tolist() 243 | explanation_lens = [] 244 | for i in range(batch_size): 245 | if eos_token_id in context_ids_list[i]: 246 | begin_id = context_ids_list[i].index(eos_token_id) 247 | elif pad_token_id in context_ids_list[i]: 248 | begin_id = context_ids_list[i].index(pad_token_id) 249 | else: 250 | begin_id = None 251 | # keep up to eos 252 | if begin_id is not None: 253 | return_sample_embeds_list.append(listener_embeds[i,:begin_id,:]) 254 | return_messages_list.append(listener_context_ids[i,:begin_id]) 255 | explanation_lens.append(begin_id) 256 | # no eos, keep whole sequence 257 | else: 258 | return_sample_embeds_list.append(listener_embeds[i,:,:]) 259 | return_messages_list.append(listener_context_ids[i,:]) 260 | explanation_lens.append(max_explanation_len) 261 | 262 | del outputs, logits, encoder_hidden_states, running_decoder_input_embeds 263 | if method == 'differentiable_argmax': del preds_onehot 264 | if method == 'averaged_embeddings': del probs, averaged_embeddings, predicted_embeddings 265 | 266 | return return_sample_embeds_list, return_messages_list, explanation_lens 267 | 268 | 269 | ### END SAMPLING FUNCTIONS ### 270 | 271 | 272 | ### TOKENIZATION FUNCTIONS ### 273 | 274 | def trim_unks(x): 275 | try: 276 | unk_id = x.index('_end_') 277 | return x[:unk_id] 278 | except: 279 | return x 280 | 281 | def detok_batch(tokenizer, x, ignore_tokens = None, eos_token = None): 282 | ''' 283 | - convert x into strings using tokenizer 284 | - x is either tensor of dim 2 or dim 3 or a .tolist() of such a tensor 285 | - stop decoding if eos_token hit, if eos_token provided 286 | - skip over tokens in ignore_tokens 287 | ''' 288 | 289 | if ignore_tokens is not None: 290 | ignore_tokens_idx = tokenizer.convert_tokens_to_ids(ignore_tokens) 291 | ignore_tokens_idx += [-100,-1] 292 | else: 293 | ignore_tokens = [] 294 | ignore_tokens_idx = [-100,-1] 295 | 296 | # if tokenizer.pad_token_id is None: 297 | ignore_tokens_idx += [0] 298 | if not isinstance(x, list): 299 | x = x.tolist() 300 | dim = 3 if isinstance(x[0][0], list) else 2 301 | eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) 302 | texts = [] 303 | 304 | for i in range(len(x)): 305 | if dim == 2: 306 | current_idx = [] 307 | for j in range(len(x[i])): 308 | current_id = x[i][j] 309 | if current_id == eos_token_id: 310 | break 311 | elif current_id not in ignore_tokens_idx: 312 | current_idx.append(current_id) 313 | decoded_sequence = tokenizer.decode(current_idx) 314 | # check if any ignore_tokens are in decoded_sequence. this is happening for some reason. many token_ids lead to [UNK], but [UNK] maps to id=100 315 | if any([ignore_token in decoded_sequence for ignore_token in ignore_tokens]): 316 | decoded_sequence = ' '.join([token for token in decoded_sequence.split() if token not in ignore_tokens]) 317 | # APPEND 318 | texts.append(decoded_sequence) 319 | elif dim == 3: 320 | decoded_sequences = [] 321 | for j in range(len(x[i])): 322 | current_idx = [] 323 | for k in range(len(x[i][j])): 324 | current_id = x[i][j][k] 325 | if current_id == eos_token_id: 326 | break 327 | elif current_id not in ignore_tokens_idx: 328 | current_idx.append(current_id) 329 | 330 | decoded_sequence = tokenizer.decode(current_idx) 331 | 332 | # check if any ignore_tokens are in decoded_sequence. this is happening for some reason. many token_ids lead to [UNK], but [UNK] maps to id=100 333 | if any([ignore_token in decoded_sequence for ignore_token in ignore_tokens]): 334 | decoded_sequence = ' '.join([token for token in decoded_sequence.split() if token not in ignore_tokens]) 335 | # APPEND single decoding 336 | decoded_sequences.append(decoded_sequence) 337 | 338 | # APPEND list of n decodings 339 | texts.append(decoded_sequences) 340 | 341 | return texts 342 | 343 | 344 | ### END TOKENIZATION FUNCTIONS ### 345 | 346 | 347 | ### MISC ### 348 | 349 | def str2bool(v): 350 | # used for boolean argparse values 351 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 352 | return True 353 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 354 | return False 355 | else: 356 | raise argparse.ArgumentTypeError('Boolean value expected.') 357 | 358 | def isNaN(x): 359 | return (x!=x) 360 | 361 | def gumbel_softmax(logits, temperature): 362 | """ 363 | based on implementation here: https://github.com/dev4488/VAE_gumble_softmax/blob/master/vae_gumbel_softmax.py 364 | the point is that derivative of output is taken w.r.t. y_soft, which is a differentiable function of the logits 365 | """ 366 | import ipdb; ipdb.set_trace() 367 | logits_shape = list(logits.shape) 368 | gumbel_softmax = F.gumbel_softmax(logits, tau=temperature, hard = False) 369 | y_soft = gumbel_softmax(logits, temperature) 370 | shape = y_soft.size() 371 | ind = y_soft.argmax(dim=-1) 372 | y_hard = torch.zeros_like(y_soft).view(-1, shape[-1]) 373 | y_hard.scatter_(1, ind.view(-1, 1), 1) 374 | y_hard = y_hard.view(*y_soft.shape) 375 | y_hard = (y_hard - y_soft).detach() + y_soft 376 | y_hard = torch.argmax(y_hard, dim=-1) 377 | return y_hard 378 | 379 | def differentiable_argmax(logits, temperature): 380 | """ 381 | take argmax on forward pass; use softmax for backward pass 382 | """ 383 | logits_shape = list(logits.shape) 384 | y_soft = F.softmax(logits / temperature, dim=-1) 385 | shape = y_soft.size() 386 | ind = y_soft.argmax(dim=-1) 387 | y_hard = torch.zeros_like(y_soft).view(-1, shape[-1]) 388 | y_hard.scatter_(1, ind.view(-1, 1), 1) 389 | y_hard = y_hard.view(*y_soft.shape) 390 | y_hard = (y_hard - y_soft).detach() + y_soft 391 | return y_hard 392 | 393 | def removeNonAscii(s): 394 | if isinstance(s, str): 395 | return "".join(i for i in s if ord(i)<128) 396 | else: 397 | return s 398 | 399 | def bootstrap_diff_in_means(means1, means2, boottimes=1e5): 400 | return 401 | 402 | ### END MISC ### --------------------------------------------------------------------------------