├── scripts ├── run_contrastive_decoding.sh └── run_counterfactual_training.sh ├── utils.py ├── prompts ├── creak.explanation.txt ├── strategyQA.explanation.txt ├── csqa.explanation.txt └── qasc.explanation.txt ├── README.md ├── generate_utils.py ├── contrastive_decoding_rationalization.py ├── data_helper.py └── main.py /scripts/run_contrastive_decoding.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dataset="strategyQA" 4 | prompt="explanation" 5 | 6 | output_prefix="outputs/${dataset}/" 7 | mkdir -p $output_prefix 8 | 9 | python -u \ 10 | contrastive_decoding_rationalization.py \ 11 | --output_prefix $output_prefix \ 12 | --dataset $dataset \ 13 | --prompt $prompt \ 14 | >${output_prefix}/rationalization.log 2>&1 & 15 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def get_logger(name, log_path=None): 4 | 5 | logger = logging.getLogger(name) 6 | logger.setLevel(logging.DEBUG) 7 | formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S') 8 | 9 | if log_path: 10 | handler = logging.FileHandler(log_path, 'w') 11 | handler.setLevel(logging.INFO) 12 | handler.setFormatter(formatter) 13 | logger.addHandler(handler) 14 | 15 | return logger 16 | 17 | -------------------------------------------------------------------------------- /prompts/creak.explanation.txt: -------------------------------------------------------------------------------- 1 | {"prompt": "Q: Only people named Floyd wearing pink are allowed to attend Pink Floyd concerts.\nA: The statement is false. The rock group would not be as popular if they had such requirements for their concerts.\n\nQ: Marlboro used iconic imagery to promote its brand.\nA: The statement is true. Marlboro (cigarette) used cowboys as an advertising campaign.\n\nQ: Fax works without any internet connection.\nA: The statement is false. Internet connection is required for a fax to function well.\n\nQ: Larry King served tea during his show.\nA: The statement is false. He had a set format that did not involve tea.\n\nQ: The crack in the Liberty Bell sets it apart from other famous bells.\nA: The statement is true. The Liberty Bell is famous for having a large crack in its side.\n\nQ: A Jury may decide whether a criminal is guilty or innocent.\nA: The statement is true. Juries are used in the American legal system to decide the outcomes of criminal trials.\n\nQ: {}\nA: The statement is {}."} 2 | -------------------------------------------------------------------------------- /prompts/strategyQA.explanation.txt: -------------------------------------------------------------------------------- 1 | {"prompt": "Q: Do hamsters provide food for any animals?\nA: The answer is yes. Hamsters are prey animals. Prey animals provide food for predators.\n\nQ: Could Brooke Shields succeed at University of Pennsylvania?\nA: The answer is yes. Brooke Shields went to Princeton University. Princeton University is about as academically rigorous as the University of Pennsylvania.\n\nQ: Hydrogen's atomic number squared exceeds number of Spice Girls?\nA: The answer is no. Hydrogen has an atomic number of 1. 1 squared is 1. There are 5 Spice Girls.\n\nQ: Is it common to see frost during some college commencements?\nA: The answer is yes. College commencement ceremonies can happen in December, May, and June. December is in the winter, so there can be frost.\n\nQ: Could a llama birth twice during War in Vietnam (1945-46)?\nA: The answer is no. The War in Vietnam was 6 months. The gestation period for a llama is 11 months, which is more than 6 months.\n\nQ: Would a pear sink in water?\nA: The answer is no. The density of a pear is about 0.6 g/cm^3, which is less than water. Objects less dense than water float.\n\nQ: {}\nA: The answer is {}."} 2 | -------------------------------------------------------------------------------- /scripts/run_counterfactual_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dataset="strategyQA" 4 | model_name="t5-3b" 5 | max_enc_length=128 6 | max_dec_length=128 7 | train_batch_size=16 8 | eval_batch_size=32 9 | grad_step=1 10 | learning_rate=3e-5 11 | warmup_ratio=0.06 12 | weight_decay=0 13 | num_epoch=10 14 | num_epoch_early_stopping=1 15 | 16 | counterfactual_alpha=0.5 17 | 18 | save_dir="checkpoints/${dataset}/counterfactual${counterfactual_alpha}_${model_name}_bs${train_batch_size}_gs${grad_step}_lr${learning_rate}_wd${weight_decay}_e${num_epoch}" 19 | mkdir -p $save_dir 20 | 21 | python -u \ 22 | main.py \ 23 | --num_epoch_early_stopping $num_epoch_early_stopping \ 24 | --add_task_prefix \ 25 | --counterfactual_alpha $counterfactual_alpha \ 26 | --dataset $dataset \ 27 | --save_dir $save_dir \ 28 | --model_name $model_name \ 29 | --max_enc_length $max_enc_length \ 30 | --max_dec_length $max_dec_length \ 31 | --train_batch_size $train_batch_size \ 32 | --eval_batch_size $eval_batch_size \ 33 | --grad_step $grad_step \ 34 | --learning_rate $learning_rate \ 35 | --warmup_ratio $warmup_ratio \ 36 | --weight_decay $weight_decay \ 37 | --num_epoch $num_epoch \ 38 | > ${save_dir}/debug.log 2>&1 & 39 | 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCOTT 2 | 3 | This is a Pytorch implementation for our ACL 2023 outstanding paper "SCOTT: Self-Consistent Chain-of-Thought Distillation" [[arxiv](https://arxiv.org/abs/2305.01879)]. 4 | 5 | ## Contrastive Decoding 6 | 7 | ### 1. Prepare the data 8 | 9 | Download the dataset: [StrategyQA](https://allenai.org/data/strategyqa)/[CommonsenseQA](https://www.tau-nlp.sites.tau.ac.il/commonsenseqa)/[CREAK](https://github.com/yasumasaonoe/creak)/[QASC](https://allenai.org/data/qasc). 10 | 11 | Split the dataset into `train/dev/test.jsonl` subsets. Also build a `train.counterfactual.jsonl` subset by perturbing the answers in the `train.jsonl` subset. Organize all the subsets in the folder `data/DATASET`. 12 | 13 | ### 2. Obtain the rationales 14 | 15 | ```bash 16 | ./scripts/run_contrastive_decoding.sh 17 | ``` 18 | The generated rationales would be stored at `outputs/DATASET`. 19 | 20 | ## Counterfactual Training 21 | 22 | ```bash 23 | ./scripts/run_counterfactual_training.sh 24 | ``` 25 | After training, the checkpoints and the evaluation result will be stored at `checkpoints/DATASET`. 26 | 27 | ## Citation 28 | 29 | ``` 30 | @inproceedings{wang-etal-2023-scott, 31 | title = "{SCOTT}: Self-Consistent Chain-of-Thought Distillation", 32 | author = "Wang, Peifeng and 33 | Wang, Zhengyang and 34 | Li, Zheng and 35 | Gao, Yifan and 36 | Yin, Bing and 37 | Ren, Xiang", 38 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 39 | month = jul, 40 | year = "2023", 41 | address = "Toronto, Canada", 42 | publisher = "Association for Computational Linguistics", 43 | url = "https://aclanthology.org/2023.acl-long.304", 44 | doi = "10.18653/v1/2023.acl-long.304", 45 | pages = "5546--5558", 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /prompts/csqa.explanation.txt: -------------------------------------------------------------------------------- 1 | {"prompt": "Q: What do people use to absorb extra ink from a fountain pen?\nAnswer Choices:\n(a) shirt pocket\n(b) calligrapher's hand\n(c) inkwell\n(d) desk drawer\n(e) blotter\nA: The answer is blotter. Blotting paper absorbs liquids like ink well.\n\nQ: What home entertainment equipment requires cable?\nAnswer Choices:\n(a) radio shack\n(b) substation\n(c) cabinet\n(d) television\n(e) desk\nA: The answer is television. Cable can be fixed to a television. Television is a home entertainment equipment.\n\nQ: The fox walked from the city into the forest, what was it looking for?\nAnswer Choices:\n(a) pretty flowers\n(b) hen house\n(c) natural habitat\n(d) storybook\n(e) dense forest\nA: The answer is natural habitat. Forests are one of the main natural habitats of foxes.\n\nQ: Sammy wanted to go to where the people were. Where might he go?\nAnswer Choices:\n(a) populated areas\n(b) race track\n(c) desert\n(d) apartment\n(e) roadblock\nA: The answer is populated areas. Populated areas are where there are a lot of people.\n\nQ: Where do you put your grapes just before checking out?\nAnswer Choices:\n(a) mouth\n(b) grocery cart\n(c) super market\n(d) fruit basket\n(e) fruit market\nA: The answer is grocery cart. Grocery cart is used in stores by customers to collect purchases. Checking out of purchases is done in checkout area of stores.\n\nQ: Google Maps and other highway and street GPS services have replaced what?\nAnswer Choices:\n(a) united states\n(b) mexico\n(c) countryside\n(d) atlas\n(e) oceans\nA: The answer is atlas. Atlas are replaced by more precise Google maps, other highway and street GPS services. One can get much more precise data with the help of Google maps and Street GPS services.\n\nQ: Before getting a divorce, what did the wife feel who was doing all the work?\nAnswer Choices:\n(a) harder\n(b) anguish\n(c) bitterness\n(d) tears\n(e) sadness\nA: The answer is bitterness. Bitterness is the resentful feeling of anger at being treated unfairly. Doing all the work means being treated unfairly.\n\nQ: {}\nA: The answer is {}."} 2 | -------------------------------------------------------------------------------- /prompts/qasc.explanation.txt: -------------------------------------------------------------------------------- 1 | {"prompt": "Q: How do you reduce pollution?\nAnswer choices:\n(a) igniting fuel and oxidiser\n(b) transportation technology\n(c) wasting\n(d) not recycling\n(e) burning fossil fuels\n(f) converting electricity to heat\n(g) water conservation\n(h) using less resources\nA: The answer is using less resources. Conserving resources has a positive impact on the environment. Use of resources affects the environment such as pollution.\n\nQ: what will move to another area if their habitat will no longer support them?\nAnswer choices:\n(a) density\n(b) Birds\n(c) squids\n(d) humans\n(e) clouds\n(f) gravity\n(g) cows\n(h) Whales\nA: The answer is cows. If a habitat can no longer support animals then those animals will move to another area. Cows are social animals.\n\nQ: With the exception of allergies, what may cause a person to seek medical attention?\nAnswer choices:\n(a) Contact with latex\n(b) a tree falling\n(c) Organs within the body.\n(d) Contact with baby chicks\n(e) prolactin release\n(f) Contact with peanut butter\n(g) hypothyroidism\n(h) Contact with microorganisms\nA: The answer is Contact with microorganisms. Microorganisms can cause infections. Infections usually require medical treatment.\n\nQ: Lavender can induce\nAnswer choices:\n(a) healing\n(b) energy\n(c) hormones\n(d) mutations\n(e) Heart rate\n(f) growth\n(g) symptoms\n(h) warmth\nA: The answer is healing. Healing requires rest. Lavender induces restful sleep.\n\nQ: what state is a liquid in when frozen?\nAnswer choices:\n(a) vapor\n(b) dense\n(c) gas\n(d) cooled\n(e) steam\n(f) solid\n(g) boiling\n(h) cold\nA: The answer is solid. Freezing means changing from a liquid into a solid by reducing heat energy. Liquids freeze when they change to the solid state.\n\nQ: what unites to form a diploid zygote?\nAnswer choices:\n(a) plant reproduction\n(b) Most plants\n(c) orchids\n(d) sperm and ova\n(e) salt and pepper\n(f) predator and prey\n(g) honeybees\n(h) diploids and zygotes\nA: The answer is sperm and ova. Gametes then unite in fertilization and form a diploid zygote. Collectively, the sperm and the ova are also referred to as gametes .\n\nQ: What absorbs all visible light?\nAnswer choices:\n(a) apples\n(b) coal\n(c) Green\n(d) coral\n(e) skin\n(f) bamboo\n(g) glass\n(h) eyes\nA: The answer is coal. If an object is black then that object absorbs all visible light. Light grains are quartz, Black grains are coal.\n\nQ: {}\nA: The answer is {}."} -------------------------------------------------------------------------------- /generate_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | from tqdm import tqdm, trange 5 | import numpy as np 6 | import math 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 11 | 12 | from transformers import set_seed, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup, get_constant_schedule_with_warmup 13 | from transformers.optimization import Adafactor 14 | 15 | from utils import get_logger 16 | 17 | def generation(inputs, model, tokenizer, args): 18 | 19 | with torch.no_grad(): 20 | pred_ids = model.generate( 21 | input_ids=inputs.input_ids, 22 | attention_mask=inputs.attention_mask, 23 | max_length=args.max_dec_length, 24 | decoder_start_token_id=model.config.decoder_start_token_id, 25 | eos_token_id=tokenizer.eos_token_id, 26 | pad_token_id=tokenizer.pad_token_id, 27 | early_stopping=True, 28 | num_return_sequences=1, #args.num_return_sequences, 29 | num_beams=args.num_beams, 30 | do_sample=args.sample, 31 | top_p=args.top_p, 32 | top_k=args.top_k, 33 | use_cache=True 34 | ) 35 | 36 | batch_output = [tokenizer.decode(beam, skip_special_tokens=True) for beam in pred_ids] 37 | 38 | return batch_output 39 | 40 | def generation_with_prefix(inputs, decoder_input_ids, model, tokenizer, args): 41 | 42 | input_length = len(decoder_input_ids[0]) 43 | with torch.no_grad(): 44 | pred_ids = model.generate( 45 | input_ids=inputs.input_ids, 46 | attention_mask=inputs.attention_mask, 47 | max_length=args.max_dec_length, 48 | decoder_start_token_id=model.config.decoder_start_token_id, 49 | decoder_input_ids=decoder_input_ids, 50 | eos_token_id=tokenizer.eos_token_id, 51 | pad_token_id=tokenizer.pad_token_id, 52 | early_stopping=True, 53 | num_return_sequences=1, #args.num_return_sequences, 54 | num_beams=args.num_beams, 55 | do_sample=args.sample, 56 | top_p=args.top_p, 57 | top_k=args.top_k, 58 | use_cache=True 59 | ) 60 | 61 | batch_output = [tokenizer.decode(beam[input_length:], skip_special_tokens=True) for beam in pred_ids] 62 | 63 | return batch_output 64 | 65 | 66 | -------------------------------------------------------------------------------- /contrastive_decoding_rationalization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import random 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, set_seed 8 | from tqdm import tqdm 9 | 10 | torch.set_num_threads(4) 11 | 12 | # for REPRODUCIBILITY 13 | set_seed(42) 14 | 15 | # ----------------------------------------------------- # 16 | # hyper-parameters 17 | num_return_sequences = 1 18 | generation_length = 128 19 | 20 | # ----------------------------------------------------- # 21 | 22 | def contrastive_decoding(input_seq1, input_seq2, model, tokenizer, indicator_token_ids, args): 23 | inputs1 = tokenizer(input_seq1, truncation=True, return_tensors='pt').to(args.device) 24 | input_length1 = len(inputs1.input_ids[0]) 25 | generated1 = inputs1.input_ids 26 | past_key_values1 = None 27 | 28 | inputs2 = tokenizer(input_seq2, truncation=True, return_tensors='pt').to(args.device) 29 | input_length2 = len(inputs2.input_ids[0]) 30 | generated2 = inputs2.input_ids 31 | past_key_values2 = None 32 | 33 | with torch.no_grad(): 34 | for step in range(generation_length): 35 | # get probs given by the original teacher 36 | attention_mask1 = generated1.new_ones(generated1.shape) 37 | outputs1 = model( 38 | input_ids=generated1 if past_key_values1 is None else generated1[:, -1:], 39 | past_key_values=past_key_values1, 40 | attention_mask=attention_mask1, 41 | ) 42 | logits1 = outputs1.logits[:, -1, :] 43 | past_key_values1 = outputs1.past_key_values 44 | prob1 = F.log_softmax(logits1 / args.temperature, dim=-1) 45 | 46 | candidate_next_token = prob1.argmax(dim=-1, keepdim=True) 47 | if candidate_next_token[0].item() == indicator_token_ids["stop"]: 48 | break 49 | 50 | # get probs given by the hallucinating teacher 51 | attention_mask2 = generated2.new_ones(generated2.shape) 52 | outputs2 = model( 53 | input_ids=generated2 if past_key_values2 is None else generated2[:, -1:], 54 | past_key_values=past_key_values2, 55 | attention_mask=attention_mask2, 56 | ) 57 | logits2 = outputs2.logits[:, -1, :] 58 | past_key_values2 = outputs2.past_key_values 59 | prob2 = F.log_softmax(logits2, dim=-1) 60 | 61 | # contrastive decoding 62 | debiased_prob = prob1 - args.interpolation * prob2 63 | next_token = debiased_prob.argmax(dim=-1, keepdim=True) 64 | 65 | if next_token[0] == indicator_token_ids["stop"]: 66 | break 67 | 68 | generated1 = torch.cat((generated1, next_token), dim=1) 69 | generated2 = torch.cat((generated2, next_token), dim=1) 70 | 71 | generation = tokenizer.decode(generated1[0][input_length1:], skip_special_tokens=True) 72 | return generation 73 | 74 | def main(args): 75 | # ----------------------------------------------------- # 76 | # load LM 77 | model_path = 'EleutherAI/gpt-neox-20b' 78 | model_name = "GPT-neox" 79 | tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir='../cache') #, use_fast=False) 80 | n_gpus = 1 81 | free_in_GB = 49 82 | max_memory = {i: "{}GB".format(free_in_GB) for i in range(args.gpu, args.gpu + n_gpus)} 83 | 84 | from transformers import GPTNeoXForCausalLM 85 | model = GPTNeoXForCausalLM.from_pretrained( 86 | model_path, 87 | device_map='auto', 88 | max_memory = max_memory, 89 | cache_dir='../cache', 90 | torch_dtype='auto' 91 | ) 92 | 93 | indicator_token_ids = { 94 | "stop": tokenizer.encode("\n\nQ")[-2], 95 | } 96 | 97 | model.eval() 98 | 99 | # ----------------------------------------------------- # 100 | # prepare data 101 | with open('./prompts/{}.{}.txt'.format(args.dataset, args.prompt), 'r') as fr: 102 | prompt = json.load(fr)["prompt"] 103 | 104 | print(prompt) 105 | prompt_without_question = '\n\n'.join(prompt.split('\n\n')[:-1])+'\n\n' 106 | 107 | for split in args.eval_split.split(','): 108 | with open('./data/{}/{}.jsonl'.format(args.dataset, split), 'r') as fr: 109 | examples = [json.loads(line) for line in fr.readlines()] 110 | 111 | # ----------------------------------------------------- # 112 | # inference 113 | 114 | output_path = os.path.join(args.output_prefix, '{}.jsonl'.format(split)) 115 | 116 | fw = open(output_path, 'w', buffering=1) 117 | for example in tqdm(examples): 118 | if "context" in example: 119 | formatted_question = example["context"] 120 | choices = ["false", "true"] 121 | if "counterfactual" in split: 122 | answer = "false" if example["answer"] == 1 else "true" 123 | wrong_answer = "false" if example["answer"] == 0 else "true" 124 | else: 125 | answer = "false" if example["answer"] == 0 else "true" 126 | wrong_answer = "false" if example["answer"] == 1 else "true" 127 | question = example["context"] 128 | else: 129 | formatted_question = example["question"] 130 | choices = example["choices"] if "choices" in example else ["no", "yes"] 131 | question = example["question"] 132 | if "choices" in example and len(example["choices"]) > 2: 133 | if "counterfactual" in split: 134 | answer = random.choice(example["choices"][:example["answer"]] + example["choices"][example["answer"]+1:]) 135 | wrong_answer = example["choices"][example["answer"]] 136 | else: 137 | answer = example["choices"][example["answer"]] 138 | wrong_answer = random.choice(example["choices"][:example["answer"]] + example["choices"][example["answer"]+1:]) 139 | else: 140 | if "counterfactual" in split: 141 | answer = "yes" if example["answer"] == 0 else "no" 142 | wrong_answer = "yes" if example["answer"] == 1 else "no" 143 | else: 144 | answer = "yes" if example["answer"] == 1 else "no" 145 | wrong_answer = "yes" if example["answer"] == 0 else "no" 146 | 147 | if "choices" in example and len(example["choices"]) > 2: 148 | choices_seq = "" 149 | formatted_question += "\nAnswer Choices:" 150 | for choice_id, choice in enumerate(example["choices"]): 151 | formatted_question += "\n({}) {}".format(chr(ord('a')+choice_id), choice) 152 | choices_seq += " ({}) {}".format(chr(ord('A')+choice_id), choice) 153 | 154 | input_seq1 = prompt.format(formatted_question, answer) 155 | input_seq2 = prompt.format(formatted_question, wrong_answer) # replace wrong_answer with "" if using empty string as the perturbed answer 156 | if args.debug: 157 | print(input_seq2) 158 | print(input_seq3) 159 | generation = contrastive_decoding(input_seq1, input_seq2, model, tokenizer, indicator_token_ids, args) 160 | 161 | if "context" in example: 162 | fw.write(json.dumps({"id": example["id"], "answer": answer, "statement": question, "explanation": generation_list}) + "\n") 163 | else: 164 | if "choices" in example and len(example["choices"]) > 2: 165 | fw.write(json.dumps({"id": example["id"], "answer": answer, "question": question, "choices": choices_seq.strip(), "explanation": generation.strip()}) + "\n") 166 | else: 167 | fw.write(json.dumps({"id": example["id"], "answer": answer, "question": question, "explanation": generation.strip()}) + "\n") 168 | fw.close() 169 | 170 | # ----------------------------------------------------- # 171 | 172 | if __name__ == "__main__": 173 | 174 | parser = argparse.ArgumentParser(description='Run main.') 175 | parser.add_argument('--dataset', '-d', type=str) 176 | parser.add_argument('--output_prefix', '-o', type=str) 177 | parser.add_argument('--prompt', '-p', type=str) 178 | parser.add_argument('--num_process', type=int, default=1) 179 | parser.add_argument('--eval_split', type=str, default='test,dev,train,train.counterfactual') 180 | parser.add_argument("--debug", action='store_true') 181 | 182 | # debiased factor 183 | parser.add_argument('--interpolation', type=float, default=0.5) 184 | 185 | # decoding strategy 186 | parser.add_argument('--temperature', type=float, default=1.0) 187 | 188 | # gpu and workers option 189 | parser.add_argument('--gpu', type=int, default=0) 190 | 191 | args = parser.parse_args() 192 | args.device = torch.device('cuda:{}'.format(args.gpu)) 193 | 194 | main(args) 195 | 196 | -------------------------------------------------------------------------------- /data_helper.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | from dataclasses import dataclass 5 | from typing import List, Optional 6 | import random 7 | 8 | import torch 9 | from torch.utils.data import Dataset, TensorDataset 10 | 11 | @dataclass(frozen=True) 12 | class InputExample: 13 | 14 | qid: str 15 | question: str 16 | explanation: List[str] 17 | choices: str 18 | answer: str 19 | is_statement: bool 20 | 21 | class TrainingDataset(Dataset): 22 | 23 | features: List[InputExample] 24 | 25 | def __init__(self, features): 26 | self.features = features 27 | 28 | def __len__(self): 29 | return len(self.features) 30 | 31 | def __getitem__(self, i) -> InputExample: 32 | return self.features[i] 33 | 34 | def load_raw_dataset(split, args): 35 | data_path = os.path.join('./outputs', args.dataset, '{}.jsonl'.format(split)) 36 | dataset = [] 37 | 38 | with open(data_path, 'r') as fr: 39 | for line_idx, line in tqdm(enumerate(fr), desc='processing {}'.format(data_path)): 40 | example = json.loads(line) 41 | dataset.append( 42 | InputExample( 43 | qid=example["id"], 44 | question=example["statement"] if "statement" in example else example["question"], 45 | explanation=example["explanation"], 46 | choices=example["choices"] if "choices" in example else None, 47 | answer=example["answer"], 48 | is_statement="statement" in example, 49 | ) 50 | ) 51 | 52 | 53 | for example in dataset[:2]: 54 | print("*** Example ***") 55 | print(example) 56 | 57 | return dataset 58 | 59 | def get_label_tensor(raw_label, tokenizer, args): 60 | label_ids = tokenizer.encode(raw_label, add_special_tokens=False) 61 | label_ids += [tokenizer.eos_token_id] 62 | label_ids = label_ids[:args.max_dec_length] 63 | label_ids += [-100] * (args.max_dec_length - len(label_ids)) 64 | return label_ids 65 | 66 | def get_label_tensor_answer_only(raw_label, raw_label_without_answer, tokenizer, args): 67 | label_ids = tokenizer.encode(raw_label, add_special_tokens=False) 68 | label_ids += [tokenizer.eos_token_id] 69 | label_ids = label_ids[:args.max_dec_length] 70 | label_ids += [-100] * (args.max_dec_length - len(label_ids)) 71 | 72 | label_ids_without_answer = tokenizer.encode(raw_label_without_answer, add_special_tokens=False) 73 | label_ids_without_answer = label_ids_without_answer[:args.max_dec_length] 74 | 75 | label_ids_answer_only = label_ids.copy() 76 | for idx in range(len(label_ids_without_answer)): 77 | label_ids_answer_only[idx] = -100 78 | 79 | decoder_input_ids = [tokenizer.pad_token_id] + [tokenizer.pad_token_id if _id == -100 else _id for _id in label_ids[:-1]] 80 | 81 | return decoder_input_ids, label_ids_answer_only 82 | 83 | def format_input(context, choices=None, counterfactual=False, add_task_prefix=True): 84 | input_seq = "" 85 | if add_task_prefix: 86 | if counterfactual: 87 | input_seq += "[counterfactual] " 88 | else: 89 | input_seq += "[factual] " 90 | input_seq += context.strip() 91 | if choices is not None: 92 | input_seq += " \\n {}".format(choices.strip()) 93 | return input_seq 94 | 95 | def format_output(explanation, answer, counterfactual=False, without_explanation=False, add_task_prefix=True): 96 | output_seq = "" 97 | if add_task_prefix: 98 | if counterfactual: 99 | output_seq += "[counterfactual] " 100 | else: 101 | output_seq += "[factual] " 102 | 103 | if not without_explanation: 104 | output_seq += explanation.strip() 105 | output_seq += ' So the answer is ' 106 | output_seq_with_answer = output_seq + answer.strip() 107 | return output_seq_with_answer, output_seq.strip() 108 | 109 | class Data_Collator_for_Training(object): 110 | def __init__(self, tokenizer, args, counterfactual=False): 111 | self.tokenizer = tokenizer 112 | self.args = args 113 | self.counterfactual = counterfactual 114 | 115 | def __call__(self, examples): 116 | encoder_input_tensor = [] 117 | encoder_attention_mask_tensor = [] 118 | decoder_label_tensor = [] 119 | decoder_input_ids_tensor = [] 120 | 121 | for example_idx, example in enumerate(examples): 122 | input_seq = format_input(example.question, example.choices, counterfactual=self.counterfactual, add_task_prefix=self.args.add_task_prefix) 123 | inputs = self.tokenizer(input_seq, padding='max_length', max_length=self.args.max_enc_length, truncation=True) 124 | 125 | if isinstance(example.explanation, list): 126 | explanation = random.choice(example.explanation) 127 | else: 128 | explanation = example.explanation 129 | output_seq, output_seq_without_answer = format_output(explanation, example.answer, counterfactual=self.counterfactual, without_explanation=self.args.without_explanation, add_task_prefix=self.args.add_task_prefix) 130 | 131 | encoder_input_tensor.append(inputs['input_ids']) 132 | encoder_attention_mask_tensor.append(inputs['attention_mask']) 133 | 134 | if self.counterfactual: 135 | decoder_input_ids, decoder_label = get_label_tensor_answer_only(output_seq, output_seq_without_answer, self.tokenizer, self.args) 136 | decoder_input_ids_tensor.append(decoder_input_ids) 137 | decoder_label_tensor.append(decoder_label) 138 | else: 139 | decoder_label_tensor.append(get_label_tensor(output_seq, self.tokenizer, self.args)) 140 | 141 | if self.counterfactual: 142 | return tuple(torch.tensor(t) for t in [encoder_input_tensor, encoder_attention_mask_tensor, decoder_label_tensor, decoder_input_ids_tensor]) 143 | else: 144 | return tuple(torch.tensor(t) for t in [encoder_input_tensor, encoder_attention_mask_tensor, decoder_label_tensor]) 145 | 146 | def get_tensor_dataset(split, tokenizer, args, counterfactual=False): 147 | data_path = os.path.join('./data', args.dataset, '{}.jsonl'.format(split)) 148 | 149 | encoder_input_tensor = [] 150 | encoder_attention_mask_tensor = [] 151 | decoder_label_tensor = [] 152 | decoder_input_ids_tensor = [] 153 | 154 | with open(data_path, 'r') as fr: 155 | for line_idx, line in tqdm(enumerate(fr), desc='processing {}'.format(data_path)): 156 | example = json.loads(line) 157 | 158 | if "question" in example: 159 | if "choices" in example: 160 | input_seq = format_input(example["question"], example["choices"], counterfactual=counterfactual, add_task_prefix=args.add_task_prefix) 161 | else: 162 | input_seq = format_input(example["question"], counterfactual=counterfactual, add_task_prefix=args.add_task_prefix) 163 | else: 164 | input_seq = format_input(example["statement"], counterfactual=counterfactual, add_task_prefix=args.add_task_prefix) 165 | 166 | inputs = tokenizer(input_seq, padding='max_length', max_length=args.max_enc_length, truncation=True) 167 | 168 | if isinstance(example["explanation"], list): 169 | for explanation in example["explanation"][:5]: 170 | output_seq, output_seq_without_answer = format_output(explanation, example["answer"], counterfactual=counterfactual, without_explanation=args.without_explanation, add_task_prefix=args.add_task_prefix) 171 | 172 | encoder_input_tensor.append(inputs['input_ids']) 173 | encoder_attention_mask_tensor.append(inputs['attention_mask']) 174 | 175 | if counterfactual: 176 | decoder_input_ids, decoder_label = get_label_tensor_answer_only(output_seq, output_seq_without_answer, tokenizer, args) 177 | decoder_input_ids_tensor.append(decoder_input_ids) 178 | decoder_label_tensor.append(decoder_label) 179 | else: 180 | decoder_label_tensor.append(get_label_tensor(output_seq, tokenizer, args)) 181 | 182 | else: 183 | output_seq, output_seq_without_answer = format_output(example["explanation"], example["answer"], counterfactual=counterfactual, without_explanation=args.without_explanation, add_task_prefix=args.add_task_prefix) 184 | 185 | encoder_input_tensor.append(inputs['input_ids']) 186 | encoder_attention_mask_tensor.append(inputs['attention_mask']) 187 | 188 | if counterfactual: 189 | decoder_input_ids, decoder_label = get_label_tensor_answer_only(output_seq, output_seq_without_answer, tokenizer, args) 190 | decoder_input_ids_tensor.append(decoder_input_ids) 191 | decoder_label_tensor.append(decoder_label) 192 | else: 193 | decoder_label_tensor.append(get_label_tensor(output_seq, tokenizer, args)) 194 | 195 | encoder_input_tensor = torch.tensor(encoder_input_tensor, dtype=torch.long) 196 | encoder_attention_mask_tensor= torch.tensor(encoder_attention_mask_tensor, dtype=torch.long) 197 | decoder_label_tensor = torch.tensor(decoder_label_tensor, dtype=torch.long) 198 | if counterfactual: 199 | decoder_input_ids_tensor = torch.tensor(decoder_input_ids_tensor, dtype=torch.long) 200 | for f1, f2, f3 in zip(encoder_input_tensor[:2], encoder_attention_mask_tensor[:2], decoder_label_tensor[:2]): 201 | print("*** Example ***") 202 | print("encoder input: %s" % tokenizer.decode(f1)) 203 | print("encoder attention mask: %s" % f2) 204 | print("decoder output: %s" % tokenizer.decode([tid for tid in f3 if not tid == -100])) 205 | if counterfactual: 206 | for f4 in decoder_input_ids_tensor[:2]: 207 | print("decoder input: %s" % tokenizer.decode(f4)) 208 | 209 | return TensorDataset(encoder_input_tensor, encoder_attention_mask_tensor, decoder_label_tensor, decoder_input_ids_tensor) 210 | else: 211 | return TensorDataset(encoder_input_tensor, encoder_attention_mask_tensor, decoder_label_tensor) 212 | 213 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | from tqdm import tqdm, trange 5 | import numpy as np 6 | import math 7 | import random 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 12 | 13 | from transformers import set_seed, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup, get_constant_schedule_with_warmup 14 | from transformers.optimization import Adafactor 15 | 16 | from data_helper import get_tensor_dataset, load_raw_dataset, format_input, format_output, Data_Collator_for_Training 17 | from generate_utils import generation, generation_with_prefix 18 | 19 | import logging 20 | def get_logger(name, log_path=None): 21 | 22 | logger = logging.getLogger(name) 23 | logger.setLevel(logging.DEBUG) 24 | formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S') 25 | 26 | if log_path: 27 | handler = logging.FileHandler(log_path, 'w') 28 | handler.setLevel(logging.INFO) 29 | handler.setFormatter(formatter) 30 | logger.addHandler(handler) 31 | 32 | return logger 33 | 34 | def evaluate(dataset, model, args): 35 | 36 | data_sampler = SequentialSampler(dataset) 37 | dataloader = DataLoader(dataset, 38 | sampler=data_sampler, 39 | batch_size=args.eval_batch_size) 40 | model.eval() 41 | epoch_iterator = tqdm(dataloader, desc="Eval Iteration") 42 | 43 | loss_sum = 0. 44 | ppl_sum = 0. 45 | tokens_sum = 0. 46 | for step, batch in enumerate(epoch_iterator): 47 | 48 | input_ids, attention_mask, text_labels = tuple(t.to(args.device) for t in batch) 49 | 50 | with torch.no_grad(): 51 | outputs = model( 52 | input_ids=input_ids, 53 | attention_mask=attention_mask, 54 | labels=text_labels 55 | ) 56 | 57 | loss = outputs.loss 58 | num_tokens = (text_labels != -100).sum().item() 59 | tokens_sum += num_tokens 60 | ppl_sum += outputs.loss.item() * num_tokens 61 | 62 | loss_sum += loss.item() 63 | if args.debug and step > 10: 64 | break 65 | 66 | loss_sum /= (step + 1) 67 | ppl_sum = math.exp(ppl_sum / tokens_sum) 68 | 69 | return {"loss": loss_sum, "perplexity": ppl_sum} 70 | 71 | def inference(dataset, output_path, model, tokenizer, args): 72 | batch_input = [] 73 | batch_output_prefix = [] 74 | batch_example = [] 75 | example_idx = 0 76 | if output_path is not None: 77 | fw = open(output_path, 'w') 78 | accuracy = 0. 79 | generated_explanation = [] 80 | if args.add_task_prefix: 81 | output_prefix = ' [factual]' 82 | else: 83 | output_prefix = ' ' 84 | model.eval() 85 | for example in tqdm(dataset): 86 | batch_example.append(example) 87 | input_seq = format_input(example.question, example.choices) 88 | 89 | batch_input.append(input_seq) 90 | batch_output_prefix.append(output_prefix) 91 | 92 | if len(batch_input) == args.eval_batch_size or example_idx == len(dataset) - 1: 93 | inputs = tokenizer(batch_input, padding='max_length', max_length=args.max_enc_length, truncation=True, return_tensors='pt').to(args.device) 94 | decoder_input_ids = tokenizer(batch_output_prefix, add_special_tokens=False, return_tensors='pt').to(args.device).input_ids 95 | batch_output = generation_with_prefix(inputs, decoder_input_ids, model, tokenizer, args) 96 | for example, output in zip(batch_example, batch_output): 97 | answer_prefix = "So the answer is " 98 | generation_split = output.split(answer_prefix) 99 | generated_explanation.append(generation_split[0].strip()) 100 | if len(generation_split) == 1: 101 | continue 102 | explanation = generation_split[0].strip() 103 | prediction = generation_split[1].strip() 104 | if prediction == example.answer: 105 | accuracy += 1 106 | if output_path is not None: 107 | output_example = {"id": example.qid} 108 | output_example["question"] = example.question 109 | output_example["answer"] = prediction 110 | if example.choices is not None: 111 | output_example["choices"] = [span.split(') ')[1].strip() for span in example.choices.split('(')[1:]] 112 | else: 113 | if example.is_statement: 114 | output_example["choices"] = ["false", "true"] 115 | else: 116 | output_example["choices"] = ["no", "yes"] 117 | if not args.without_explanation: 118 | output_example["explanation"] = explanation 119 | 120 | fw.write(json.dumps(output_example)+'\n') 121 | batch_input = [] 122 | batch_example = [] 123 | batch_output_prefix = [] 124 | example_idx += 1 125 | if args.debug and example_idx > 50: 126 | break 127 | 128 | if output_path is not None: 129 | fw.close() 130 | return accuracy * 100. / len(dataset), generated_explanation 131 | 132 | def inference_with_oracle(dataset, model, tokenizer, args): 133 | example_idx = 0 134 | accuracy = 0. 135 | model.eval() 136 | for example in tqdm(dataset): 137 | input_seq = format_input(example.question, example.choices) 138 | answer_prefix = " So the answer is" 139 | 140 | inputs = tokenizer(input_seq, padding='max_length', max_length=args.max_enc_length, truncation=True, return_tensors='pt').to(args.device) 141 | if args.add_task_prefix: 142 | output_prefix = ' [factual]' + example.explanation + answer_prefix 143 | else: 144 | output_prefix = ' ' + example.explanation + answer_prefix 145 | decoder_input_ids = tokenizer(output_prefix, add_special_tokens=False, return_tensors='pt').to(args.device).input_ids 146 | prediction = generation_with_prefix(inputs, decoder_input_ids, model, tokenizer, args)[0].strip() 147 | if prediction == example.answer: 148 | accuracy += 1 149 | example_idx += 1 150 | if args.debug and example_idx > 50: 151 | break 152 | 153 | return accuracy * 100. / len(dataset) 154 | 155 | def inference_with_perturb(dataset, explanations, model, tokenizer, args, replace_ratio=0.5): 156 | example_idx = 0 157 | accuracy = 0. 158 | model.eval() 159 | for example in tqdm(dataset): 160 | input_seq = format_input(example.question, example.choices) 161 | answer_prefix = " So the answer is" 162 | answer_prefix_ids = tokenizer.encode(answer_prefix, add_special_tokens=False) 163 | 164 | inputs = tokenizer(input_seq, padding='max_length', max_length=args.max_enc_length, truncation=True, return_tensors='pt').to(args.device) 165 | explanation_ids = tokenizer.encode(explanations[example_idx], add_special_tokens=False) 166 | explanation_length = len(explanation_ids) 167 | mask_idx = random.sample(range(explanation_length), int(explanation_length * replace_ratio)) 168 | pert_explanation_ids = [random.choice(range(len(tokenizer))) if _idx in mask_idx else explanation_ids[_idx] for _idx in range(explanation_length)] 169 | if args.add_task_prefix: 170 | decoder_input_ids = [tokenizer.pad_token_id] + tokenizer.encode('[factual]', add_special_tokens=False) + pert_explanation_ids + answer_prefix_ids 171 | else: 172 | decoder_input_ids = [tokenizer.pad_token_id] + pert_explanation_ids + answer_prefix_ids 173 | decoder_input_ids = torch.tensor([decoder_input_ids]).to(args.device) 174 | prediction = generation_with_prefix(inputs, decoder_input_ids, model, tokenizer, args)[0].strip() 175 | if prediction == example.answer: 176 | accuracy += 1 177 | example_idx += 1 178 | if args.debug and example_idx > 50: 179 | break 180 | 181 | return accuracy * 100. / len(dataset) 182 | 183 | def main(args, seed): 184 | # ----------------------------------------------------- # 185 | # prepare logger 186 | log_path = os.path.join(args.save_dir, 'train_seed{}.log'.format(seed)) 187 | logger = get_logger("model", log_path) 188 | logger.info('args: {}'.format(args)) 189 | 190 | # ----------------------------------------------------- # 191 | # model 192 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir='../cache/') 193 | model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name, cache_dir='../cache/') 194 | model.to(args.device) 195 | 196 | # ----------------------------------------------------- # 197 | # data 198 | 199 | trainset = get_tensor_dataset('train', tokenizer, args) 200 | train_sampler = RandomSampler(trainset) 201 | train_dataloader = DataLoader(trainset, 202 | collate_fn=None, 203 | sampler=train_sampler, 204 | batch_size=args.train_batch_size, 205 | ) 206 | 207 | if args.counterfactual_alpha > 0: 208 | trainset1 = get_tensor_dataset('train.counterfactual' , tokenizer, args, counterfactual=True) 209 | train_sampler1 = RandomSampler(trainset1) 210 | train_dataloader_counterfactual = DataLoader(trainset1, collate_fn=None, sampler=train_sampler1, batch_size=args.train_batch_size) 211 | 212 | devset = get_tensor_dataset('dev', tokenizer, args) 213 | 214 | # ----------------------------------------------------- # 215 | # optimization 216 | no_decay = ["bias", "LayerNorm.weight"] 217 | optimizer_grouped_parameters = [ 218 | { 219 | "params": [p for n, p in model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay)], 220 | "weight_decay": args.weight_decay, 221 | }, 222 | { 223 | "params": [p for n, p in model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay)], 224 | "weight_decay": 0.0 225 | }, 226 | ] 227 | optimizer = Adafactor( 228 | optimizer_grouped_parameters, 229 | lr=args.learning_rate, 230 | weight_decay=0.0, 231 | relative_step=False, 232 | scale_parameter=False, 233 | warmup_init=False 234 | ) 235 | 236 | num_update_steps_per_epoch = len(train_dataloader) 237 | t_total = num_update_steps_per_epoch // args.grad_step * args.num_epoch 238 | warmup_steps = int(t_total * args.warmup_ratio) 239 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) 240 | 241 | # ----------------------------------------------------- # 242 | # training loop 243 | model_ckpt = os.path.join(args.save_dir, 'model_seed{}.ckpt'.format(seed)) 244 | output_path = os.path.join(args.save_dir, 'validation_seed{}.jsonl'.format(seed)) 245 | global_step = 0 246 | best_dev_loss = 1e19 247 | step_nogress = 0 248 | optimizer.zero_grad() 249 | loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=args.smoothing_factor) 250 | if args.debug: 251 | args.num_epoch = 1 252 | for epoch in trange(int(args.num_epoch), desc="Epoch"): 253 | train_loss = 0. 254 | counterfactual_loss = 0. 255 | model.train() 256 | epoch_iterator = tqdm(train_dataloader, desc="Train Iteration at Epoch {}".format(epoch), total=num_update_steps_per_epoch) 257 | if args.counterfactual_alpha > 0: 258 | counterfactual_iterator = iter(train_dataloader_counterfactual) 259 | for step, batch in enumerate(epoch_iterator): 260 | 261 | input_ids, attention_mask, labels = tuple(t.to(args.device) for t in batch) 262 | 263 | outputs = model( 264 | input_ids=input_ids, 265 | attention_mask=attention_mask, 266 | labels=labels, 267 | ) 268 | outputs_loss = loss_fct(outputs.logits.view(-1, outputs.logits.size(-1)), labels.view(-1)) 269 | 270 | loss = (1 - args.counterfactual_alpha) * outputs_loss 271 | 272 | if args.counterfactual_alpha > 0: 273 | try: 274 | counterfactual_batch = next(counterfactual_iterator) 275 | except StopIteration: 276 | counterfactual_iterator = iter(train_dataloader_counterfactual) 277 | counterfactual_batch = next(counterfactual_iterator) 278 | input_ids, attention_mask, labels, decoder_input_ids = tuple(t.to(args.device) for t in counterfactual_batch) 279 | 280 | counterfactual_outputs = model( 281 | input_ids=input_ids, 282 | attention_mask=attention_mask, 283 | decoder_input_ids=decoder_input_ids, 284 | # labels=labels, 285 | ) 286 | counterfactual_outputs_loss = loss_fct(counterfactual_outputs.logits.view(-1, counterfactual_outputs.logits.size(-1)), labels.view(-1)) 287 | loss += args.counterfactual_alpha * counterfactual_outputs_loss 288 | 289 | loss /= args.grad_step 290 | loss.backward() 291 | if (global_step + 1) % args.grad_step == 0: 292 | 293 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 294 | optimizer.step() 295 | 296 | scheduler.step() # Update learning rate schedule 297 | optimizer.zero_grad() 298 | 299 | train_loss += outputs_loss.item() # * args.grad_step 300 | if args.counterfactual_alpha > 0: 301 | counterfactual_loss += counterfactual_outputs_loss.item() 302 | global_step += 1 303 | epoch_iterator.set_description("Epoch {} loss {:.4f} counter {:.4f}".format(epoch, train_loss / (step + 1), counterfactual_loss / (step + 1))) 304 | if args.debug and global_step > 10: 305 | break 306 | 307 | train_loss /= (step + 1) 308 | counterfactual_loss /= (step + 1) 309 | log = 'Epoch: {:03d} Train loss: {:.4f} Counterfacual loss: {:.4f}' 310 | logger.info(log.format(epoch, train_loss, counterfactual_loss)) 311 | 312 | dev_result = evaluate(devset, model, args) 313 | log = 'Epoch: {:03d}, dev loss {:.4f}, perplexity {:.4f}' 314 | if dev_result["loss"] < best_dev_loss: 315 | torch.save({'ckpt': model.state_dict(), 'args': args}, model_ckpt) 316 | log += ' best' 317 | best_dev_loss = dev_result["loss"] 318 | step_nogress = 0 319 | else: 320 | step_nogress += 1 321 | logger.info(log.format(epoch, dev_result["loss"], dev_result["perplexity"])) 322 | if step_nogress > args.num_epoch_early_stopping and global_step > warmup_steps: 323 | break 324 | 325 | return_result = {} 326 | model.load_state_dict(torch.load(model_ckpt)['ckpt']) 327 | for split in ['test']: 328 | testset = load_raw_dataset(split, args) 329 | output_path = os.path.join(args.save_dir, '{}_seed{}.jsonl'.format(split, seed)) 330 | 331 | accuracy, explanations = inference(testset, output_path, model, tokenizer, args) 332 | if split == 'test': 333 | return_result["accuracy_inference"] = accuracy 334 | log = 'Epoch: {:03d}, inference accuracy: {:.4f}' 335 | logger.info(log.format(-1, accuracy)) 336 | if not args.without_explanation: 337 | return_result["accuracy_oracle"] = inference_with_oracle(testset, model, tokenizer, args) 338 | log = 'Epoch: {:03d}, oracle accuracy: {:.4f}' 339 | logger.info(log.format(-1, return_result["accuracy_oracle"])) 340 | return_result["accuracy_perturb"] = inference_with_perturb(testset, explanations, model, tokenizer, args) 341 | log = 'Epoch: {:03d}, perturb accuracy: {:.4f}' 342 | logger.info(log.format(-1, return_result["accuracy_perturb"])) 343 | 344 | if not args.save_ckpt: 345 | os.remove(model_ckpt) 346 | return return_result 347 | 348 | if __name__ == "__main__": 349 | 350 | parser = argparse.ArgumentParser(description='Run main.') 351 | parser.add_argument('--dataset', '-d', type=str) 352 | parser.add_argument('--save_dir', '-o', type=str) 353 | parser.add_argument("--debug", action='store_true') 354 | parser.add_argument("--save_ckpt", action='store_true') 355 | parser.add_argument("--add_task_prefix", action='store_true') 356 | 357 | # model 358 | parser.add_argument('--model_name', '-m', type=str) 359 | parser.add_argument('--max_enc_length', type=int, default=128) 360 | parser.add_argument('--max_dec_length', type=int, default=128) 361 | 362 | # training 363 | parser.add_argument('--train_batch_size', type=int, default=32) 364 | parser.add_argument('--grad_step', type=int, default=1) 365 | parser.add_argument('--learning_rate', type=float, default=1e-5) 366 | parser.add_argument("--warmup_ratio", type=float, default=0.06) 367 | parser.add_argument('--weight_decay', type=float, default=0.0) 368 | parser.add_argument("--max_grad_norm", default=1.0, type=float) 369 | parser.add_argument('--num_epoch', type=float, default=1000) 370 | parser.add_argument('--num_epoch_early_stopping', type=int, default=10) 371 | 372 | # method 373 | parser.add_argument("--without_explanation", action='store_true') 374 | parser.add_argument('--counterfactual_alpha', type=float, default=0) 375 | parser.add_argument('--smoothing_factor', type=float, default=0) 376 | 377 | # inference 378 | parser.add_argument("--inference", action='store_true') 379 | parser.add_argument("--evaluate", action='store_true') 380 | parser.add_argument('--eval_split', type=str, default='test') 381 | parser.add_argument('--eval_batch_size', type=int, default=8) 382 | parser.add_argument('--sample', action='store_true') 383 | parser.add_argument('--num_beams', type=int, default=1) 384 | parser.add_argument('--top_k', type=int, default=0) 385 | parser.add_argument('--top_p', type=float, default=1.0) 386 | parser.add_argument('--num_return_sequences', type=int, default=1) 387 | parser.add_argument("--overwrite_output", action='store_true') 388 | 389 | # gpu and workers option 390 | parser.add_argument('--gpu', type=int, default=0) 391 | 392 | args = parser.parse_args() 393 | args.device = torch.device('cuda:{}'.format(args.gpu)) 394 | 395 | eval_result_all_split = {} 396 | for seed in range(5): 397 | set_seed(seed) 398 | eval_result = main(args, seed) 399 | for split in eval_result: 400 | if split not in eval_result_all_split: 401 | eval_result_all_split[split] = [] 402 | eval_result_all_split[split].append(eval_result[split]) 403 | output_result = {} 404 | for split in eval_result_all_split: 405 | output_result[split] = { 406 | "accuracy_mean": np.mean(eval_result_all_split[split]), 407 | "accuracy_std": np.std(eval_result_all_split[split]), 408 | } 409 | with open(os.path.join(args.save_dir, 'evaluation_results.json'), 'w') as fw: 410 | json.dump(output_result, fw, indent=4) 411 | --------------------------------------------------------------------------------