├── .gitignore ├── README.md ├── UniEval ├── LICENSE ├── evaluator.py ├── run.py ├── scorer.py └── utils.py ├── baseline ├── README.md ├── huggingface_run.py ├── openai_run.py └── test_prompt_baseline.py ├── human_eval └── sample_indices_for_human.py ├── kcd ├── attribute_classifier │ ├── __init__.py │ ├── attribute_classifier_model.py │ ├── attribute_dataloader.py │ ├── evaluate_classifier.py │ └── train_classifier.py ├── classifier_guidance │ ├── __init__.py │ ├── astar_decode.py │ ├── fudge_decode.py │ ├── guided_generation_predictor.py │ ├── metric_guidance.py │ ├── nado_decode.py │ ├── openai_fudge_decode.py │ ├── openai_ppl_mcts.py │ ├── ppl_mcts.py │ └── utils.py ├── configs.py ├── dstc11_task5 │ ├── __init__.py │ ├── dataset_walker.py │ └── knowledge_reader.py ├── evaluation │ ├── __init__.py │ ├── auto_evaluation.py │ └── token_f1_score.py ├── instructions.py ├── kilt │ ├── knowledge_source.py │ ├── load_kilt_data.py │ ├── preprocess_fever.py │ └── preprocess_kilt.py ├── openai_module.py ├── partial_negative.py ├── sample_negative.py ├── summarization │ ├── __init__.py │ └── load_data.py ├── text_data.py ├── token_classifier │ ├── __init__.py │ ├── dataloader.py │ ├── model.py │ ├── train.py │ └── trainer.py ├── util.py └── wizard_of_wikipedia │ ├── __init__.py │ ├── load_data.py │ └── preprocess.py ├── notebooks └── read_results.ipynb ├── requirements.txt ├── scripts ├── analyze_partial_hallucination_data.py ├── evaluate_generations_with_classifier.py ├── evaluate_summary_mfma.py ├── evaluate_zeroshot_wow_classification.py ├── run_guided_generation.py ├── run_openai_guided_generation.py ├── run_openai_ppl_mcts.py ├── run_ppl_mcts.py └── shell │ ├── baselines │ ├── baseline_run.sh │ ├── openai_run.sh │ └── sft_baseline_run.sh │ ├── data_process │ ├── partial_neg_gen.sh │ ├── preprocess_wow.sh │ └── random_neg.sh │ ├── eval │ ├── class_prob.sh │ ├── test_t5_token_classifier.sh │ └── unieval.sh │ ├── guided_run.sh │ ├── openai_guided_run.sh │ ├── openai_mcts_run.sh │ ├── ppl_mcts_run.sh │ └── train │ ├── sft_t5.sh │ ├── train_t5_token_classifier.sh │ ├── train_t5_token_classifier_cnn.sh │ └── train_token_classifier_gpt.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ 132 | 133 | # pure experimental script 134 | play.ipynb 135 | 136 | # ignore output generations folder 137 | outputs/ 138 | saved_models/ 139 | generations/ 140 | 141 | # wandb logs 142 | wandb/ 143 | logs/ 144 | 145 | fever/ 146 | cached/ 147 | eval_results/ 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Constrained Decoding 2 | 3 | Official Code for EMNLP 2023 Paper "KCTS: Knowledge-Constrained Tree Search Decoding with Token-Level Hallucination Detection" (https://arxiv.org/abs/2310.09044). 4 | 5 | ## Environment 6 | 7 | ```bash 8 | pip install -r requirements.txt 9 | pip install -e . 10 | ``` 11 | ## Prepare Data 12 | 13 | 1. First, download WoW dataset through [ParlAI](https://github.com/facebookresearch/ParlAI). 14 | 2. Then, 15 | 16 | ```bash 17 | export WOW_PATH= 18 | sh scripts/shell/data_process/preprocess_wow.sh 20 $WOW_PATH 19 | ``` 20 | 21 | 3. Generate Partial Negative data 22 | 23 | ```bash 24 | bash scripts/shell/data_process/partial_neg_gen.sh 0 wow 16 # for wow 25 | bash scripts/shell/data_process/partial_neg_gen.sh 0 cnn_dailymail 16 # for cnn/dm data 26 | ``` 27 | 28 | 4. Sample Random Negative data (for WoW only) 29 | 30 | ```bash 31 | bash scripts/shell/data_process/random_neg.sh wow 32 | ``` 33 | 34 | 5. Mix the datasets to your liking. 35 | 36 | ```python 37 | 38 | # typo expected 39 | from datasets import load_from_disk 40 | 41 | partial_data_path = 42 | random_data_path = 43 | 44 | partial_data = load_from_disk(partial_data_path) 45 | random_data = load_from_disk(random_data_path) 46 | 47 | merged_dataset = concatenate_datasets([partial_data, random_data]) 48 | merged_dataset.train_test_split(test_size=0.1) 49 | 50 | merged_dataset.save_to_disk(SAVE_PATH) 51 | ``` 52 | ## Train RIPA discrimnator 53 | 54 | ```bash 55 | # the numbers are the stdin options of the train script. Details can be found at the top of the script file. 56 | sh scripts/shell/train/train_t5_token_classifier.sh 0 EOS 0 0 0 0 # train f 57 | sh scripts/shell/train/train_t5_token_classifier.sh 0 RIPA 0 0 0 1 # finetune RIPA from f 58 | sh scripts/shell/train/train_t5_token_classifier_cnn.sh 0 RIPA 0 0 0 0 # cnn 59 | ``` 60 | 61 | ## Run Weighted Decoding 62 | 63 | ```bash 64 | sh scripts/shell/guided_run.sh 0 fudge RAND wow 8 0 0 0 '' 65 | sh scripts/shell/guided_run.sh 0 nado ALL wow 8 1 0 0 '' 66 | # KWD 67 | sh scripts/shell/guided_run.sh 0 fudge RIPA wow 8 0 0 0 '' 68 | ``` 69 | 70 | ## Run MCTS (KCTS) 71 | 72 | ```bash 73 | sh scripts/shell/ppl_mcts_run.sh 0 RIPA '' wow 8 0 0 0 0 0 74 | ``` 75 | 76 | ## Guide GPT 3.5 77 | 78 | - Need to train RIPA on GPT2 for this. Checkout `scripts/shell/train/train_token_classifier_gpt.sh`. 79 | 80 | ```bash 81 | export EXP_ROOT= 82 | sh scripts/shell/openai_guided_run.sh 0 RIPA 4 $EXP_ROOT 0 0 3 0 0 0 83 | ``` 84 | 85 | ## Evaluation 86 | 87 | We use [UniEval](https://arxiv.org/abs/2210.07197) (Zhong et al., 2022) + [MFMA](https://aclanthology.org/2022.findings-naacl.76.pdf) (Lee et al., 2022, for summarization) + Token-based metrics. 88 | 89 | ```bash 90 | sh scripts/eval/unieval.sh 91 | ``` 92 | 93 | - One can also evaluate the $f$ confidence, using `scripts/eval/class_prob.sh` script. 94 | - Also see `scripts/eval/test_t5_token_classifier.sh` to evaluate the classifier performance. 95 | -------------------------------------------------------------------------------- /UniEval/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ming Zhong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /UniEval/scorer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM 4 | from tqdm import tqdm 5 | 6 | class UniEvaluator: 7 | def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): 8 | """ Set up model """ 9 | self.device = device 10 | self.max_length = max_length 11 | 12 | self.config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) 13 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir) 14 | self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config, 15 | cache_dir=cache_dir) 16 | 17 | self.model.eval() 18 | self.model.to(device) 19 | 20 | self.softmax = nn.Softmax(dim=1) 21 | 22 | self.pos_id = self.tokenizer("Yes")["input_ids"][0] 23 | self.neg_id = self.tokenizer("No")["input_ids"][0] 24 | 25 | def score(self, inputs, batch_size=8): 26 | """ 27 | Get scores for the given samples. 28 | final_score = postive_score / (postive_score + negative_score) 29 | """ 30 | 31 | # The implementation of "forward" in T5 still requires decoder_input_ids. 32 | # Therefore, we construct a random one-word target sequence. 33 | # The content of the target has no effect on the final scores. 34 | tgts = ["No" for _ in range(len(inputs))] 35 | 36 | pos_score_list, neg_score_list = [], [] 37 | for i in tqdm(range(0, len(inputs), batch_size)): 38 | src_list = inputs[i: i + batch_size] 39 | tgt_list = tgts[i: i + batch_size] 40 | try: 41 | with torch.no_grad(): 42 | encoded_src = self.tokenizer( 43 | src_list, 44 | max_length=self.max_length, 45 | truncation=True, 46 | padding=True, 47 | return_tensors='pt' 48 | ) 49 | encoded_tgt = self.tokenizer( 50 | tgt_list, 51 | max_length=self.max_length, 52 | truncation=True, 53 | padding=True, 54 | return_tensors='pt' 55 | ) 56 | 57 | src_tokens = encoded_src['input_ids'].to(self.device) 58 | src_mask = encoded_src['attention_mask'].to(self.device) 59 | 60 | tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1) 61 | 62 | output = self.model( 63 | input_ids=src_tokens, 64 | attention_mask=src_mask, 65 | labels = tgt_tokens 66 | ) 67 | logits = output.logits.view(-1, self.model.config.vocab_size) 68 | 69 | pos_score = self.softmax(logits)[:, self.pos_id] # Yes 70 | neg_score = self.softmax(logits)[:, self.neg_id] # No 71 | 72 | cur_pos_score = [x.item() for x in pos_score] 73 | cur_neg_score = [x.item() for x in neg_score] 74 | pos_score_list += cur_pos_score 75 | neg_score_list += cur_neg_score 76 | 77 | except RuntimeError: 78 | print(f'source: {src_list}') 79 | print(f'target: {tgt_list}') 80 | exit(0) 81 | 82 | score_list = [] 83 | for i in range(len(pos_score_list)): 84 | score_list.append(pos_score_list[i] / (pos_score_list[i] + neg_score_list[i])) 85 | 86 | return score_list 87 | -------------------------------------------------------------------------------- /UniEval/utils.py: -------------------------------------------------------------------------------- 1 | from prettytable import PrettyTable 2 | 3 | def convert_to_json(output_list, src_list=None, ref_list=None, context_list=None, \ 4 | scores=None, doc_id=None, system_id=None): 5 | """ 6 | Convert the data into the json format. 7 | 8 | output_list: a list of model output 9 | src_list: source input for different NLG tasks. For example, source document for summarization 10 | and dialogue history for dialogue response generation 11 | ref_list: human-annotated groundtruth 12 | context_list: the context needed to evaluate several specific dimension. For example, 13 | additional factual information when evaluating engagingness and groundedness in dialogues 14 | scores: human scores for evaluating the model output. They can be used to calculate the correlation 15 | between evaluators and human judgements. The scores should be stored in a dictionary. For example, 16 | {'fluency': 2.0, 'coherence': 3.0} could be the human score for a sample. 17 | doc_id: the index of the input source. It can be used to calculate summary-level correlation for summarzation 18 | system_id: the index of the generation system. It can be used to calculate system-level correlation. 19 | """ 20 | json_data = [] 21 | for i in range(len(output_list)): 22 | cur = {} 23 | cur['system_output'] = output_list[i] 24 | if src_list is not None: 25 | cur['source'] = src_list[i] 26 | if ref_list is not None: 27 | cur['reference'] = ref_list[i] 28 | if context_list is not None: 29 | cur['context'] = context_list[i] 30 | if scores is not None: 31 | cur['scores'] = scores[i] 32 | if doc_id is not None: 33 | cur['doc_id'] = doc_id[i] 34 | if system_id is not None: 35 | cur['system_id'] = system_id[i] 36 | json_data.append(cur) 37 | return json_data 38 | 39 | 40 | def add_question(dimension, output, src=None, ref=None, context=None, task=None): 41 | """ 42 | Add questions to generate input in Bool-QA format for UniEval. 43 | 44 | dimension: specific dimension to be evaluated 45 | src: source input for different NLG tasks. For example, source document for summarization 46 | and dialogue history for dialogue response generation. 47 | output: output text generated by the models 48 | ref: human-annotataed groundtruth 49 | context: the context needed to evaluate several specific dimension. For example, 50 | additional factual information when evaluating engagingness and groundedness in dialogues. 51 | """ 52 | 53 | input_with_question = [] 54 | for i in range(len(output)): 55 | # For summarization 56 | if task == 'summarization': 57 | if dimension == 'fluency': 58 | cur_input = 'question: Is this a fluent paragraph? paragraph: ' + output[i] 59 | elif dimension == 'coherence': 60 | cur_input = 'question: Is this a coherent summary to the document? summary: ' + output[i] + ' document: ' + src[i] 61 | elif dimension == 'consistency': 62 | cur_input = 'question: Is this claim consistent with the document? claim: ' + output[i] + ' document: ' + src[i] 63 | elif dimension == 'relevance': 64 | cur_input = 'question: Is this summary relevant to the reference? summary: ' + output[i] + ' reference: ' + ref[i] 65 | else: 66 | raise NotImplementedError('The input format for this dimension is still undefined. Please customize it first.') 67 | # For dialogues 68 | elif task == 'dialogue': 69 | if dimension == 'naturalness': 70 | cur_input = 'question: Is this a natural response in the dialogue? response: ' + output[i] 71 | elif dimension == 'coherence': 72 | cur_input = 'question: Is this a coherent response given the dialogue history? response: '\ 73 | + output[i] + ' dialogue history: ' + src[i] 74 | elif dimension == 'engagingness': 75 | cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? response: '\ 76 | + output[i] + ' dialogue history: ' + src[i] + ' fact: ' + context[i] 77 | elif dimension == 'groundedness': 78 | cur_input = 'question: Is this response consistent with knowledge in the fact? response: '\ 79 | + output[i] + ' fact: ' + context[i] 80 | elif dimension == 'understandability': 81 | cur_input = 'question: Is this an understandable response in the dialogue? response: ' + output[i] 82 | else: 83 | raise NotImplementedError('The input format for this dimension is still undefined. Please customize it first.') 84 | # For data-to-text 85 | elif task == 'data2text': 86 | if dimension == 'naturalness': 87 | cur_input = 'question: Is this a fluent utterance? utterance: ' + output[i] 88 | elif dimension == 'informativeness': 89 | cur_input = 'question: Is this sentence informative according to the reference? sentence: '\ 90 | + output[i] + ' reference: ' + ref[i] 91 | else: 92 | raise NotImplementedError('The input format for this dimension is still undefined. Please customize it first.') 93 | # For factual consistency detection 94 | elif task == 'fact': 95 | if dimension == 'consistency': 96 | cur_input = 'question: Is this claim consistent with the document? claim: ' + output[i] + ' document: ' + src[i] 97 | else: 98 | raise NotImplementedError('No other dimensions for the factual consistency detection task.') 99 | # For new customized tasks 100 | else: 101 | raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.') 102 | input_with_question.append(cur_input) 103 | return input_with_question 104 | 105 | 106 | def print_scores(scores): 107 | table = PrettyTable(['Dimensions','Score']) 108 | print('\nEvaluation scores are shown below:') 109 | dims = list(scores[0].keys()) 110 | for dim in dims: 111 | cur_score = 0 112 | for i in range(len(scores)): 113 | cur_score += scores[i][dim] 114 | table.add_row([dim, round(cur_score / len(scores), 6)]) 115 | print(table) 116 | -------------------------------------------------------------------------------- /baseline/README.md: -------------------------------------------------------------------------------- 1 | OPENAI API References 2 | 3 | 4 | # Completion API 5 | 6 | * Parameters 7 | * NOTE: GPT3 uses the same tokenizer as GPT2 8 | 9 | ```python 10 | def completion( 11 | model: str, 12 | prompt: str | list[str], 13 | suffix: str = None, 14 | max_tokens: int = 16, 15 | temperature: float = 1, # between (0, 2), 16 | top_p: float = 1, 17 | n: int = 1, # number of completions for each prompt) 18 | logprobs: int = 0, # max 5, 19 | stop: str | list[str] = None, # stop token, <=4, 20 | best_of: int = 1, 21 | logit_bias: dict[str, int] = None, 22 | presence_penalty: float = 0, # between (-2, 2), 23 | frequency_penalty: float = 0, # between (-2, 2), 24 | ) 25 | ``` 26 | 27 | * `presence_penalty: Optional[number] = 0` 28 | * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. 29 | 30 | * `frequency_penalty: Optional[number] = 0` 31 | * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. 32 | 33 | 34 | 35 | * Response 36 | 37 | ```json 38 | { 39 | "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", 40 | "object": "text_completion", 41 | "created": 1589478378, 42 | "model": "text-davinci-003", 43 | "choices": [ 44 | { 45 | "text": "\n\nThis is indeed a test", 46 | "index": 0, 47 | "logprobs": null, 48 | "finish_reason": "length" 49 | } 50 | ], 51 | "usage": { 52 | "prompt_tokens": 5, 53 | "completion_tokens": 7, 54 | "total_tokens": 12 55 | } 56 | } 57 | ``` 58 | 59 | # Chat API 60 | 61 | * Parameters 62 | ```python 63 | def completion( 64 | model: str, 65 | messages: list[{ 66 | "role": str = choice[system, user, assistant] 67 | "content": str 68 | "name": str = None 69 | }], 70 | max_tokens: int = 16, 71 | temperature: float = 1, # between (0, 2), 72 | top_p: float = 1, 73 | n: int = 1, # number of completions for each prompt) 74 | stop: str | list[str] = None, # stop token, <=4, 75 | best_of: int = 1, 76 | logit_bias: dict[str, int] = None, 77 | presence_penalty: float = 0, # between (-2, 2), 78 | frequency_penalty: float = 0, # between (-2, 2), 79 | ) 80 | ``` 81 | * Response 82 | ```json 83 | { 84 | "id": "chatcmpl-123", 85 | "object": "chat.completion", 86 | "created": 1677652288, 87 | "choices": [{ 88 | "index": 0, 89 | "message": { 90 | "role": "assistant", 91 | "content": "\n\nHello there, how may I assist you today?", 92 | }, 93 | "finish_reason": "stop" 94 | }], 95 | "usage": { 96 | "prompt_tokens": 9, 97 | "completion_tokens": 12, 98 | "total_tokens": 21 99 | } 100 | } 101 | ``` 102 | -------------------------------------------------------------------------------- /baseline/huggingface_run.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | from pprint import pprint 5 | 6 | import torch 7 | from tqdm.auto import tqdm 8 | from transformers import (DataCollatorForSeq2Seq, HfArgumentParser, Seq2SeqTrainer, 9 | Seq2SeqTrainingArguments) 10 | from peft import LoraConfig, get_peft_model 11 | 12 | from kcd.text_data import load_text_data 13 | from kcd.util import load_transformer_LM_tokenizer 14 | from kcd.configs import GenerationConfig 15 | 16 | 17 | @dataclass 18 | class ExperimentArgs: 19 | data_path: str = 'data/wow-dev-kilt-processed.jsonl' 20 | output_path: str = 'generations/baseline' 21 | model_name: str = "google/flan-t5-xl" 22 | dataset: str = 'wow' 23 | use_kilt_format: bool = True 24 | task: str = 'completion' 25 | continue_from: int = 0 26 | batch_size: int = 1 27 | load_8bit: bool = True 28 | load_checkpoint: str = '' 29 | print_output: bool = False 30 | instruction_model: str = 'basic' # choices=['basic', 'openai', 'alpaca'] 31 | 32 | 33 | def main(): 34 | parser = HfArgumentParser((ExperimentArgs, GenerationConfig, Seq2SeqTrainingArguments)) 35 | args, gen_cfg, train_args = parser.parse_args_into_dataclasses() 36 | args.output_path = train_args.output_dir 37 | 38 | load_kwargs = { 39 | 'device_map': 'auto' if args.load_8bit else None, 40 | 'load_in_8bit': args.load_8bit, 41 | 'torch_dtype': torch.float16 if args.load_8bit else torch.bfloat16, 42 | } 43 | 44 | model, tokenizer = load_transformer_LM_tokenizer(args.model_name, **load_kwargs) 45 | tokenizer.truncation_side = 'left' 46 | 47 | if args.load_checkpoint: 48 | peft_config_path = os.path.join(os.path.dirname(args.load_checkpoint), 'adapter_model') 49 | peft_config = LoraConfig.from_pretrained(peft_config_path) 50 | model = get_peft_model(model, peft_config) 51 | 52 | ckpt = torch.load(args.load_checkpoint) 53 | ckpt['base_model.model.lm_head.weight'] = ckpt.pop("base_model.model.lm_head.0.weight") 54 | model.load_state_dict(ckpt, strict=True) 55 | 56 | tokenized_dataset = load_text_data(path=args.data_path, 57 | instruction_model=args.instruction_model, 58 | task=args.task, 59 | use_kilt_format=args.use_kilt_format, 60 | tokenize=True, 61 | tokenizer=tokenizer, 62 | no_label=True) 63 | text_dataset = load_text_data(path=args.data_path, 64 | instruction_model=args.instruction_model, 65 | task=args.task, 66 | use_kilt_format=args.use_kilt_format, 67 | tokenize=False) 68 | 69 | trainer = Seq2SeqTrainer( 70 | model=model, 71 | args=train_args, 72 | data_collator=DataCollatorForSeq2Seq(tokenizer, model=model), 73 | tokenizer=tokenizer, 74 | ) 75 | preds = trainer.predict(tokenized_dataset, **gen_cfg.__dict__) 76 | preds.predictions[preds.predictions == -100] = tokenizer.pad_token_id 77 | responses = tokenizer.batch_decode(preds.predictions, skip_special_tokens=True) 78 | 79 | os.makedirs(args.output_path, exist_ok=True) 80 | dataset_name = args.dataset 81 | if args.use_kilt_format: 82 | dataset_name += '-kilt' 83 | out_fname = f'{dataset_name}-{args.model_name.replace("/", "-")}' 84 | if args.load_checkpoint: 85 | out_fname = out_fname + '-sft' 86 | out_fname = os.path.join(args.output_path, out_fname + '.jsonl') 87 | fout = open(out_fname, 'a') 88 | for i, (example, response) in tqdm(enumerate(zip(text_dataset, responses)), 89 | total=len(responses)): 90 | result = { 91 | 'prompt': example['prompt'], 92 | 'response': response, 93 | 'gold': example['completion'], 94 | 'index': i, 95 | } 96 | fout.write(json.dumps(result) + '\n') 97 | if args.print_output: 98 | print("prompt:\n") 99 | pprint(example['prompt']) 100 | print() 101 | print(f"{args.model_name} Response:\n") 102 | print(response + '\n\n') 103 | print(f"Gold Response:\n") 104 | print(example['completion'] + '\n\n') 105 | fout.close() 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /baseline/openai_run.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import json 3 | import os 4 | import time 5 | from termcolor import colored 6 | 7 | from transformers import HfArgumentParser 8 | from tqdm.auto import tqdm 9 | 10 | from kcd.openai_module import OpenAIModel, OpenAIAPIParameters 11 | from kcd.text_data import load_text_data 12 | 13 | 14 | @dataclass 15 | class ExperimentArgs: 16 | data_path: str = 'data/wow-dev-kilt-processed.jsonl' 17 | output_path: str = 'generations/baseline' 18 | model_name: str = "gpt-3.5-turbo" 19 | dataset: str = 'wow' 20 | use_kilt_format: bool = True 21 | task: str = 'chat' 22 | continue_from: int = 0 23 | debug: bool = False 24 | skip_no_knowledge: bool = False 25 | instruction_model: str = 'basic' 26 | human_indices: str = None 27 | 28 | 29 | def main(): 30 | parser = HfArgumentParser((ExperimentArgs, OpenAIAPIParameters)) 31 | args, parameters = parser.parse_args_into_dataclasses() 32 | 33 | text_dataset = load_text_data(path=args.data_path, 34 | task=args.task, 35 | instruction_model=args.instruction_model, 36 | use_kilt_format=args.use_kilt_format) 37 | if args.human_indices: 38 | with open(args.human_indices) as f: 39 | indices = [int(i) for i in f.readlines()] 40 | text_dataset = text_dataset.select(indices) 41 | model = OpenAIModel(model_name=args.model_name, task=args.task) 42 | 43 | os.makedirs(args.output_path, exist_ok=True) 44 | dataset_name = args.dataset 45 | if args.use_kilt_format: 46 | dataset_name += '-kilt' 47 | out_fname = os.path.join(args.output_path, f'{dataset_name}-openai_{args.model_name}') 48 | if args.human_indices: 49 | out_fname += '_human' 50 | out_fname += '.jsonl' 51 | fout = open(out_fname, 'a') 52 | for i, example in tqdm(enumerate(text_dataset), total=len(text_dataset)): 53 | if i < args.continue_from: 54 | continue 55 | if args.debug: 56 | full_response = {'dummy': 'dummy'} 57 | completion = 'dummy' 58 | elif args.skip_no_knowledge and "no_passages_used" in example['prompt']: 59 | full_response = {'choices': [{'text': 'skipped'}]} 60 | completion = 'skipped due to no knowledge' 61 | else: 62 | completion = None 63 | try_count = 0 64 | start = time.time() 65 | while completion is None: 66 | if try_count > 5: 67 | print( 68 | f"Stop trying after {try_count} tries and {time.time() - start:.2f} seconds." 69 | ) 70 | print(f"You can resume by setting --continue_from={i}") 71 | exit(1) 72 | try: 73 | try_count += 1 74 | completion, full_response = model(example['prompt'], parameters) 75 | except: 76 | print("OpenAI Rate Limit reached. Sleeping for 5 minutes.") 77 | time.sleep(300) 78 | if try_count > 0: 79 | print( 80 | f"exited while loop after {try_count} tries and {time.time() - start:.2f} seconds" 81 | ) 82 | result = { 83 | 'prompt': example['prompt'], 84 | 'response': full_response, 85 | 'gold': example['completion'], 86 | 'index': i, 87 | } 88 | fout.write(json.dumps(result) + '\n') 89 | print("prompt:\n") 90 | if args.task == 'chat': 91 | for msg in example["prompt"]: 92 | if msg['role'] == 'user': 93 | print(colored('User: ' + msg['content'], 'green')) 94 | elif msg['role'] == 'assistant': 95 | print(colored('Assistant: ' + msg['content'], 'blue')) 96 | else: 97 | print(colored('System: ' + msg['content'], 'red')) 98 | else: 99 | print(example["prompt"]) 100 | print() 101 | print(f"{args.model_name} Response:\n") 102 | print(completion + '\n\n') 103 | print(f"Gold Response:\n") 104 | print(example["completion"] + '\n\n') 105 | fout.close() 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /baseline/test_prompt_baseline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | 5 | import pandas as pd 6 | import torch 7 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM 8 | import fire 9 | 10 | from kcd.util import get_logger 11 | 12 | 13 | def experiment(model, 14 | tokenizer, 15 | idx, 16 | evidence, 17 | original_claim, 18 | logger, 19 | device='cpu', 20 | prompt_len=3, 21 | max_new_tokens=40): 22 | ##################### input prep ################### 23 | if not isinstance(evidence, list): 24 | evidence = [evidence] 25 | original_claim = [original_claim] 26 | idx = [idx] 27 | 28 | basic_gen_inst = [] 29 | evidence_inst = [] 30 | zero_shot_inst = [] 31 | zero_shot_prompted_inst = [] 32 | prompts = [] 33 | for evid, claim in zip(evidence, original_claim): 34 | prompt = ' '.join(claim.split(' ')[:prompt_len]) 35 | 36 | basic_gen_inst.append(f"Complete the following sentence: {prompt}") 37 | evidence_inst.append(f"""Complete the following sentence: {evid} {prompt}""") 38 | zero_shot_inst.append(f"""Generate a claim that is supported by the evidence below. 39 | evidence: {evid} 40 | claim:""") 41 | zero_shot_prompted_inst.append(f"""Generate a claim that is supported by the evidence below. 42 | evidence: {evid} 43 | claim: {prompt}""") 44 | prompts.append(prompt) 45 | 46 | gen_inst_ids = tokenizer(basic_gen_inst, return_tensors='pt', padding=True).to(device) 47 | evidence_inst_ids = tokenizer(evidence_inst, return_tensors='pt', padding=True).to(device) 48 | zero_shot_inst_ids = tokenizer(zero_shot_inst, return_tensors='pt', padding=True).to(device) 49 | zero_shot_prompted_inst_ids = tokenizer(zero_shot_prompted_inst, 50 | return_tensors='pt', 51 | padding=True).to(device) 52 | 53 | all_instructions = { 54 | 'completion': gen_inst_ids, 55 | '+ evidence': evidence_inst_ids, 56 | '+ zero-shot-inst': zero_shot_inst_ids, 57 | '+ zero-shot+prompted': zero_shot_prompted_inst_ids, 58 | } 59 | 60 | ################ Generation ######################## 61 | def _generate(inputs, top_p=0.95, temperature=0.8, num_beams=8): 62 | greedy = tokenizer.batch_decode(model.generate(**inputs, max_new_tokens=max_new_tokens), 63 | skip_special_tokens=False) 64 | topp = tokenizer.batch_decode(model.generate(**inputs, 65 | max_new_tokens=max_new_tokens, 66 | do_sample=True, 67 | top_p=top_p, 68 | temperature=temperature), 69 | skip_special_tokens=False) 70 | beam = tokenizer.batch_decode(model.generate(**inputs, 71 | num_beams=num_beams, 72 | max_new_tokens=max_new_tokens), 73 | skip_special_tokens=False) 74 | 75 | return greedy, topp, beam 76 | 77 | completions = {} 78 | for key, inst in all_instructions.items(): 79 | completions[key] = _generate(inst) 80 | 81 | data = [] 82 | for i in range(len(idx)): 83 | _data = { 84 | 'data_idx': idx[i], 85 | 'evidence': evidence[i], 86 | 'prompt': prompts[i], 87 | 'original_claim': original_claim[i], 88 | 'results': {}, 89 | } 90 | 91 | logger.info('#' * 40 + 'fever test set index %d ' + '#' * 40, idx[i]) 92 | logger.info('evidence: %s', evidence[i]) 93 | logger.info('prompt: %s', prompts[i]) 94 | logger.info('original claim: %s', original_claim[i]) 95 | 96 | for key, result in completions.items(): 97 | _data['results'][key] = {} 98 | _data['results'][key]['greedy'] = result[0][i] 99 | _data['results'][key]['top_p'] = result[1][i] 100 | _data['results'][key]['beam'] = result[2][i] 101 | 102 | logger.info( 103 | """[%s] 104 | \t[greedy] 105 | \t%s 106 | \t[top_p=0.95, temp=0.8] 107 | \t%s 108 | \t[beam=8] 109 | \t%s""", key, result[0][i], result[1][i], result[2][i]) 110 | 111 | data.append(_data) 112 | return data 113 | 114 | 115 | def main( 116 | pretrained_model='google/flan-t5-xl', 117 | fever_data_path='data/fever/paper_test.jsonl', 118 | outfname='outputs/fever_prompt_baseline.jsonl', 119 | prompt_len=1, 120 | max_new_tokens=20, 121 | batch_size=16, 122 | end_idx=1000, 123 | ): 124 | ############## logging #################### 125 | logging.basicConfig(level=logging.INFO) 126 | logger = get_logger('logs/prompt_baseline_test.log') 127 | ########################################### 128 | 129 | wiki_added_fever_path = os.path.splitext(fever_data_path)[0] + '+wiki.jsonl' 130 | fever = pd.read_json(wiki_added_fever_path, lines=True) 131 | 132 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 133 | #################### MODEL LOADING ######## 134 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model) 135 | 136 | if 't5' in pretrained_model or 't0' in pretrained_model: 137 | model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, 138 | device_map='balanced_low_0', 139 | load_in_8bit=True) 140 | else: 141 | model = AutoModelForCausalLM.from_pretrained(pretrained_model) 142 | # open-ended generation 143 | tokenizer.pad_token = tokenizer.eos_token 144 | model.config.pad_token_id = model.config.eos_token_id 145 | model.config.bos_token_id = model.config.eos_token_id 146 | # model = model.to(device) 147 | model.eval() 148 | print('model loading finished') 149 | ############################################ 150 | outfile = open(outfname, 'w') 151 | indices = [] 152 | evidences = [] 153 | claims = [] 154 | for idx, df in fever.iterrows(): 155 | if len(df['wiki_extracted']) == 0: 156 | continue 157 | evidence = df['wiki_extracted'][0] # first evidence 158 | evidence = evidence.replace('-LRB-', '(') 159 | evidence = evidence.replace('-RRB-', ')') 160 | claim = df['claim'] 161 | 162 | indices.append(idx) 163 | evidences.append(evidence) 164 | claims.append(claim) 165 | 166 | if (idx + 1) % batch_size == 0: 167 | batch_result = experiment(model, 168 | tokenizer, 169 | indices, 170 | evidences, 171 | claims, 172 | logger, 173 | device=device, 174 | prompt_len=prompt_len, 175 | max_new_tokens=max_new_tokens) 176 | for res in batch_result: 177 | json.dump(res, outfile) 178 | outfile.write('\n') 179 | indices = [] 180 | evidences = [] 181 | claims = [] 182 | if idx > end_idx: 183 | break 184 | outfile.close() 185 | 186 | 187 | if __name__ == '__main__': 188 | fire.Fire(main) 189 | -------------------------------------------------------------------------------- /human_eval/sample_indices_for_human.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import os 4 | 5 | import fire 6 | import pandas as pd 7 | 8 | 9 | def sample_indices_for_human(task, max_index, size=100): 10 | """Sample indices for human data.""" 11 | if task == 'wow': 12 | df = pd.read_json(f'generations/baseline/{task}-openai_gpt-3.5-turbo.jsonl', lines=True) 13 | good_indices = [] 14 | for i, prompt in enumerate(df['prompt']): 15 | for msg in prompt: 16 | if msg['role'] == 'system': 17 | knowledge = msg['content'].replace("Use the following knowledge," 18 | " but not directly copy, to" 19 | " generate a concise response: ", "") 20 | if knowledge.strip() == '' or 'no_passages_used' in knowledge: 21 | continue 22 | good_indices.append(i) 23 | else: 24 | good_indices = list(range(max_index)) 25 | random.shuffle(good_indices) 26 | indices = good_indices[:size] 27 | return indices 28 | 29 | def read_generations_from_index_file(generation_file, indices): 30 | """Read generations from index file.""" 31 | 32 | generations = pd.read_json(generation_file, lines=True) 33 | return generations.iloc[indices] 34 | 35 | def main(task='wow', size=100, seed=1234, do_sample=False, do_read=False): 36 | """Main function.""" 37 | max_indices = { 38 | 'wow': 3924, 39 | 'cnn_dailymail': 1780, 40 | } 41 | random.seed(seed) 42 | if do_sample: 43 | indices = sample_indices_for_human(task, max_indices[task], size=size) 44 | 45 | with open(f'generations/{task}_human_indices.txt', 'w') as f: 46 | for idx in indices: 47 | f.write(str(idx) + '\n') 48 | 49 | target_files = [ 50 | # 'generations/baseline/wow-openai_gpt-3.5-turbo.jsonl', 51 | # 'generations/fudge/wow-google-flan-t5-xl-fudge-DecoderDisc-wow-RAND.jsonl', 52 | # 'generations/ppl_mcts/wow-google-flan-t5-xl-DecoderDisc-wow-PARTIAL.jsonl', 53 | # 'generations/ppl_mcts/wow-google-flan-t5-xl-DecoderDisc-wow-EOS.jsonl', 54 | # 'generations/baseline/cnn_dailymail-openai_gpt-3.5-turbo.jsonl', 55 | # 'generations/baseline/cnn_dailymail-google-flan-t5-xl.jsonl', 56 | # 'generations/ppl_mcts/cnn_dailymail-google-flan-t5-xl-token_f1.jsonl', 57 | # 'generations/ppl_mcts/cnn_dailymail-google-flan-t5-xl-DecoderDisc-cnn_dailymail-EOS-only_mlp.jsonl', 58 | 'generations/baseline/wow-google-flan-t5-xl.jsonl', 59 | 'generations/ppl_mcts/wow-google-flan-t5-xl-DecoderDisc-wow-EOS-PARTIAL.jsonl', 60 | 'generations/fudge/wow-google-flan-t5-xl-fudge-DecoderDisc-wow-EOS-PARTIAL.jsonl' 61 | ] 62 | 63 | if do_read: 64 | indices = [] 65 | with open(f'generations/{task}_human_indices.txt', 'r') as f: 66 | for line in f: 67 | idx = int(line.strip()) 68 | indices.append(idx) 69 | for generation_file in target_files: 70 | if not os.path.exists(generation_file): 71 | print(f'{generation_file} does not exist. skipping...') 72 | continue 73 | else: 74 | print(f'{generation_file}') 75 | generations = read_generations_from_index_file(generation_file, indices) 76 | out_fname = os.path.splitext(generation_file)[0] 77 | generations.to_json(f'{out_fname}_human.jsonl', orient='records', lines=True) 78 | 79 | if __name__ == '__main__': 80 | fire.Fire(main) 81 | -------------------------------------------------------------------------------- /kcd/attribute_classifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/Knowledge-Constrained-Decoding/707f4de017c12ec6145b08249362e247ba8aa486/kcd/attribute_classifier/__init__.py -------------------------------------------------------------------------------- /kcd/attribute_classifier/attribute_dataloader.py: -------------------------------------------------------------------------------- 1 | from statistics import mean, median 2 | from datasets import load_dataset 3 | from torch.utils.data import DataLoader, Dataset 4 | from transformers.tokenization_utils_base import BatchEncoding 5 | 6 | 7 | def get_attribute_dataloader(dataname_or_data, 8 | tokenizer, 9 | max_length: int = 256, 10 | batch_size: int = 32, 11 | split: str = 'test', 12 | num_workers: int = 0): 13 | dataset = AttributeDataset(dataname_or_data, tokenizer, max_length=max_length, split=split) 14 | return DataLoader( 15 | dataset, 16 | batch_size=batch_size, 17 | shuffle=split == 'train', 18 | num_workers=num_workers, 19 | ) 20 | 21 | 22 | class AttributeDataset(Dataset): 23 | 24 | def __init__(self, 25 | dataname_or_data: str, 26 | tokenizer, 27 | max_length: int = 256, 28 | split: str = 'test', 29 | show_stats: bool = False) -> None: 30 | if isinstance(dataname_or_data, str): 31 | data = load_dataset(dataname_or_data) 32 | else: 33 | data = dataname_or_data 34 | data = data[split] 35 | 36 | self.labels = data['label'] 37 | self.texts = tokenizer(data['sentence'], 38 | return_tensors='pt', 39 | padding='max_length', 40 | truncation=True, 41 | max_length=max_length) 42 | 43 | if show_stats: 44 | print(f'[split]: {split}') 45 | lengths = [len(tokenizer.tokenize(x)) for x in data['sentence']] 46 | print('text length stats:') 47 | print( 48 | f'max: {max(lengths)}, mean: {mean(lengths)}, min: {min(lengths)}, median: {median(lengths)}' 49 | ) 50 | 51 | def __len__(self): 52 | return len(self.labels) 53 | 54 | def __getitem__(self, idx): 55 | data = {k: v[idx] for k, v in self.texts.items()} 56 | data = BatchEncoding(data, tensor_type='pt') 57 | label = self.labels[idx] 58 | 59 | data['labels'] = label 60 | 61 | return data -------------------------------------------------------------------------------- /kcd/attribute_classifier/evaluate_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.no_grad() 5 | def evaluate(model, dataloader, loss_fn, args, device='cpu'): 6 | """ Evaluates a given model and dataset. 7 | 8 | obtained from: 9 | https://github.com/fabio-deep/Distributed-Pytorch-Boilerplate/blob/master/src/evaluate.py 10 | """ 11 | model.eval() 12 | sample_count = 0 13 | running_loss = 0 14 | running_acc = 0 15 | 16 | for inputs in dataloader: 17 | inputs = inputs.to(device) 18 | labels = inputs.pop('labels') 19 | 20 | yhat = model(**inputs, pool_method=args.pool_method).logits 21 | loss = loss_fn(yhat, labels) 22 | 23 | sample_count += labels.size(0) 24 | running_loss += loss.item() * labels.size(0) # smaller batches count less 25 | running_acc += (yhat.argmax(-1) == labels).sum().item() # num corrects 26 | 27 | loss = running_loss / sample_count 28 | acc = running_acc / sample_count 29 | 30 | return loss, acc -------------------------------------------------------------------------------- /kcd/classifier_guidance/__init__.py: -------------------------------------------------------------------------------- 1 | from .guided_generation_predictor import GuidedGenerationPredictor 2 | from .metric_guidance import metric_guided_generation 3 | from .fudge_decode import fudge_generation 4 | from .openai_fudge_decode import openai_fudge_generation 5 | from .nado_decode import nado_generation 6 | from .astar_decode import astar_generation 7 | 8 | GENERATE_FN_REGISTRY = { 9 | 'metric_guidance': metric_guided_generation, 10 | 'fudge': fudge_generation, 11 | 'nado': nado_generation, 12 | 'astar': astar_generation, 13 | 'openai_fudge': openai_fudge_generation 14 | } 15 | 16 | 17 | def load_generate_fn(name, **kwargs): 18 | return GENERATE_FN_REGISTRY[name](**kwargs) 19 | -------------------------------------------------------------------------------- /kcd/classifier_guidance/astar_decode.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | import torch 5 | from transformers import top_k_top_p_filtering 6 | from tqdm.auto import tqdm 7 | 8 | from .fudge_decode import KnowledgeDiscriminator 9 | from .utils import zero_out_after_eos 10 | 11 | 12 | def astar_generation( 13 | model=None, 14 | max_new_tokens: int = 32, 15 | k: int = 200, 16 | future_steps: int = 5, 17 | lambda_weight: float = 0.25, 18 | disable_adapter_lm_forward=False, 19 | soft_forward=False, 20 | ): 21 | assert model is not None 22 | discriminator = KnowledgeDiscriminator(model) 23 | 24 | def _generate_fn(model, tokenizer, inputs): 25 | """input_ids: torch.LongTensor, knowledge_text: str""" 26 | input_ids = inputs['input_ids'] 27 | attention_mask = inputs['attention_mask'] 28 | if model.config.is_encoder_decoder: 29 | encoder_outputs = model.get_encoder()(input_ids=input_ids, 30 | attention_mask=attention_mask) 31 | else: 32 | encoder_outputs = None 33 | past_key_values = None 34 | generated_tokens = torch.full((input_ids.shape[0], 1), 35 | model.config.decoder_start_token_id, 36 | dtype=torch.long, 37 | device=input_ids.device) 38 | for i in tqdm(range(max_new_tokens), total=max_new_tokens): 39 | next_logit, past_key_values = astar_step( 40 | model, 41 | discriminator, 42 | generated_tokens, 43 | gen_inst_ids=input_ids, 44 | attention_mask=attention_mask, 45 | encoder_outputs=encoder_outputs, 46 | past_key_values=past_key_values, 47 | top_k=k, 48 | future_steps=future_steps, 49 | lambda_weight=lambda_weight, 50 | disable_adapter_lm_forward=disable_adapter_lm_forward, 51 | soft_forward=soft_forward) 52 | generated_tokens = torch.cat([generated_tokens, next_logit], dim=1) 53 | # early stopping based on eos 54 | if (generated_tokens == tokenizer.eos_token_id).any(dim=1).all(): 55 | break 56 | return zero_out_after_eos(generated_tokens, tokenizer.eos_token_id) 57 | 58 | return _generate_fn 59 | 60 | 61 | @torch.inference_mode() 62 | def astar_step(model, 63 | discriminator, 64 | decoded_ids: torch.LongTensor, 65 | gen_inst_ids: torch.LongTensor = None, 66 | attention_mask: torch.LongTensor = None, 67 | encoder_outputs=None, 68 | past_key_values=None, 69 | top_k: int = 200, 70 | top_p: float = 1.0, 71 | temperature: float = 1.0, 72 | future_steps: int = 5, 73 | lambda_weight: float = 0.25, 74 | disable_adapter_lm_forward: bool = False, 75 | soft_forward: bool = False): 76 | assert model.config.is_encoder_decoder 77 | # prepare input_ids for encoder_decoder or decoder-only models 78 | if model.config.is_encoder_decoder: 79 | if gen_inst_ids is None: 80 | raise ValueError("gen_inst_ids must be set.") 81 | input_ids = decoded_ids 82 | else: 83 | input_ids = decoded_ids 84 | if gen_inst_ids is not None: 85 | input_ids = torch.cat([gen_inst_ids, input_ids], dim=1) 86 | 87 | model_inputs = prepare_inputs_for_generation( 88 | model, 89 | encoder_outputs=encoder_outputs, 90 | input_ids=input_ids, 91 | attention_mask=attention_mask, 92 | past_key_values=past_key_values, 93 | use_cache=True, 94 | ) 95 | # first forward: get t+1 topk logits 96 | outputs = model(**model_inputs, 97 | output_attentions=False, 98 | output_hidden_states=False) 99 | if isinstance(outputs, tuple) and len(outputs) == 2: 100 | outputs, _ = outputs # ignore token classifier output 101 | past_key_values = outputs.past_key_values 102 | logits = outputs.logits 103 | next_logit = torch.softmax(logits[:, -1, :], dim=-1) # [B, V] 104 | values, indices = next_logit.topk(top_k, dim=-1) # [B, k] 105 | topk_logits = indices.T # [k, B] 106 | 107 | # future "rollout" for each topk token 108 | next_key_values = past_key_values 109 | all_probs = [] 110 | for logit in topk_logits: 111 | future_token = logit.unsqueeze(-1) 112 | future_tokens = [future_token] 113 | for i in range(future_steps): 114 | if soft_forward: 115 | raise NotImplementedError 116 | # do forward 117 | else: 118 | if disable_adapter_lm_forward: 119 | with model.disable_adapter(): 120 | outputs, _ = model(encoder_outputs=encoder_outputs, 121 | attention_mask=attention_mask, 122 | past_key_values=next_key_values, 123 | decoder_input_ids=future_token, 124 | use_cache=True) 125 | else: 126 | outputs, _ = model(encoder_outputs=encoder_outputs, 127 | attention_mask=attention_mask, 128 | past_key_values=next_key_values, 129 | decoder_input_ids=future_token, 130 | use_cache=True) 131 | future_logits = outputs.logits[:, -1, :] / temperature 132 | filtered_logits = top_k_top_p_filtering(future_logits, 133 | top_p=top_p, 134 | top_k=top_k) 135 | probs = torch.softmax(filtered_logits, dim=-1) 136 | future_token = torch.multinomial(probs, num_samples=1) # [B, 1] 137 | next_key_values = outputs.past_key_values 138 | future_tokens.append(future_token) 139 | future_tokens = torch.cat(future_tokens, dim=1) # [B, future_steps + 1] 140 | class_prob = discriminator(torch.cat([decoded_ids, future_tokens], dim=1), 141 | future_token.T, # [1, B] 142 | gen_inst_ids=gen_inst_ids, 143 | attention_mask=attention_mask, 144 | encoder_outputs=encoder_outputs) # [B,] 145 | all_probs.append(class_prob) 146 | all_probs = torch.stack(all_probs, dim=1) # [B, k] 147 | astar_topk = values + lambda_weight * all_probs # [B, k] 148 | max_idx = astar_topk.argmax(dim=-1) # [B,] 149 | max_logit = indices[range(indices.shape[0]), max_idx].unsqueeze(-1) # [B, 1] 150 | 151 | return max_logit, past_key_values 152 | 153 | 154 | def prepare_inputs_for_generation(model, encoder_outputs=None, **kwargs): 155 | if model.config.is_encoder_decoder: 156 | model_inputs = model.prepare_inputs_for_generation(encoder_outputs=encoder_outputs, 157 | **kwargs) 158 | else: 159 | model_inputs = model.prepare_inputs_for_generation(**kwargs) 160 | return model_inputs 161 | -------------------------------------------------------------------------------- /kcd/classifier_guidance/guided_generation_predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | from transformers import Seq2SeqTrainer 6 | from transformers.deepspeed import is_deepspeed_zero3_enabled 7 | 8 | 9 | class GuidedGenerationPredictor(Seq2SeqTrainer): 10 | 11 | def __init__(self, generate_fn: Callable = None, **kwargs): 12 | super().__init__(**kwargs) 13 | self.generate_fn = generate_fn 14 | 15 | def prediction_step( 16 | self, 17 | model: nn.Module, 18 | inputs: Dict[str, torch.Tensor | Any], 19 | prediction_loss_only: bool, 20 | ignore_keys: List[str] | None = None 21 | ) -> Tuple[float | None, torch.Tensor | None, torch.Tensor | None]: 22 | """ 23 | Copied from Seq2SeqTrainer.prediction_step, but with the following changes: 24 | - use `metric_guidance` instaed of model.generate 25 | """ 26 | if not self.args.predict_with_generate or prediction_loss_only: 27 | return super().prediction_step(model, 28 | inputs, 29 | prediction_loss_only=prediction_loss_only, 30 | ignore_keys=ignore_keys) 31 | 32 | has_labels = "labels" in inputs 33 | inputs = self._prepare_inputs(inputs) 34 | 35 | # XXX: adapt synced_gpus for fairscale as well 36 | # Priority (handled in generate): 37 | # gen_kwargs > model.generation_config > default GenerationConfig() 38 | gen_kwargs = self._gen_kwargs.copy() 39 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 40 | gen_kwargs["max_length"] = self.model.config.max_length 41 | gen_kwargs["num_beams"] = (gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") 42 | is not None else self.model.config.num_beams) 43 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 44 | gen_kwargs["synced_gpus"] = (gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") 45 | is not None else default_synced_gpus) 46 | 47 | # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate 48 | # (otherwise, it would continue generating from the padded `decoder_input_ids`) 49 | if ("labels" in inputs and "decoder_input_ids" in inputs and 50 | inputs["labels"].shape == inputs["decoder_input_ids"].shape): 51 | inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"} 52 | 53 | ############## NOTE: changes starts here ############################### 54 | generated_tokens = self.generate_fn(model, self.tokenizer, inputs) 55 | return None, generated_tokens, None 56 | -------------------------------------------------------------------------------- /kcd/classifier_guidance/metric_guidance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instead of classifiers, use metrics to score and re-rank tokens. 3 | """ 4 | import torch 5 | 6 | from .utils import zero_out_after_eos 7 | 8 | class MetricGuidance: 9 | """ 10 | Metric guidance class to be used with PplMCTS. 11 | """ 12 | 13 | def __init__(self, 14 | tokenizer, 15 | metric, 16 | metric_name=None, 17 | max_new_tokens: int = 32, 18 | k: int = 200) -> None: 19 | self.tokenizer = tokenizer 20 | self.metric = metric 21 | self.metric_name = metric_name 22 | self.max_new_tokens = max_new_tokens 23 | self.k = k 24 | 25 | def __call__(self, reference_ids, decoded_ids): 26 | """ 27 | reference_ids: torch.LongTensor of shape [B, T_1] 28 | decoder_ids: torch.LongTensor of shape [B, T_2] 29 | """ 30 | reference = self.tokenizer.batch_decode(reference_ids, skip_special_tokens=True) 31 | decoded = self.tokenizer.batch_decode(decoded_ids, skip_special_tokens=True) 32 | scores = self.metric(decoded, [[ref] for ref in reference]) 33 | if isinstance(scores, dict): 34 | assert self.metric_name is not None 35 | scores = scores[self.metric_name] 36 | if not isinstance(scores, torch.Tensor): 37 | scores = torch.tensor(scores) # [B,] 38 | return scores 39 | 40 | 41 | def metric_guided_generation(metric=None, metric_name=None, max_new_tokens: int = 32, k: int = 200): 42 | assert metric is not None 43 | 44 | def _generate_fn(model, tokenizer, inputs): 45 | """input_ids: torch.LongTensor, knowledge_text: str""" 46 | input_ids = inputs['input_ids'] 47 | knowledge_text = tokenizer.batch_decode(inputs['knowledge_ids'], skip_special_tokens=True) 48 | 49 | generated_tokens = torch.LongTensor([[model.config.decoder_start_token_id] 50 | ]).to(input_ids.device) 51 | generated_tokens = generated_tokens.repeat(input_ids.shape[0], 1) 52 | for i in range(max_new_tokens): 53 | next_logit = metric_guidance_step(model, 54 | tokenizer, 55 | metric, 56 | generated_tokens, 57 | knowledge_text, 58 | gen_inst_ids=input_ids, 59 | k=k, 60 | metric_name=metric_name) 61 | generated_tokens = torch.cat([generated_tokens, next_logit], dim=1) 62 | # early stopping based on eos 63 | if (generated_tokens == tokenizer.eos_token_id).any(dim=1).all(): 64 | break 65 | return zero_out_after_eos(generated_tokens, tokenizer.eos_token_id) 66 | 67 | return _generate_fn 68 | 69 | 70 | # the metric can be TokenF1Score.batch_compute 71 | @torch.inference_mode() 72 | def metric_guidance_step(model, 73 | tokenizer, 74 | metric, 75 | decoded_ids: torch.LongTensor, 76 | knowledge_text: list[str], 77 | gen_inst_ids: torch.LongTensor = None, 78 | k: int = 200, 79 | metric_name=None): 80 | if model.config.is_encoder_decoder: 81 | assert gen_inst_ids is not None 82 | input_ids = gen_inst_ids 83 | logits = model(input_ids=input_ids, decoder_input_ids=decoded_ids).logits # [1, T, V] 84 | else: 85 | input_ids = decoded_ids 86 | if gen_inst_ids is not None: 87 | input_ids = torch.cat([gen_inst_ids, input_ids], dim=1) 88 | else: 89 | input_ids = decoded_ids 90 | logits = model(input_ids=input_ids).logits 91 | next_logit = torch.softmax(logits[:, -1, :], dim=-1) # [B, V] 92 | values, indices = next_logit.topk(k, dim=-1) # [B, k] 93 | topk_logits = indices.T # [k, B] 94 | 95 | candidates = [] 96 | for idx in topk_logits: 97 | curr = torch.cat([decoded_ids[:, 1:], idx.unsqueeze(1)], dim=1) 98 | candidates.append(curr) 99 | 100 | candidates = torch.stack(candidates, dim=0) # [K, B, T] 101 | candidates = candidates.view(-1, candidates.shape[-1]) # [K * B, T] 102 | candidate_str = tokenizer.batch_decode(candidates, skip_special_tokens=True) # [K * B,] 103 | 104 | scores = metric(candidate_str, [[kt] for kt in knowledge_text] * k) 105 | if isinstance(scores, dict): 106 | scores = scores[metric_name] 107 | if not isinstance(scores, torch.Tensor): 108 | scores = torch.tensor(scores).to(values.device) # [K * B,] 109 | 110 | reranked = values * scores.view(k, -1).T # [B, K] 111 | max_idx = reranked.argmax(dim=-1) # [B,] 112 | max_logit = indices[range(max_idx.shape[0]), max_idx].unsqueeze(-1) # [B, 1] 113 | return max_logit 114 | -------------------------------------------------------------------------------- /kcd/classifier_guidance/nado_decode.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | import torch 5 | from tqdm.auto import tqdm 6 | from transformers import top_k_top_p_filtering 7 | 8 | from .fudge_decode import KnowledgeDiscriminator 9 | from .utils import zero_out_after_eos 10 | 11 | def nado_generation( 12 | model=None, 13 | max_new_tokens: int = 32, 14 | k: int = 200, 15 | alpha: int = 1.0, 16 | disable_adapter_lm_forward=False, 17 | ): 18 | assert model is not None 19 | discriminator = KnowledgeDiscriminator(model) 20 | 21 | def _generate_fn(model, tokenizer, inputs): 22 | """input_ids: torch.LongTensor, knowledge_text: str""" 23 | input_ids = inputs['input_ids'] 24 | attention_mask = inputs['attention_mask'] 25 | if model.config.is_encoder_decoder: 26 | encoder_outputs = model.get_encoder()(input_ids=input_ids, 27 | attention_mask=attention_mask) 28 | else: 29 | encoder_outputs = None 30 | past_key_values = None 31 | generated_tokens = torch.full((input_ids.shape[0], 1), 32 | model.config.decoder_start_token_id, 33 | dtype=torch.long, 34 | device=input_ids.device) 35 | score = torch.ones_like(generated_tokens, dtype=torch.float) 36 | for i in tqdm(range(max_new_tokens), total=max_new_tokens): 37 | next_logit, score, past_key_values = nado_step( 38 | model, 39 | discriminator, 40 | generated_tokens, 41 | score, 42 | gen_inst_ids=input_ids, 43 | attention_mask=attention_mask, 44 | encoder_outputs=encoder_outputs, 45 | past_key_values=past_key_values, 46 | k=k, 47 | alpha=alpha, 48 | disable_adapter_lm_forward=disable_adapter_lm_forward) 49 | generated_tokens = torch.cat([generated_tokens, next_logit], dim=1) 50 | # early stopping based on eos 51 | if (generated_tokens == tokenizer.eos_token_id).any(dim=1).all(): 52 | break 53 | return zero_out_after_eos(generated_tokens, tokenizer.eos_token_id) 54 | 55 | return _generate_fn 56 | 57 | 58 | @torch.inference_mode() 59 | def nado_step(model, 60 | discriminator, 61 | decoded_ids: torch.LongTensor, 62 | current_ids_score: torch.Tensor, 63 | gen_inst_ids: torch.LongTensor = None, 64 | attention_mask: torch.LongTensor = None, 65 | encoder_outputs=None, 66 | past_key_values=None, 67 | k: int = 200, 68 | alpha: float = 1.0, 69 | disable_adapter_lm_forward: bool = False): 70 | assert model.config.is_encoder_decoder 71 | # check v2 discriminator 72 | if hasattr(model, 'base_model'): # peft 73 | v2 = getattr(model.base_model, 'v2', False) 74 | else: 75 | v2 = getattr(model, 'v2', False) 76 | 77 | # prepare input_ids for encoder_decoder or decoder-only models 78 | if model.config.is_encoder_decoder: 79 | if gen_inst_ids is None: 80 | raise ValueError("gen_inst_ids must be set.") 81 | input_ids = decoded_ids 82 | else: 83 | input_ids = decoded_ids 84 | if gen_inst_ids is not None: 85 | input_ids = torch.cat([gen_inst_ids, input_ids], dim=1) 86 | 87 | if model.config.is_encoder_decoder: 88 | model_inputs = model.prepare_inputs_for_generation( 89 | input_ids=input_ids, 90 | encoder_outputs=encoder_outputs, 91 | attention_mask=attention_mask, 92 | past_key_values=past_key_values, 93 | use_cache=True) 94 | else: 95 | model_inputs = model.prepare_inputs_for_generation( 96 | input_ids=input_ids, 97 | attention_mask=attention_mask, 98 | past_key_values=past_key_values, 99 | use_cache=True) 100 | 101 | # do forward 102 | if disable_adapter_lm_forward: 103 | with model.disable_adapter(): 104 | outputs = model(**model_inputs, 105 | return_lm_only=True, 106 | output_attentions=False, 107 | output_hidden_states=False) 108 | if v2: 109 | # with adapter 110 | _, disc_outputs = model(**model_inputs, 111 | output_attentions=False, 112 | output_hidden_states=False) 113 | class_logit = disc_outputs.logits 114 | else: 115 | outputs = model(**model_inputs, 116 | output_attentions=False, 117 | output_hidden_states=False) 118 | if v2: 119 | outputs, disc_out = outputs 120 | class_logit = disc_out.logits # [B, V] 121 | 122 | # select the next logit 123 | if isinstance(outputs, tuple) and len(outputs) == 2: 124 | outputs, _ = outputs # ignore token classifier output 125 | 126 | next_key_values = outputs.past_key_values 127 | logits = outputs.logits 128 | 129 | if v2: 130 | class_logp = torch.nn.LogSigmoid()(class_logit) # [B, V] 131 | logits = logits[:, -1] 132 | logits = top_k_top_p_filtering(logits, top_k=k) 133 | next_logit = torch.log_softmax(logits, dim=-1) 134 | next_logit = next_logit + class_logp * alpha 135 | _, max_logit = next_logit.topk(1, dim=-1) # [B, 1] 136 | score = class_logp[range(class_logp.shape[0]), max_logit.squeeze(-1)] 137 | return max_logit, score, next_key_values 138 | 139 | 140 | next_logit = torch.softmax(logits[:, -1, :], dim=-1) # [B, V] 141 | values, indices = next_logit.topk(k, dim=-1) # [B, k] 142 | topk_logits = indices.T # [k, B] 143 | 144 | class_prob = discriminator(decoded_ids, 145 | topk_logits, 146 | gen_inst_ids=gen_inst_ids, 147 | attention_mask=attention_mask, 148 | encoder_outputs=encoder_outputs) # [K * B,] 149 | class_prob = class_prob.view(k, -1).T # [K, B] => [B, K] 150 | class_prob = class_prob / current_ids_score 151 | 152 | fudge_topk = values * class_prob # [B, K,] 153 | max_idx = fudge_topk.argmax(dim=-1) # [B,] 154 | max_logit = indices[range(indices.shape[0]), max_idx].unsqueeze(-1) # [B, 1] 155 | score = class_prob[range(indices.shape[0]), max_idx].unsqueeze(-1) # [B, 1] 156 | return max_logit, score, next_key_values 157 | -------------------------------------------------------------------------------- /kcd/classifier_guidance/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def zero_out_after_eos(sequence, eos_token_id): 5 | """In the batched greedy decode, the sequence generation is not stopped 6 | at eos. Manually zero-out the tokens after eos with padding tokens. 7 | """ 8 | row, col = torch.where(sequence == eos_token_id) 9 | for i, j in zip(row, col): 10 | col_idx = j + 1 11 | sequence[i][col_idx:] = eos_token_id 12 | 13 | return sequence 14 | -------------------------------------------------------------------------------- /kcd/configs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class GenerationConfig: 6 | max_new_tokens: int = 32 7 | do_sample: bool = True 8 | num_beams: int = 1 9 | temperature: float = 1.0 10 | top_p: float = 0.95 11 | num_return_sequences: int = 1 12 | top_k: int = 200 13 | -------------------------------------------------------------------------------- /kcd/dstc11_task5/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_walker import load_dstc_data 2 | -------------------------------------------------------------------------------- /kcd/dstc11_task5/dataset_walker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from datasets import Dataset 5 | 6 | from .knowledge_reader import KnowledgeReader 7 | 8 | 9 | def load_dstc_data(data_path): 10 | data = DatasetWalker(data_path, labels=True, incl_knowledge=True, do_format=True) 11 | dataset = Dataset.from_list([x for x in data]) 12 | dataset = dataset.filter(lambda x: x['answers'] is not None) 13 | config = { 14 | 'input_columns': ['ctxs', 'question'], 15 | 'instruction': "### History:\n{}\n\n### Knowledge:\n{}" 16 | "\n\nGiven the dialog history and a relevant knowledge above," 17 | " generate a knowledgeable, usefule, and helpful answer." 18 | } 19 | return dataset, config 20 | 21 | 22 | class DatasetWalker: 23 | """ 24 | Copied from https://github.com/alexa/dstc11-track5/blob/main/scripts/dataset_walker.py 25 | Adjusted by Sehyun Choi, 2023 26 | """ 27 | EOT = '' 28 | 29 | def __init__(self, 30 | data_path, 31 | labels=True, 32 | labels_file=None, 33 | incl_knowledge=True, 34 | do_format=True): 35 | dataset = os.path.basename(data_path) 36 | dataroot = os.path.dirname(data_path) 37 | 38 | path = dataroot 39 | 40 | if dataset not in ['train', 'val', 'test']: 41 | raise ValueError('Wrong dataset name: %s' % (dataset)) 42 | 43 | logs_file = os.path.join(path, dataset, 'logs.json') 44 | with open(logs_file, 'r') as f: 45 | self.logs = json.load(f) 46 | 47 | self.labels = None 48 | 49 | if labels is True: 50 | if labels_file is None: 51 | labels_file = os.path.join(path, dataset, 'labels.json') 52 | 53 | with open(labels_file, 'r') as f: 54 | self.labels = json.load(f) 55 | 56 | self._incl_knowledge = incl_knowledge 57 | if self._incl_knowledge is True: 58 | self._knowledge = KnowledgeReader(dataroot) 59 | 60 | self.do_format = do_format 61 | 62 | def __getitem__(self, idx): 63 | log = self.logs[idx] 64 | if self.labels is not None: 65 | label = self.labels[idx] 66 | if self._incl_knowledge is True and label['target'] is True: 67 | for idx, snippet in enumerate(label['knowledge']): 68 | domain = snippet['domain'] 69 | entity_id = snippet['entity_id'] 70 | doc_type = snippet['doc_type'] 71 | doc_id = snippet['doc_id'] 72 | 73 | if doc_type == 'review': 74 | sent_id = snippet['sent_id'] 75 | sent = self._knowledge.get_review_sent(domain, entity_id, doc_id, sent_id) 76 | label['knowledge'][idx]['sent'] = sent 77 | 78 | elif doc_type == 'faq': 79 | doc = self._knowledge.get_faq_doc(domain, entity_id, doc_id) 80 | question = doc['question'] 81 | answer = doc['answer'] 82 | 83 | label['knowledge'][idx]['question'] = question 84 | label['knowledge'][idx]['answer'] = answer 85 | else: 86 | label = None 87 | 88 | if self.do_format: 89 | return self.format_log(log, label=label) 90 | return log, label 91 | 92 | def format_log(self, log, label=None): 93 | data = {} 94 | history = [] 95 | speakers = [] 96 | for turn in log: 97 | speaker = 'User' if turn['speaker'] == 'U' else 'System' 98 | turn_text = f"{speaker}: {turn['text']}" 99 | history.append(turn_text) 100 | speakers.append('0' if speaker == 'User' else '1') 101 | data['question'] = self.EOT.join(history) 102 | data['user'] = ','.join(speakers) 103 | 104 | if label is None or not label['target']: 105 | data['ctxs'] = None 106 | data['answers'] = None 107 | else: 108 | knowledges = [] 109 | for knowledge in label['knowledge']: 110 | if 'sent' in knowledge: 111 | knowledges.append(knowledge['sent']) 112 | else: 113 | knowledges.append(f"Q: {knowledge['question']}\nA: {knowledge['answer']}") 114 | data['ctxs'] = '\n'.join(knowledges) 115 | data['answers'] = label['response'] 116 | data['label'] = 1 117 | 118 | return data 119 | 120 | def __len__(self): 121 | return len(self.logs) 122 | -------------------------------------------------------------------------------- /kcd/dstc11_task5/knowledge_reader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from https://github.com/alexa/dstc11-track5/blob/main/scripts/knowledge_reader.py 3 | """ 4 | import os 5 | import json 6 | 7 | class KnowledgeReader(object): 8 | def __init__(self, dataroot, knowledge_file='knowledge.json'): 9 | path = os.path.join(os.path.abspath(dataroot)) 10 | 11 | with open(os.path.join(path, knowledge_file), 'r') as f: 12 | self.knowledge = json.load(f) 13 | 14 | def get_domain_list(self): 15 | return list(self.knowledge.keys()) 16 | 17 | def get_entity_list(self, domain): 18 | if domain not in self.get_domain_list(): 19 | raise ValueError("invalid domain name") 20 | 21 | entity_ids = [] 22 | for entity_id in self.knowledge[domain].keys(): 23 | entity_ids.append(int(entity_id)) 24 | 25 | result = [] 26 | for entity_id in sorted(entity_ids): 27 | entity_name = self.knowledge[domain][str(entity_id)]['name'] 28 | result.append({'id': entity_id, 'name': entity_name}) 29 | 30 | return result 31 | 32 | def get_entity_name(self, domain, entity_id): 33 | if domain not in self.get_domain_list(): 34 | raise ValueError("invalid domain name: %s" % domain) 35 | 36 | if str(entity_id) not in self.knowledge[domain]: 37 | raise ValueError("invalid entity id: %s" % str(entity_id)) 38 | 39 | result = self.knowledge[domain][str(entity_id)]['name'] or None 40 | 41 | return result 42 | 43 | def get_faq_doc_ids(self, domain, entity_id): 44 | if domain not in self.get_domain_list(): 45 | raise ValueError("invalid domain name: %s" % domain) 46 | 47 | result = [] 48 | 49 | if str(entity_id) not in self.knowledge[domain]: 50 | raise ValueError("invalid entity id: %s" % str(entity_id)) 51 | 52 | entity_obj = self.knowledge[domain][str(entity_id)] 53 | for doc_id, doc_obj in entity_obj['faqs'].items(): 54 | result.append(doc_id) 55 | 56 | return result 57 | 58 | def get_faq_doc(self, domain, entity_id, doc_id): 59 | if domain not in self.get_domain_list(): 60 | raise ValueError("invalid domain name: %s" % domain) 61 | 62 | if str(entity_id) not in self.knowledge[domain]: 63 | raise ValueError("invalid entity id: %s" % str(entity_id)) 64 | 65 | entity_name = self.get_entity_name(domain, entity_id) 66 | 67 | if str(doc_id) not in self.knowledge[domain][str(entity_id)]['faqs']: 68 | raise ValueError("invalid doc id: %s" % str(doc_id)) 69 | 70 | doc_obj = self.knowledge[domain][str(entity_id)]['faqs'][str(doc_id)] 71 | result = {'domain': domain, 'entity_id': entity_id, 'entity_name': entity_name, 'doc_id': doc_id, 'question': doc_obj['question'], 'answer': doc_obj['answer']} 72 | 73 | return result 74 | 75 | def get_review_doc_ids(self, domain, entity_id): 76 | if domain not in self.get_domain_list(): 77 | raise ValueError("invalid domain name: %s" % domain) 78 | 79 | if str(entity_id) not in self.knowledge[domain]: 80 | raise ValueError("invalid entity id: %s" % str(entity_id)) 81 | 82 | result = [] 83 | 84 | entity_obj = self.knowledge[domain][str(entity_id)] 85 | for doc_id, doc_obj in entity_obj['reviews'].items(): 86 | result.append(doc_id) 87 | 88 | return result 89 | 90 | def get_review_doc(self, domain, entity_id, doc_id): 91 | if domain not in self.get_domain_list(): 92 | raise ValueError("invalid domain name: %s" % domain) 93 | 94 | if str(entity_id) not in self.knowledge[domain]: 95 | raise ValueError("invalid entity id: %s" % str(entity_id)) 96 | 97 | entity_name = self.get_entity_name(domain, entity_id) 98 | 99 | if str(doc_id) not in self.knowledge[domain][str(entity_id)]['reviews']: 100 | raise ValueError("invalid doc id: %s" % str(doc_id)) 101 | 102 | doc_obj = self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)] 103 | 104 | result = {'domain': domain, 'entity_id': entity_id, 'entity_name': entity_name, 'doc_id': doc_id, 'sentences': doc_obj['sentences']} 105 | if 'traveler_type' in doc_obj: 106 | result['traveler_type'] = doc_obj['traveler_type'] 107 | 108 | if 'dishes' in doc_obj: 109 | result['dishes'] = doc_obj['dishes'] 110 | 111 | if 'drinks' in doc_obj: 112 | result['drinks'] = doc_obj['drinks'] 113 | 114 | return result 115 | 116 | def get_review_sent(self, domain, entity_id, doc_id, sent_id): 117 | if domain not in self.get_domain_list(): 118 | raise ValueError("invalid domain name: %s" % domain) 119 | 120 | if str(entity_id) not in self.knowledge[domain]: 121 | raise ValueError("invalid entity id: %s" % str(entity_id)) 122 | 123 | if str(doc_id) not in self.knowledge[domain][str(entity_id)]['reviews']: 124 | raise ValueError("invalid doc id: %s" % str(doc_id)) 125 | 126 | if str(sent_id) not in self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]['sentences']: 127 | raise ValueError("invalid sentence id: %s" % str(sent_id)) 128 | 129 | result = self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]['sentences'][str(sent_id)] 130 | 131 | return result 132 | -------------------------------------------------------------------------------- /kcd/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .token_f1_score import TokenF1Score 2 | from .auto_evaluation import evaluate, evaluate_per_sent 3 | -------------------------------------------------------------------------------- /kcd/evaluation/auto_evaluation.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | from torchmetrics import BLEUScore 4 | from torchmetrics.text.rouge import ROUGEScore 5 | 6 | 7 | def evaluate_per_sent(preds: list[str], targets: list[list[str]], metric: str, bleu_weights=None): 8 | all_scores = [evaluate(p, t, metrics=(metric,), bleu_weights=bleu_weights)[metric].item() 9 | for p, t in zip(preds, targets)] 10 | return all_scores 11 | 12 | 13 | def evaluate(preds, target, metrics=('bleu', 'rougeL'), bleu_weights=None): 14 | metric_fn = {} 15 | if 'bleu' in metrics: 16 | if bleu_weights is not None: 17 | bleu = BLEUScore(weights=bleu_weights) 18 | else: 19 | bleu = BLEUScore() 20 | metric_fn['bleu'] = bleu 21 | if 'rougeL' in metrics: 22 | rouge = ROUGEScore(rouge_keys="rougeL") 23 | metric_fn['rougeL'] = rouge 24 | 25 | scores = {} 26 | for metric, func in metric_fn.items(): 27 | scores[metric] = func(preds, target) 28 | 29 | return scores 30 | 31 | 32 | if __name__ == '__main__': 33 | preds = ["hello there", "general kenobi"] 34 | target = [["hello there", "hi there"], ["master kenobi", "general canopi"]] 35 | scores = evaluate(preds, target) 36 | pprint(scores) 37 | -------------------------------------------------------------------------------- /kcd/evaluation/token_f1_score.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from parlai. 3 | 4 | Sehyun Choi, 2023 5 | """ 6 | import re 7 | from collections import Counter 8 | from typing import List 9 | 10 | re_art = re.compile(r'\b(a|an|the)\b') 11 | re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']') 12 | 13 | 14 | def normalize_answer(s): 15 | """ 16 | Lower text and remove punctuation, articles and extra whitespace. 17 | """ 18 | s = s.lower() 19 | s = re_punc.sub(' ', s) 20 | s = re_art.sub(' ', s) 21 | s = ' '.join(s.split()) 22 | return s 23 | 24 | 25 | class TokenF1Score: 26 | """ 27 | Helper class which computes token-level F1. 28 | """ 29 | 30 | @staticmethod 31 | def prec_recall_f1_score(pred_items, gold_items): 32 | """ 33 | Compute precision, recall and f1 given a set of gold and prediction items. 34 | 35 | :param pred_items: iterable of predicted values 36 | :param gold_items: iterable of gold values 37 | 38 | :return: tuple (p, r, f1) for precision, recall, f1 39 | """ 40 | common = Counter(gold_items) & Counter(pred_items) 41 | num_same = sum(common.values()) 42 | if num_same == 0: 43 | return 0, 0, 0 44 | precision = 1.0 * num_same / len(pred_items) 45 | recall = 1.0 * num_same / len(gold_items) 46 | f1 = (2 * precision * recall) / (precision + recall) 47 | return precision, recall, f1 48 | 49 | @staticmethod 50 | def compute(guess: str, answers: List[str], expose_p_and_r: bool = False): 51 | g_tokens = normalize_answer(guess).split() 52 | scores = [ 53 | TokenF1Score.prec_recall_f1_score(g_tokens, 54 | normalize_answer(a).split()) for a in answers 55 | ] 56 | max_p, max_r, max_f1 = 0, 0, 0 57 | for p, r, f1 in scores: 58 | max_p, max_r, max_f1 = max(max_p, p), max(max_r, r), max(f1, max_f1) 59 | if expose_p_and_r: 60 | return max_p, max_r, max_f1 61 | return max_f1 62 | 63 | @staticmethod 64 | def batch_compute(guesses: List[str], answers: List[List[str]], expose_p_and_r: bool = False): 65 | assert len(guesses) == len(answers) 66 | return [ 67 | TokenF1Score.compute(guess, answer, expose_p_and_r=expose_p_and_r) 68 | for guess, answer in zip(guesses, answers) 69 | ] 70 | -------------------------------------------------------------------------------- /kcd/instructions.py: -------------------------------------------------------------------------------- 1 | BASE_INSTRUCTION_TEMPLATE = "{}\n\n{}" 2 | ALPACA_INSTRUCTION_TEMPLATE = ("Below is an instruction that describes a task, " 3 | "paired with an input that provides further context. " 4 | "Write a response that appropriately completes the request.\n\n" 5 | "### Instruction:\n{}\n\n" 6 | "### Input:\n{}\n\n" 7 | "### Response:") 8 | 9 | OPENAI_INSTRUCTION = ("Use the following knowledge, but not directly copy, " 10 | "to generate a concise response: \"{}\"") 11 | 12 | TASK_INSTRUCTIONS = { 13 | 'wow': { 14 | 'instruction': "Given the dialog history and a relevant knowledge above," 15 | " generate a knowledgeable, useful, and helpful answer.", 16 | 'input': "History:\n{}\n\nKnowledge:\n{}", 17 | 'param': ['question', 'knowledge'] 18 | }, 19 | 'fever': { 20 | 'instruction': "Generate a claim that is entirely supported by the evidences above.", 21 | 'input': "Evidences:\n{}", 22 | 'param': ['knowledge'] 23 | }, 24 | 'dstc11_task5': { 25 | 'instruction': "Given the dialog history and a relevant knowledge above," 26 | " generate a knowledgeable, useful, and helpful answer.", 27 | 'input': "History:\n{}\n\nKnowledge:\n{}", 28 | 'param': ['question', 'knowledge'] 29 | }, 30 | 'summarization': { 31 | 'instruction': "Given the article above, generate a faithful summary.", 32 | 'input': "### Document:\n{}", 33 | 'param': ['knowledge'] 34 | } 35 | } 36 | 37 | CLAIM_CLASSFICATION_INSTRUCTION = ("Given some evidences, determine whether the claim " 38 | "is supported by the evidences or not.\n\n" 39 | "### Claim:\n{}\n\n" 40 | "### Evidences:\n{}\n\n" 41 | "### Choices:\n- {}\n- {}") 42 | 43 | WOW_CLASSFICATION_INSTRUCTION = ("Given the dialog history and a relevant knowledge, " 44 | "determine whether the response is supported by " 45 | "the knowledge or not." 46 | "### Knowledge:\n{}\n\n" 47 | "### History:\n{}\n\n" 48 | "### Response:\n{}\n\n" 49 | "### Choices:\n- Yes\n- No") 50 | 51 | def get_instruction(model, task, **kwargs): 52 | if model == 'openai': 53 | if 'knowledge' not in kwargs: 54 | raise ValueError('Missing parameter: knowledge') 55 | return OPENAI_INSTRUCTION.format(kwargs.get('knowledge')) 56 | 57 | if task in ('cnn_dailymail', 'xsum'): 58 | task = 'summarization' 59 | components = TASK_INSTRUCTIONS[task] 60 | instruction_params = [] 61 | for param in components['param']: 62 | if param not in kwargs: 63 | raise ValueError(f'Missing parameter: {param}') 64 | instruction_params.append(kwargs.get(param)) 65 | 66 | inst_template = get_model_base_instruction(model, 67 | instruction=components['instruction'], 68 | input_text=components['input']) 69 | return inst_template.format(*instruction_params) 70 | 71 | 72 | def get_model_base_instruction(model, instruction=None, input_text=None): 73 | if model == 'openai': 74 | return OPENAI_INSTRUCTION 75 | assert instruction is not None and input_text is not None 76 | if model == 'alpaca': 77 | return ALPACA_INSTRUCTION_TEMPLATE.format(instruction, input_text) 78 | return BASE_INSTRUCTION_TEMPLATE.format(input_text, instruction) 79 | -------------------------------------------------------------------------------- /kcd/kilt/knowledge_source.py: -------------------------------------------------------------------------------- 1 | from pymongo import MongoClient 2 | import requests 3 | from urllib.parse import unquote 4 | import urllib.request 5 | from bs4 import BeautifulSoup 6 | import urllib.parse as urlparse 7 | from urllib.parse import parse_qs 8 | 9 | DEFAULT_MONGO_CONNECTION_STRING = "mongodb://127.0.0.1:27017/admin" 10 | 11 | 12 | def _get_pageid_from_api(title, client=None): 13 | pageid = None 14 | 15 | title_html = title.strip().replace(" ", "%20") 16 | url = ( 17 | "https://en.wikipedia.org/w/api.php?action=query&titles={}&format=json".format(title_html)) 18 | 19 | try: 20 | # Package the request, send the request and catch the response: r 21 | r = requests.get(url) 22 | 23 | # Decode the JSON data into a dictionary: json_data 24 | json_data = r.json() 25 | 26 | if len(json_data["query"]["pages"]) > 1: 27 | print("WARNING: more than one result returned from wikipedia api") 28 | 29 | for _, v in json_data["query"]["pages"].items(): 30 | pageid = v["pageid"] 31 | 32 | except Exception as e: 33 | # print("Exception: {}".format(e)) 34 | pass 35 | 36 | return pageid 37 | 38 | 39 | def _read_url(url): 40 | with urllib.request.urlopen(url) as response: 41 | html = response.read() 42 | soup = BeautifulSoup(html, features="html.parser") 43 | title = soup.title.string.replace(" - Wikipedia", "").strip() 44 | return title 45 | 46 | 47 | def _get_title_from_wikipedia_url(url, client=None): 48 | title = None 49 | try: 50 | title = _read_url(url) 51 | except Exception: 52 | try: 53 | # try adding https 54 | title = _read_url("https://" + url) 55 | except Exception: 56 | # print("Exception: {}".format(e)) 57 | pass 58 | return title 59 | 60 | 61 | class KnowledgeSource: 62 | 63 | def __init__( 64 | self, 65 | mongo_connection_string=None, 66 | database="kilt", 67 | collection="knowledgesource", 68 | ): 69 | if not mongo_connection_string: 70 | mongo_connection_string = DEFAULT_MONGO_CONNECTION_STRING 71 | self.client = MongoClient(mongo_connection_string) 72 | self.db = self.client[database][collection] 73 | 74 | def get_all_pages_cursor(self): 75 | cursor = self.db.find({}) 76 | return cursor 77 | 78 | def get_num_pages(self): 79 | return self.db.estimated_document_count() 80 | 81 | def get_page_by_id(self, wikipedia_id): 82 | page = self.db.find_one({"id": str(wikipedia_id)}) 83 | return page 84 | 85 | def get_page_by_title(self, wikipedia_title, attempt=0): 86 | page = self.db.find_one({"id": str(wikipedia_title)}) 87 | return page 88 | 89 | def get_page_from_url(self, url): 90 | page = None 91 | 92 | # 1. try to look for title in the url 93 | parsed = urlparse.urlparse(url) 94 | record = parse_qs(parsed.query) 95 | if "title" in record: 96 | title = record["title"][0].replace("_", " ") 97 | page = self.get_page_by_title(title) 98 | 99 | # 2. try another way to look for title in the url 100 | if page == None: 101 | title = url.split("/")[-1].replace("_", " ") 102 | page = self.get_page_by_title(title) 103 | 104 | # 3. try to retrieve the current wikipedia_id from the url 105 | if page == None: 106 | title = _get_title_from_wikipedia_url(url, client=self.client) 107 | if title: 108 | pageid = _get_pageid_from_api(title, client=self.client) 109 | if pageid: 110 | page = self.get_page_by_id(pageid) 111 | 112 | return page 113 | -------------------------------------------------------------------------------- /kcd/kilt/load_kilt_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from datasets import Dataset, load_from_disk 4 | import pandas as pd 5 | 6 | 7 | def load_wow(path): 8 | wow_config = { 9 | 'input_columns': ['ctxs', 'question'], 10 | 'instruction': "History:\n{}\n\nKnowledge:\n{}" 11 | "\n\nGiven the dialog history and a relevant knowledge above," 12 | " generate a knowledgeable, usefule, and helpful answer." 13 | } 14 | if os.path.isdir(path): 15 | dataset = load_from_disk(path) 16 | return dataset, wow_config 17 | df = pd.read_json(path, lines=True) 18 | df['ctxs'] = df['ctxs'].apply(lambda x: x[0]['text'] if x is not None else None) 19 | df['answers'] = df['answers'].apply(lambda x: x[0]) 20 | df['label'] = 1 21 | dataset = Dataset.from_pandas(df) 22 | 23 | return dataset, wow_config 24 | 25 | 26 | def load_fever(path): 27 | fever_config = { 28 | 'input_columns': ['ctxs'], 29 | 'instruction': "Evidences:\n{}\n\nGenerate a claim that is" 30 | " entirely supported by the evidences above." 31 | } 32 | if os.path.isdir(path): 33 | dataset = load_from_disk(path) 34 | return dataset, fever_config 35 | df = pd.read_json(path, lines=True) 36 | df['answers'] = df['answers'].apply(lambda x: x[0]) 37 | df.loc[df['ctxs'].isna(), 'answers'] = 'NOT ENOUGH INFO' 38 | 39 | def _answer2label(x): 40 | if x == 'NOT ENOUGH INFO': 41 | return 0 42 | elif x == 'SUPPORTS': 43 | return 1 44 | elif x == 'REFUTES': 45 | return 2 46 | else: 47 | raise ValueError 48 | 49 | df['label'] = df['answers'].apply(_answer2label) 50 | # process ctxs 51 | evidences = [] 52 | for ctx in df['ctxs']: 53 | if ctx is None: 54 | evidences.append(None) 55 | continue 56 | evid = 0 57 | evidence = [] 58 | for ev in ctx: 59 | if ev is None: 60 | continue 61 | evidence.append(f'Knowledge {evid}: {ev["text"].strip()}') 62 | evid += 1 63 | evidence = '\n'.join(evidence) 64 | evidences.append(evidence) 65 | df['ctxs'] = evidences 66 | dataset = Dataset.from_pandas(df) 67 | return dataset, fever_config 68 | 69 | 70 | def main(path, dataset): 71 | if dataset == 'wow': 72 | df = load_wow(path) 73 | elif dataset == 'fever': 74 | df = load_fever(path) 75 | else: 76 | raise ValueError 77 | 78 | print(df) 79 | 80 | 81 | if __name__ == '__main__': 82 | import fire 83 | fire.Fire(main) 84 | -------------------------------------------------------------------------------- /kcd/kilt/preprocess_fever.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import sys 9 | from tqdm.auto import tqdm 10 | from knowledge_source import KnowledgeSource 11 | 12 | ks = KnowledgeSource() 13 | 14 | 15 | def convert_kilt(inputpath, outputpath): 16 | data = [] 17 | inputdata = open(inputpath, "r") 18 | for example in tqdm(inputdata): 19 | d = {} 20 | ex = json.loads(example) 21 | d["question"] = ex["claim"] 22 | d["answers"] = ex["label"] 23 | d["id"] = ex["id"] 24 | 25 | if ex['label'] == 'NOT ENOUGH INFO': 26 | d['evidence'] = None 27 | continue 28 | 29 | evidence_ids = {} 30 | for ev in ex["evidence"]: 31 | for ann in ev: 32 | _, _, wikipedia_title, sentence_id = ann 33 | if wikipedia_title is None: 34 | continue 35 | if wikipedia_title not in evidence_ids: 36 | evidence_ids[wikipedia_title] = {sentence_id} 37 | else: 38 | evidence_ids[wikipedia_title].add(sentence_id) 39 | 40 | evidences = {} 41 | for title, sent_ids in evidence_ids.items(): 42 | page = ks.get_page_by_id(title) 43 | if page is None: 44 | continue 45 | sentence = [] 46 | sents = [t.strip() + '.' for t in page['text'].split(' . ')] 47 | for sid in sent_ids: 48 | try: 49 | sentence.append(sents[sid]) 50 | except: 51 | pass 52 | evidences[title] = sentence 53 | 54 | if evidences: 55 | d["evidence"] = evidences 56 | else: 57 | d["evidence"] = None 58 | data.append(d) 59 | with open(outputpath, "w") as fout: 60 | json.dump(data, fout) 61 | 62 | 63 | if __name__ == "__main__": 64 | inputpath = sys.argv[1] 65 | outputpath = sys.argv[2] 66 | convert_kilt(inputpath, outputpath) 67 | -------------------------------------------------------------------------------- /kcd/kilt/preprocess_kilt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import sys 9 | from tqdm.auto import tqdm 10 | from knowledge_source import KnowledgeSource 11 | 12 | # get the knowledge souce 13 | ks = KnowledgeSource() 14 | 15 | 16 | def convert_kilt(inputpath, outputpath): 17 | data = [] 18 | inputdata = open(inputpath, "r") 19 | for example in tqdm(inputdata): 20 | d = {} 21 | ex = json.loads(example) 22 | d["question"] = ex["input"] 23 | answers = set() 24 | for a in ex["output"]: 25 | if "answer" in a: 26 | answers.add(a["answer"]) 27 | d["answers"] = list(answers) 28 | d["id"] = ex["id"] 29 | passages = [] 30 | 31 | if 'provenance' not in ex['output'][0]: 32 | d["ctxs"] = None 33 | data.append(d) 34 | continue 35 | 36 | for c in ex["output"][0]["provenance"]: 37 | page = ks.get_page_by_id(c["wikipedia_id"]) 38 | text = [] 39 | if c['start_paragraph_id'] == c['end_paragraph_id']: 40 | # single paragraph 41 | pid = c['start_paragraph_id'] 42 | text = page['text'][pid][c['start_character']:c['end_character'] + 1] 43 | else: 44 | for pid in range(c['start_paragraph_id'], c['end_paragraph_id'] + 1): 45 | if pid == c['start_paragraph_id']: # start 46 | t = page['text'][pid][c['start_character']:] 47 | elif pid == c['end_paragraph_id']: # end 48 | t = page['text'][pid][:c['end_character'] + 1] 49 | else: # inbetween 50 | t = page['text'][pid] 51 | text.append(t) 52 | 53 | p = { 54 | "text": text, 55 | "title": page["wikipedia_title"], 56 | "wikipedia_id": page["wikipedia_id"] 57 | } 58 | passages.append(p) 59 | d["ctxs"] = passages 60 | data.append(d) 61 | with open(outputpath, "w") as fout: 62 | for entry in data: 63 | json.dump(entry, fout) 64 | fout.write('\n') 65 | 66 | 67 | if __name__ == "__main__": 68 | inputpath = sys.argv[1] 69 | outputpath = sys.argv[2] 70 | convert_kilt(inputpath, outputpath) 71 | -------------------------------------------------------------------------------- /kcd/openai_module.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Union 3 | import os 4 | import openai 5 | 6 | 7 | @dataclass 8 | class OpenAIAPIParameters: 9 | max_tokens: int = 16 10 | temperature: float = 1 # between (0, 2) 11 | top_p: float = 1 12 | n: int = 1 # number of completions for each prompt) 13 | logprobs: int = 0 # max 5, only for Completion API 14 | stop: Union[str, list[str]] = None # stop token, max 4 15 | best_of: int = 1 16 | logit_bias: dict[str, int] = None 17 | presence_penalty: float = 0 # between (-2, 2), 18 | frequency_penalty: float = 0 # between (-2, 2), 19 | 20 | def __post_init__(self): 21 | if self.logit_bias is None: 22 | self.logit_bias = dict() 23 | 24 | 25 | class OpenAIModel: 26 | 27 | def __init__(self, model_name: str = "text-davinci-003", task: str = 'completion'): 28 | self.model_name = model_name 29 | self.openai = openai 30 | self.openai.organization = os.getenv("OPENAI_ORG_ID") 31 | self.openai.api_key = os.getenv("OPENAI_API_KEY") 32 | if self.model_name == 'gpt-3.5-turbo': 33 | self.task = 'chat' 34 | else: 35 | self.task = task 36 | 37 | def __call__(self, 38 | prompt: str | list[dict[str, str]], 39 | parameters: OpenAIAPIParameters, 40 | suffix: str = None): 41 | return self.get_response(prompt, parameters, suffix=suffix) 42 | 43 | def get_response(self, 44 | prompt: str | list[dict[str, str]], 45 | parameters: OpenAIAPIParameters, 46 | suffix: str = None): 47 | if self.task == 'completion': 48 | response = self.openai.Completion.create(prompt=prompt, 49 | model=self.model_name, 50 | **parameters.__dict__, 51 | suffix=suffix) 52 | text = response['choices'][0]['text'] 53 | else: # chat 54 | assert isinstance(prompt, list) 55 | kwargs = parameters.__dict__ 56 | # not used for chat API 57 | kwargs.pop('logprobs', None) 58 | kwargs.pop('best_of', None) 59 | response = self.openai.ChatCompletion.create(messages=prompt, 60 | model=self.model_name, 61 | **kwargs) 62 | text = response['choices'][0]['message']['content'] 63 | return text, response 64 | 65 | 66 | class MockOpenAIModel: 67 | 68 | def __init__(self, model_name: str = "text-davinci-003", task: str = 'completion') -> None: 69 | from transformers import GPT2Tokenizer 70 | self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 71 | self.model_name = model_name 72 | self.task = task 73 | 74 | def __call__(self, 75 | prompt: str | list[dict[str, str]], 76 | parameters: OpenAIAPIParameters, 77 | suffix: str = None): 78 | return self.get_response(prompt, parameters, suffix=suffix) 79 | 80 | def get_response(self, 81 | prompt: str | list[dict[str, str]], 82 | parameters: OpenAIAPIParameters, 83 | suffix: str = None): 84 | mock_tokens = self.tokenizer.convert_ids_to_tokens([1, 2, 3, 4, 5]) 85 | mock_logprobs = [-1, -2, -3, -4, -5] 86 | n_prompt_tokens = len(self.tokenizer.encode(prompt)) 87 | total_usage = n_prompt_tokens + 1 # 1 for generate 1 token 88 | mock_response = { 89 | 'choices': [{ 90 | 'text': 'mock response', 91 | 'logprobs': { 92 | 'tokens': ['!'], 93 | 'token_logprobs': [-0.9], 94 | 'top_logprobs': [dict(zip(mock_tokens, mock_logprobs))] 95 | } 96 | }], 97 | "usage": { 98 | "prompt_tokens": n_prompt_tokens, 99 | "completion_tokens": 1, 100 | "total_tokens": total_usage 101 | } 102 | } 103 | return 'mock response', mock_response -------------------------------------------------------------------------------- /kcd/partial_negative.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | import random 4 | 5 | from datasets import concatenate_datasets, Dataset 6 | import torch 7 | from transformers import (HfArgumentParser, Seq2SeqTrainingArguments) 8 | from kcd.classifier_guidance.guided_generation_predictor import GuidedGenerationPredictor 9 | 10 | from kcd.instructions import get_instruction 11 | from kcd.token_classifier.dataloader import DataCollatorForSeq2SeqTokenClassification 12 | from kcd.util import load_transformer_LM_tokenizer 13 | from kcd.wizard_of_wikipedia import load_wow 14 | from kcd.summarization import load_summary_data 15 | from kcd.configs import GenerationConfig 16 | 17 | @dataclass 18 | class ExperimentArgs: 19 | dataset_name: str = 'wow' 20 | dataset_path: str = 'data/wow.jsonl' 21 | model_name: str = "google/flan-t5-xl" 22 | instruction_model: str = 'basic' 23 | max_neg_samples: int = 100000 24 | load_8bit: bool = True 25 | first_n: int = 3 26 | last_n: int = 5 27 | num_workers: int = 16 28 | 29 | 30 | def main(): 31 | parser = HfArgumentParser([ExperimentArgs, GenerationConfig, Seq2SeqTrainingArguments]) 32 | args, gen_parameters, train_args = parser.parse_args_into_dataclasses() 33 | train_args.predict_with_generate = True 34 | train_args.remove_unused_columns = False # keep to False 35 | 36 | if args.dataset_name == 'wow': 37 | data_load_fn = partial(load_wow, 38 | max_samples=args.max_neg_samples, 39 | random_sample=True) 40 | elif args.dataset_name in ['cnn_dailymail', 'xsum']: 41 | data_load_fn = partial(load_summary_data, 42 | split='train', 43 | max_train_samples=args.max_neg_samples, 44 | random_sample=True) 45 | else: 46 | raise NotImplementedError 47 | 48 | with train_args.main_process_first(desc="train dataset map pre-processing"): 49 | dataset, config = data_load_fn(args.dataset_path) 50 | print(len(dataset)) 51 | 52 | load_kwargs = { 53 | 'device_map': 'auto' if args.load_8bit else None, 54 | 'load_in_8bit': args.load_8bit, 55 | 'torch_dtype': torch.float16 if args.load_8bit else torch.bfloat16, 56 | } 57 | model, tokenizer = load_transformer_LM_tokenizer(args.model_name, **load_kwargs) 58 | tokenizer.truncation_side = 'left' 59 | is_encoder_decoder = model.config.is_encoder_decoder 60 | if is_encoder_decoder: 61 | decoder_start_token_id = model.config.decoder_start_token_id 62 | 63 | dataset = dataset.add_column('index', range(len(dataset))) 64 | 65 | def preprocess(example): 66 | # randomly select knowledge 67 | # TODO: non-random selection 68 | non_batch_indices = list(filter(lambda x: x != example['index'], range(len(dataset)))) 69 | 70 | answer = example['answers'].strip() 71 | question = example['question'].strip() 72 | idx = random.choice(non_batch_indices) 73 | knowledge = dataset[idx]['ctxs'] 74 | # randomly perturb the answer 75 | answer_tokens = tokenizer.encode(answer, return_tensors='pt')[0] 76 | first = args.first_n if args.first_n < len(answer_tokens) else 1 77 | if len(answer_tokens) - args.last_n > first: 78 | last = len(answer_tokens) - args.last_n 79 | else: 80 | last = len(answer_tokens) - 1 81 | perturb_idx = random.randint(first, last) 82 | pert = answer_tokens[:perturb_idx] 83 | pert_txt = tokenizer.decode(pert, skip_special_tokens=True) 84 | 85 | if 'question' in config['input_columns']: 86 | input_text = get_instruction(args.instruction_model, 87 | args.dataset_name, 88 | question=question, 89 | knowledge=knowledge) 90 | else: 91 | input_text = get_instruction(args.instruction_model, 92 | args.dataset_name, 93 | knowledge=knowledge) 94 | 95 | if is_encoder_decoder: 96 | tokenized = tokenizer(input_text, 97 | truncation=True, 98 | max_length=tokenizer.model_max_length, 99 | return_tensors='pt') 100 | tokenized['decoder_input_ids'] = torch.cat( 101 | [torch.full((1, 1), 102 | decoder_start_token_id, 103 | dtype=torch.long), 104 | pert.unsqueeze(0)], 105 | dim=1) 106 | else: 107 | tokenized = tokenizer(input_text + ' ' + pert_txt, 108 | truncation=True, 109 | max_length=tokenizer.model_max_length, 110 | return_tensors='pt') 111 | tokenized.pop('token_type_ids', None) # unused 112 | 113 | tokenized = {k: v[0] for k, v in tokenized.items()} 114 | tokenized.update(**dict(perturb_idx=perturb_idx, pert_txt=pert_txt)) 115 | return tokenized 116 | 117 | with train_args.main_process_first(desc="train dataset map pre-processing"): 118 | tokenized_dataset = dataset.map(preprocess, 119 | num_proc=args.num_workers, 120 | remove_columns=dataset.column_names) 121 | perturb_indices = tokenized_dataset['perturb_idx'] 122 | perturb_prompt = tokenized_dataset['pert_txt'] 123 | tokenized_dataset = tokenized_dataset.remove_columns(['perturb_idx', 'pert_txt']) 124 | 125 | 126 | def generate_fn(_model, _tokenizer, inputs): 127 | generated = _model.generate(**inputs, **gen_parameters.__dict__) 128 | # get rid of tokens that were already in the input 129 | if _model.config.is_encoder_decoder: 130 | generated = generated[:, len(inputs['decoder_input_ids'][0]):] 131 | else: 132 | generated = generated[:, len(inputs['input_ids'][0]):] 133 | return generated 134 | 135 | trainer = GuidedGenerationPredictor( 136 | generate_fn=generate_fn, 137 | model=model, 138 | args=train_args, 139 | data_collator=DataCollatorForSeq2SeqTokenClassification(tokenizer), 140 | tokenizer=tokenizer, 141 | ) 142 | preds = trainer.predict(tokenized_dataset, **gen_parameters.__dict__) 143 | preds.predictions[preds.predictions == -100] = tokenizer.pad_token_id 144 | responses = tokenizer.batch_decode(preds.predictions, skip_special_tokens=True) 145 | 146 | neg_data = [] 147 | for i, (example, p_txt, p_idx, neg) in enumerate(zip(dataset, 148 | perturb_prompt, 149 | perturb_indices, 150 | responses)): 151 | target = f'{p_txt} {neg}'.strip() 152 | label = tokenizer.encode(target, return_tensors='pt')[0] 153 | label[:p_idx] = 1 154 | label[p_idx:] = 0 155 | neg_data.append({ 156 | 'answers': target, 157 | 'ctxs': example['ctxs'], # correct knowledge 158 | 'question': example['question'], 159 | 'label': label.tolist(), 160 | }) 161 | neg_dataset = Dataset.from_list(neg_data) 162 | neg_dataset = neg_dataset.filter(lambda x: '\n' in x['question']) 163 | 164 | def _listed_label(example): 165 | example['label'] = [example['label']] 166 | return example 167 | dataset = dataset.map(_listed_label) 168 | 169 | full_data = concatenate_datasets([dataset, neg_dataset]) 170 | full_data.save_to_disk( 171 | f'data/cached/{args.dataset_name}_train_augmented_neg_{args.model_name.replace("/", "-")}') 172 | 173 | 174 | if __name__ == '__main__': 175 | main() 176 | -------------------------------------------------------------------------------- /kcd/sample_negative.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | import random 4 | 5 | from transformers import HfArgumentParser 6 | 7 | from kcd.kilt.load_kilt_data import load_fever 8 | from datasets import concatenate_datasets, Dataset 9 | 10 | 11 | @dataclass 12 | class ExperimentArgs: 13 | dataset_name: str = 'wow' 14 | dataset_path: str = 'data/wow.jsonl' 15 | use_kilt_format: bool = False 16 | sample_method: str = 'random' # choices: [random] 17 | 18 | 19 | def main(): 20 | parser = HfArgumentParser([ExperimentArgs]) 21 | args = parser.parse_args_into_dataclasses()[0] 22 | 23 | if args.dataset_name == 'wow': 24 | if args.use_kilt_format: 25 | from kcd.kilt.load_kilt_data import load_wow 26 | else: 27 | from kcd.wizard_of_wikipedia import load_wow 28 | data_load_fn = load_wow 29 | else: 30 | data_load_fn = load_fever 31 | 32 | dataset, config = data_load_fn(args.dataset_path) 33 | 34 | ans = dataset['answers'] 35 | knowledge = dataset['ctxs'] 36 | if args.sample_method == 'random': 37 | random.shuffle(knowledge) 38 | neg_dataset = Dataset.from_dict({ 39 | 'answers': ans, 40 | 'ctxs': knowledge, 41 | 'question': dataset['question'], 42 | 'label': [0 for _ in range(len(ans))], 43 | }) 44 | else: 45 | raise ValueError 46 | full_data = concatenate_datasets([dataset, neg_dataset]) 47 | basename = os.path.basename(args.dataset_path).split('.')[0] 48 | dataset_name = f'{args.dataset_name}_{basename}_augmented_neg_{args.sample_method}' 49 | full_data.save_to_disk(f'data/cached/{dataset_name}') 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /kcd/summarization/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import load_summary_data 2 | -------------------------------------------------------------------------------- /kcd/summarization/load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from datasets import load_dataset, load_from_disk 5 | 6 | from transformers import AutoTokenizer 7 | 8 | MAIN_MODEL = 'google/flan-t5-xl' 9 | 10 | 11 | def load_summary_data(path: str, tokenizer=None, split='test', max_train_samples=100000, random_sample=False): 12 | if tokenizer is None: 13 | tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL) 14 | config = { 15 | 'input_columns': ['ctxs'], 16 | 'instruction': "### Document:\n{}" 17 | "\n\nGiven the article above, generate a faithful summary." 18 | } 19 | if os.path.isdir(path): 20 | dataset = load_from_disk(path) 21 | return dataset, config 22 | if path == 'cnn_dailymail': 23 | dataset = load_dataset(path, '3.0.0')[split] 24 | if split == 'train': 25 | indices = list(range(len(dataset))) 26 | if random_sample: 27 | random.shuffle(indices) 28 | dataset = dataset.select(indices[:max_train_samples]) 29 | dataset = dataset.rename_column('article', 'ctxs') 30 | dataset = dataset.rename_column('highlights', 'question') 31 | dataset = dataset.add_column('answers', dataset['question']) 32 | elif path == 'xsum': 33 | dataset = load_dataset(path)[split] 34 | dataset = dataset.rename_column('document', 'ctxs') 35 | dataset = dataset.rename_column('summary', 'question') 36 | dataset = dataset.add_column('answers', dataset['question']) 37 | else: 38 | raise ValueError(f'Unknown dataset: {path}') 39 | dataset = dataset.add_column('label', [1] * len(dataset)) 40 | 41 | # Tokenize the dataset and filter out samples that are too long 42 | def get_doc_len(examples): 43 | return {'doc_len': len(tokenizer.encode(examples['ctxs']))} 44 | 45 | dataset = dataset.map(get_doc_len) 46 | # -25 for instructions 47 | dataset = dataset.filter(lambda x: x['doc_len'] <= tokenizer.model_max_length - 25) 48 | 49 | return dataset, config 50 | -------------------------------------------------------------------------------- /kcd/text_data.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from kcd.instructions import OPENAI_INSTRUCTION, get_instruction 4 | from kcd.kilt.load_kilt_data import load_fever 5 | from kcd.dstc11_task5 import load_dstc_data 6 | from kcd.summarization import load_summary_data 7 | 8 | 9 | def load_text_data(path='data/fever-dev-kilt-processed.jsonl', 10 | use_kilt_format=True, 11 | instruction_model='openai', 12 | task='chat', 13 | tokenize=False, 14 | tokenizer=None, 15 | add_trailing_newline=False, 16 | no_label=False): 17 | if 'fever' in path: 18 | load_fn = load_fever 19 | data_task = 'fever' 20 | elif 'wow' in path: 21 | if use_kilt_format: 22 | from kcd.kilt.load_kilt_data import load_wow 23 | else: 24 | from kcd.wizard_of_wikipedia import load_wow 25 | load_fn = load_wow 26 | data_task = 'wow' 27 | elif 'dstc11_task5' in path: 28 | load_fn = load_dstc_data 29 | data_task = 'dstc11_task5' 30 | elif 'cnn_dailymail' in path or 'xsum' in path: 31 | load_fn = load_summary_data 32 | data_task = 'summarization' 33 | else: 34 | raise ValueError(f'Unknown dataset: {path}') 35 | 36 | def prepare(config, example): 37 | # ['question', 'answers', 'id', 'ctxs', 'label'] 38 | knowledge = example['ctxs'].strip() 39 | question = example['question'].strip() 40 | answer = example['answers'].strip() 41 | if task != 'chat': 42 | if 'question' in config['input_columns']: 43 | input_text = get_instruction(instruction_model, 44 | data_task, 45 | question=question, 46 | knowledge=knowledge) 47 | target = answer 48 | else: 49 | input_text = get_instruction(instruction_model, data_task, knowledge=knowledge) 50 | target = question 51 | 52 | if tokenize: 53 | assert tokenizer is not None 54 | if no_label: 55 | target = None 56 | if add_trailing_newline: 57 | input_text = input_text + '\n\n' 58 | tokenized = tokenizer(input_text, 59 | text_target=target, 60 | return_tensors='pt', 61 | truncation=True, 62 | max_length=tokenizer.model_max_length) 63 | return {k: v[0] for k, v in tokenized.items()} 64 | else: 65 | if data_task == 'summarization': 66 | # summarization 67 | template = "Summarize the following text:\n\n{}" 68 | input_text = [{'role': 'user', 'content': template.format(knowledge)}] 69 | target = answer 70 | return {'prompt': input_text, 'completion': target} 71 | # wow for openai chatCompletion API 72 | if '' in question: # dstc11_task5 73 | conversation_history = question.split('') 74 | conversation_history = [ 75 | txt.replace('User: ', '').replace('System: ', '') 76 | for txt in conversation_history 77 | ] 78 | else: 79 | conversation_history = question.split('\n') 80 | # 1: wizard = assistant, 0: user 81 | user = list(map(int, example['user'].split(','))) 82 | if len(conversation_history) > len(user): 83 | assert len(conversation_history) == len(user) + 1 84 | first_utt = '\n'.join(conversation_history[:2]) 85 | conversation_history = [first_utt] + conversation_history[2:] 86 | 87 | messages = [{ 88 | "role": "user" if i == 0 else "assistant", 89 | "content": chat 90 | } for i, chat in zip(user, conversation_history)] 91 | messages.append({"role": "system", "content": OPENAI_INSTRUCTION.format(knowledge)}) 92 | 93 | input_text = messages 94 | target = answer 95 | 96 | return {'prompt': input_text, 'completion': target} 97 | 98 | dataset, config = load_fn(path) 99 | filtered = dataset.filter(lambda x: x['ctxs'] is not None) 100 | mapped = filtered.map(partial(prepare, config), remove_columns=dataset.column_names) 101 | return mapped 102 | -------------------------------------------------------------------------------- /kcd/token_classifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/Knowledge-Constrained-Decoding/707f4de017c12ec6145b08249362e247ba8aa486/kcd/token_classifier/__init__.py -------------------------------------------------------------------------------- /kcd/token_classifier/dataloader.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from functools import partial 3 | from typing import Union, Optional, List 4 | 5 | import torch 6 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy 7 | from transformers.data.data_collator import DataCollatorMixin 8 | 9 | from kcd.instructions import get_instruction, WOW_CLASSFICATION_INSTRUCTION 10 | from kcd.kilt.load_kilt_data import load_fever 11 | from kcd.dstc11_task5 import load_dstc_data 12 | from kcd.summarization import load_summary_data 13 | from kcd.util import shift_right 14 | 15 | @dataclass 16 | class DataCollatorForSeq2SeqTokenClassification(DataCollatorMixin): 17 | 18 | tokenizer: PreTrainedTokenizerBase 19 | other_features_to_pad: List[str] = field(default_factory=list) 20 | padding: Union[bool, str, PaddingStrategy] = True 21 | max_length: Optional[int] = None 22 | pad_to_multiple_of: Optional[int] = None 23 | label_pad_token_id: int = -100 24 | return_tensors: str = "pt" 25 | 26 | def torch_call(self, features): 27 | import torch 28 | 29 | label_name = "label" if "label" in features[0].keys() else "labels" 30 | labels = [feature[label_name] for feature in features 31 | ] if label_name in features[0].keys() else None 32 | other_features = { 33 | k : [feature[k] for feature in features] for k in self.other_features_to_pad 34 | } 35 | decoder_input_ids = [feature['decoder_input_ids'] for feature in features] 36 | 37 | no_labels_features = [{ 38 | k: v for k, v in feature.items() 39 | if k not in (label_name, 'decoder_input_ids', *self.other_features_to_pad) 40 | } for feature in features] 41 | 42 | batch = self.tokenizer.pad( 43 | no_labels_features, 44 | padding=self.padding, 45 | max_length=self.max_length, 46 | pad_to_multiple_of=self.pad_to_multiple_of, 47 | return_tensors="pt", 48 | ) 49 | 50 | sequence_length = max([len(ids) for ids in decoder_input_ids]) 51 | padding_side = self.tokenizer.padding_side 52 | 53 | def to_list(tensor_or_iterable): 54 | if isinstance(tensor_or_iterable, torch.Tensor): 55 | return tensor_or_iterable.tolist() 56 | return list(tensor_or_iterable) 57 | 58 | def pad_tensor(tensor, pad_id, seqlen): 59 | if padding_side == "right": 60 | return [to_list(x) + [pad_id] * (seqlen - len(x)) for x in tensor] 61 | return [[pad_id] * (seqlen - len(x)) + to_list(x) for x in tensor] 62 | 63 | batch['decoder_input_ids'] = pad_tensor(decoder_input_ids, self.tokenizer.pad_token_id, sequence_length) 64 | batch['decoder_input_ids'] = torch.tensor(batch['decoder_input_ids'], dtype=torch.int64) 65 | 66 | if labels is None: 67 | return batch 68 | batch[label_name] = pad_tensor(labels, self.label_pad_token_id, sequence_length) 69 | batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64) 70 | 71 | if not other_features: 72 | return batch 73 | 74 | for k, v in other_features.items(): 75 | seqlen = max([len(ids) for ids in v]) 76 | padded = pad_tensor(v, self.tokenizer.pad_token_id, seqlen) 77 | batch[k] = torch.tensor(padded, dtype=torch.int64) 78 | 79 | return batch 80 | 81 | 82 | def load_data(args, 83 | tokenizer, 84 | is_encoder_decoder=False, 85 | decoder_start_token_id=0, 86 | instruction_model='basic', 87 | zeroshot_classification=False, 88 | get_knowledge_ids=False): 89 | if args.dataset == 'fever': 90 | load_fn = load_fever 91 | elif args.dataset == 'wow': 92 | if args.use_kilt_format: 93 | from kcd.kilt.load_kilt_data import load_wow 94 | else: 95 | from kcd.wizard_of_wikipedia import load_wow 96 | load_fn = load_wow 97 | elif args.dataset == 'dstc11_task5': 98 | load_fn = load_dstc_data 99 | elif args.dataset in ('cnn_dailymail', 'xsum'): 100 | load_fn = load_summary_data 101 | else: 102 | raise ValueError 103 | 104 | def _tokenize(config: dict, example): 105 | # ['question', 'answers', 'id', 'ctxs', 'label'] 106 | knowledge = example['ctxs'].strip() 107 | question = example['question'].strip() 108 | answer = example['answers'].strip() 109 | if zeroshot_classification: 110 | classi_text = WOW_CLASSFICATION_INSTRUCTION.format(question, knowledge, answer) 111 | classi_inputs = tokenizer(classi_text, 112 | return_tensors='pt', 113 | max_length=tokenizer.model_max_length, 114 | return_token_type_ids=False, 115 | truncation=True) 116 | classi_inputs['labels'] = torch.LongTensor([example['label']]) 117 | classi_inputs = {k: v[0] for k, v in classi_inputs.items()} 118 | return classi_inputs 119 | 120 | if 'question' in config['input_columns']: 121 | input_text = get_instruction(instruction_model, 122 | args.dataset, 123 | question=question, 124 | knowledge=knowledge) 125 | target = answer 126 | else: 127 | input_text = get_instruction(instruction_model, args.dataset, knowledge=knowledge) 128 | target = answer 129 | 130 | if is_encoder_decoder: 131 | tokenized = tokenizer(input_text, 132 | text_target=target, 133 | max_length=tokenizer.model_max_length, 134 | return_tensors='pt', 135 | truncation=True) 136 | if not hasattr(args, 'sft') or not args.sft: 137 | # TODO: need to add decoder start token... 138 | tokenized['decoder_input_ids'] = shift_right(tokenized['labels'], 139 | decoder_start_token_id) 140 | if isinstance(example['label'], list): 141 | if len(example['label']) == 1: 142 | tokenized['labels'] = torch.full_like(tokenized['decoder_input_ids'], example['label'][0]) 143 | elif hasattr(args, 'sequence_label') and args.sequence_label: 144 | tokenized['labels'] = torch.full_like(tokenized['decoder_input_ids'], example['label'][-1]) 145 | else: 146 | label = torch.LongTensor([example['label']]) 147 | if not tokenized['decoder_input_ids'].shape[1] == label.shape[1]: 148 | # The indexing of label is wrong because of .strip() 149 | label = label[:, :-1] 150 | assert tokenized['decoder_input_ids'].shape[1] == label.shape[1] 151 | tokenized['labels'] = shift_right(label, 1) 152 | else: 153 | tokenized['labels'] = torch.full_like(tokenized['decoder_input_ids'], example['label']) 154 | else: 155 | target_len = len(tokenizer.encode(target)) 156 | tokenized = tokenizer(f"{input_text}\n\n{target}", 157 | max_length=tokenizer.model_max_length, 158 | return_tensors='pt', 159 | truncation=True) 160 | if hasattr(args, 'sft') and args.sft: 161 | label = tokenized['input_ids'].clone() 162 | label[:, :-target_len] = -100 163 | tokenized['labels'] = label 164 | else: 165 | full_seqlen = tokenized['attention_mask'].sum(-1) 166 | if isinstance(example['label'], list): 167 | if len(example['label']) == 1: 168 | label = torch.full((target_len,), example['label'][0], dtype=int).tolist() 169 | elif hasattr(args, 'sequence_label') and args.sequence_label: 170 | label = torch.full((target_len,), example['label'][-1], dtype=int).tolist() 171 | else: 172 | label = example['label'] 173 | if not target_len == len(label): 174 | # The indexing of label is wrong because of .strip() 175 | label = label[:-1] 176 | else: 177 | label = torch.full((target_len,), example['label'], dtype=int).tolist() 178 | label = [[-100] * (full_seqlen - len(label)) + label] 179 | tokenized['labels'] = label 180 | if get_knowledge_ids: 181 | tokenized['knowledge_ids'] = tokenizer(knowledge, return_tensors='pt').input_ids 182 | tokenized = {k: v[0] for k, v in tokenized.items()} 183 | 184 | return tokenized 185 | 186 | datasets = {} 187 | for split, path in zip(['train', 'validation', 'test'], 188 | [args.train_data_path, args.validation_data_path, args.test_data_path]): 189 | if not path: 190 | datasets[split] = None 191 | continue 192 | dataset, config = load_fn(path) 193 | filtered = dataset.filter(lambda x: x['ctxs'] is not None) 194 | mapped = filtered.map(partial(_tokenize, config), remove_columns=dataset.column_names) 195 | datasets[split] = mapped 196 | return datasets 197 | -------------------------------------------------------------------------------- /kcd/token_classifier/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from transformers.trainer import Trainer 8 | from transformers.trainer_callback import TrainerCallback 9 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 10 | 11 | from sklearn.metrics import classification_report 12 | 13 | class SavePeftModelCallback(TrainerCallback): 14 | 15 | def on_save(self, args, state, control, **kwargs): 16 | checkpoint_folder = os.path.join(args.output_dir, 17 | f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") 18 | 19 | peft_model_path = os.path.join(checkpoint_folder, "adapter_model") 20 | kwargs["model"].save_pretrained(peft_model_path) 21 | 22 | # pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") 23 | # if os.path.exists(pytorch_model_path): 24 | # os.remove(pytorch_model_path) 25 | return control 26 | 27 | 28 | class MyTrainer(Trainer): 29 | 30 | def compute_loss(self, model, inputs, return_outputs=False): 31 | lm_logits = None 32 | if hasattr(model, 'v2_regularization') and model.v2_regularization > 0: 33 | disable_adapter = hasattr(model, 'base_model') # peft 34 | if disable_adapter: 35 | with model.disable_adapter(): 36 | outputs = model(**inputs, return_lm_only=True) 37 | lm_logits = outputs.logits 38 | else: 39 | outputs = model(**inputs) 40 | lm_logits = outputs.logits 41 | output = model(lm_logits=lm_logits, **inputs) 42 | if isinstance(output, tuple): 43 | _, token_classification_output = output 44 | loss = token_classification_output.loss 45 | else: 46 | loss = output.loss 47 | 48 | if return_outputs: 49 | return loss, [None, None] # fake outputs 50 | return loss 51 | 52 | def prediction_step( 53 | self, 54 | model: nn.Module, 55 | inputs: Dict[str, Union[torch.Tensor, Any]], 56 | prediction_loss_only: bool, 57 | ignore_keys: Optional[List[str]] = None, 58 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 59 | inputs = self._prepare_inputs(inputs) 60 | model.eval() 61 | model.config.use_cache = True # faster 62 | labels = inputs.pop('labels') 63 | with torch.no_grad(): 64 | lm_logits = None 65 | if hasattr(model, 'v2_regularization') and model.v2_regularization > 0: # regularization for nado 66 | disable_adapter = hasattr(model, 'base_model') # peft 67 | if disable_adapter: 68 | with model.disable_adapter(): 69 | outputs = model(**inputs, return_lm_only=True) 70 | lm_logits = outputs.logits 71 | else: 72 | outputs, _ = model(**inputs) 73 | lm_logits = outputs.logits 74 | output = model(lm_logits=lm_logits, **inputs) 75 | if isinstance(output, tuple): 76 | _, token_classification_output = output 77 | logits = token_classification_output.logits 78 | else: 79 | logits = output.logits 80 | # get pool_method 81 | if hasattr(model, 'base_model'): # peft 82 | is_peft = True 83 | pool_method = getattr(model.base_model, 'pool_method', False) 84 | else: 85 | is_peft = False 86 | pool_method = getattr(model, 'pool_method', False) 87 | 88 | if pool_method == 'inflection': 89 | if is_peft: 90 | labels, pool_idx, _ = model.base_model._get_inflection_position(labels) 91 | else: 92 | labels, pool_idx, _ = model._get_inflection_position(labels) 93 | assert len(logits.shape) == 3 and logits.shape[-1] == 2 94 | 95 | if len(logits.shape) == 2: 96 | # pooling is applied 97 | if model.config.is_encoder_decoder: 98 | seqlen = (labels != -100).sum(dim=1) 99 | else: 100 | seqlen = torch.ne(inputs['input_ids'], model.config.pad_token_id).sum(-1) 101 | labels = labels[range(labels.shape[0]), seqlen - 1] 102 | eval_loss = F.cross_entropy(logits, labels) 103 | # compute metrics later 104 | if hasattr(model, 'v2') and model.v2: 105 | idx = inputs['input_ids'][range(labels.shape[0]), seqlen - 1] 106 | preds = torch.sigmoid(logits[range(logits.shape[0]), idx]) > 0.5 107 | preds = preds.long() 108 | else: 109 | preds = logits.argmax(-1) 110 | return eval_loss, preds, labels 111 | 112 | if logits.shape[1] == labels.shape[1] - 1: # happens for v2 113 | labels = labels[:, 1:] 114 | if logits.shape[2] == 1: # binary classification # v2 115 | eval_loss = token_classification_output.loss 116 | if pool_method == 'last': 117 | if model.config.is_encoder_decoder: 118 | seqlen = (labels != -100).sum(dim=1) 119 | else: 120 | seqlen = torch.ne(inputs['input_ids'], model.config.pad_token_id).sum(-1) 121 | labels = labels[range(labels.shape[0]), seqlen - 1] 122 | breakpoint() 123 | return eval_loss, logits[:, -1], labels 124 | else: 125 | eval_loss = F.cross_entropy(logits.permute(0, 2, 1), labels) 126 | if prediction_loss_only: 127 | return eval_loss 128 | # loss, logit, label 129 | # NOTE: for efficiency, compute accuracy here! 130 | label_mask = labels != -100 # TODO -100 is pad idx by default; 131 | if logits.shape[2] == 1: # binary classification 132 | preds = torch.sigmoid(logits.squeeze(2)) > 0.5 133 | else: 134 | preds = logits.argmax(-1) 135 | # 1. accuracy 136 | # NOTE: no need to care about padding since preds cannot be -100 137 | correct = preds == labels 138 | true_acc = correct.sum(-1) / label_mask.sum(-1) # [B,] 139 | # 2. prec, recall, f1 140 | tp, fp, fn = [], [], [] 141 | for pred, label, mask in zip(preds, labels, label_mask): 142 | pred = pred[mask] 143 | label = label[mask] 144 | tp.append(((pred == 1) & (label == 1)).sum()) 145 | fp.append(((pred == 1) & (label != 1)).sum()) 146 | fn.append(((pred != 1) & (label == 1)).sum()) 147 | tp = torch.stack(tp) 148 | fp = torch.stack(fp) 149 | fn = torch.stack(fn) 150 | precision = tp / (tp + fp) 151 | recall = tp / (tp + fn) 152 | # nan handling before f1 153 | precision = torch.nan_to_num(precision, nan=0.0) 154 | recall = torch.nan_to_num(recall, nan=0.0) 155 | f1 = (2 * precision * recall) / (precision + recall) 156 | # handle nan for f1 once again 157 | f1 = torch.nan_to_num(f1, nan=0.0) 158 | 159 | metrics = torch.cat([met.unsqueeze(1) for met in (true_acc, precision, recall, f1)], dim=1) 160 | 161 | dummy_label = torch.full((labels.shape[0],), -100) 162 | return (eval_loss, metrics, dummy_label) 163 | 164 | 165 | def compute_metrics(eval_preds): 166 | """Compute accuracy""" 167 | if not eval_preds.label_ids[0] == -100: 168 | # compute metrics here 169 | metrics = classification_report(eval_preds.label_ids, eval_preds.predictions, output_dict=True) 170 | return dict(accuracy=metrics['accuracy'], 171 | precision=metrics['1']['precision'], 172 | recall=metrics['1']['recall'], 173 | f1=metrics['1']['f1-score']) 174 | metrics = eval_preds.predictions # [N, 4] 175 | final_metrics = metrics.mean(0).tolist() 176 | 177 | return dict(accuracy=final_metrics[0], 178 | precision=final_metrics[1], 179 | recall=final_metrics[2], 180 | f1=final_metrics[3]) 181 | -------------------------------------------------------------------------------- /kcd/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM 8 | 9 | from kcd.token_classifier.model import T5DoubleHeadModel 10 | 11 | ENCODER_DECODER_ARCH_NAMES = ['t5', 't0', 'ul2', 'bart'] 12 | 13 | 14 | def load_transformer_LM_tokenizer(model_name_or_path, 15 | tokenizer_name_or_path=None, 16 | load_t5_doublehead=False, 17 | **kwargs): 18 | if tokenizer_name_or_path is None: 19 | tokenizer_name_or_path = model_name_or_path 20 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) 21 | if load_t5_doublehead: 22 | model = T5DoubleHeadModel.from_pretrained(model_name_or_path, **kwargs) 23 | elif any(name in model_name_or_path.lower() for name in ENCODER_DECODER_ARCH_NAMES): 24 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, **kwargs) 25 | else: 26 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **kwargs) 27 | # open-ended generation 28 | tokenizer.pad_token = tokenizer.eos_token 29 | model.config.pad_token_id = model.config.eos_token_id 30 | # this is needed since we are using batched generation for causal LM 31 | tokenizer.padding_side = 'left' 32 | return model, tokenizer 33 | 34 | 35 | def shift_right(input_ids, decoder_start_token_id): 36 | # shift inputs to the right 37 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 38 | shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() 39 | shifted_input_ids[..., 0] = decoder_start_token_id 40 | 41 | return shifted_input_ids 42 | 43 | 44 | def get_logger(fname="logs/fever_test.log"): 45 | logFormatter = logging.Formatter( 46 | "%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s") 47 | logger = logging.getLogger(__name__) 48 | 49 | fileHandler = logging.FileHandler(fname) 50 | fileHandler.setFormatter(logFormatter) 51 | logger.addHandler(fileHandler) 52 | 53 | consoleHandler = logging.StreamHandler() 54 | consoleHandler.setFormatter(logFormatter) 55 | logger.addHandler(consoleHandler) 56 | return logger 57 | 58 | 59 | def logsumexp(tensor, dim=-1, mask=None): 60 | if mask is None: 61 | return torch.logsumexp(tensor, dim=dim) 62 | 63 | assert mask.shape == tensor.shape, 'The factors tensor should have the same shape as the original' 64 | # a = torch.cat([torch.max(tensor, dim, keepdim=True) for _ in range(tensor.shape[dim])], dim) 65 | a = tensor.max(dim, keepdim=True) 66 | return a + torch.sum((tensor - a).exp() * mask, dim).log() 67 | 68 | 69 | def in_notebook(): 70 | try: 71 | from IPython import get_ipython 72 | if 'IPKernelApp' not in get_ipython().config: # pragma: no cover 73 | return False 74 | except ImportError: 75 | return False 76 | except AttributeError: 77 | return False 78 | return True 79 | 80 | 81 | def set_random_seeds(seed): 82 | """ 83 | set the random seed of all related libraries 84 | """ 85 | random.seed(seed) 86 | os.environ['PYTHONHASHSEED'] = str(seed) 87 | np.random.seed(seed) 88 | torch.manual_seed(seed) 89 | torch.cuda.manual_seed(seed) 90 | torch.cuda.manual_seed_all(seed) 91 | torch.backends.cudnn.deterministic = True 92 | 93 | 94 | def freeze_module(module): 95 | for param in module.parameters(): 96 | param.requires_grad = False 97 | -------------------------------------------------------------------------------- /kcd/wizard_of_wikipedia/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import load_wow 2 | -------------------------------------------------------------------------------- /kcd/wizard_of_wikipedia/load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from datasets import Dataset, load_from_disk 5 | import pandas as pd 6 | 7 | 8 | def load_wow(path: str, max_samples=None, random_sample=False): 9 | wow_config = { 10 | 'input_columns': ['ctxs', 'question'], 11 | 'instruction': "History:\n{}\n\nKnowledge:\n{}" 12 | "\n\nGiven the dialog history and a relevant knowledge above," 13 | " generate a knowledgeable, usefule, and helpful answer." 14 | } 15 | if os.path.isdir(path): 16 | dataset = load_from_disk(path) 17 | return dataset, wow_config 18 | 19 | df = pd.read_json(path, lines=True) 20 | df['question'] = df['history'].apply(lambda x: '\n'.join([_x.strip() for _x in x])) 21 | df['user'] = df['user'].apply(lambda x: ','.join(map(str, x))) 22 | df['answers'] = df['response'] 23 | df['ctxs'] = df['knowledge'].apply(lambda x: x[0].split('__knowledge__')[1].strip()) 24 | df['label'] = 1 25 | df = df[['question', 'ctxs', 'answers', 'label', 'user']] 26 | dataset = Dataset.from_pandas(df) 27 | if max_samples is not None: 28 | indices = list(range(len(dataset))) 29 | if random_sample: 30 | random.shuffle(indices) 31 | dataset = dataset.select(indices[:max_samples]) 32 | 33 | return dataset, wow_config 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # basic 2 | nltk 3 | tqdm 4 | pandas 5 | matplotlib 6 | jupyter 7 | scikit-learn 8 | # torch # NOTE: Typically, you should install torch with cuda beforehand. 9 | transformers 10 | accelerate 11 | datasets 12 | bitsandbytes 13 | loralib 14 | sentencepiece 15 | torchmetrics 16 | wandb 17 | fire 18 | git+https://github.com/huggingface/peft.git 19 | openai 20 | # unieval 21 | protobuf 22 | rouge-score 23 | py7zr 24 | evaluate 25 | prettytable 26 | editdistance 27 | # MFMA 28 | spacy 29 | spacy-legacy 30 | spacy-loggers 31 | # kilt 32 | pymongo 33 | beautifulsoup4 34 | tiktoken 35 | evaluate 36 | sacrebleu 37 | bert_score 38 | git+https://github.com/google-research/bleurt.git 39 | -------------------------------------------------------------------------------- /scripts/analyze_partial_hallucination_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import datasets 3 | import random 4 | 5 | from termcolor import colored 6 | from transformers import AutoTokenizer 7 | import fire 8 | 9 | def get_neg_data_only(dataset): 10 | for i, l in enumerate(dataset['label']): 11 | if len(l) > 1: 12 | break 13 | return dataset[i + 1:] 14 | 15 | def detokenize(tokens): 16 | return ''.join([' ' if tok == '▁' else tok.replace('▁', ' ') for tok in tokens]) 17 | 18 | 19 | def analyze(neg_dataset, idx, task, tokenizer, full_data, verbose=False): 20 | def verbose_print(*args): 21 | if verbose: 22 | print(*args) 23 | if task == 'wow': 24 | verbose_print('history\n\n', neg_dataset['question'][idx]) 25 | verbose_print('\n\nknowledge\n\n', neg_dataset['ctxs'][idx]) 26 | 27 | last_idx = sum(neg_dataset['label'][idx]) - 1 28 | tokenized = tokenizer.tokenize(neg_dataset['answers'][idx]) 29 | original_tokens = detokenize(tokenized[:last_idx]).strip() 30 | hallucinated_tokens = detokenize(tokenized[last_idx:]).strip() 31 | 32 | verbose_print("\n\nanswer\n\n", original_tokens + colored(hallucinated_tokens, 'red')) 33 | 34 | verbose_print('\n\noriginal answer\n\n', full_data[full_data['history'] == neg_dataset['question'][idx]]['response'].tolist()[0]) 35 | elif task == 'cnn': 36 | verbose_print('document\n\n', neg_dataset['ctxs'][idx]) 37 | 38 | last_idx = sum(neg_dataset['label'][idx]) - 1 39 | tokenized = tokenizer.tokenize(neg_dataset['answers'][idx]) 40 | original_tokens = detokenize(tokenized[:last_idx]) 41 | hallucinated_tokens = detokenize(tokenized[last_idx:]) 42 | 43 | verbose_print("\n\nsummary\n\n", original_tokens + colored(hallucinated_tokens, 'red')) 44 | 45 | for i, art in enumerate(full_data['article']): 46 | if art == neg_dataset['ctxs'][idx]: 47 | verbose_print('\n\noriginal summary\n\n', full_data[i]['highlights']) 48 | break 49 | else: 50 | raise ValueError 51 | 52 | return hallucinated_tokens 53 | 54 | 55 | def main(task: str='wow', verbose: bool=False): 56 | """ 57 | task: wow or cnn 58 | """ 59 | tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xl') 60 | 61 | if task == 'wow': 62 | wow_df = pd.read_json('data/cached/wow/train.jsonl', lines=True) 63 | wow_df['history'] = wow_df['history'].apply(lambda x: '\n'.join(x)) 64 | partial_data = datasets.load_from_disk('data/cached/wow_train_augmented_neg_google-flan-t5-xl') 65 | partial_negative = get_neg_data_only(partial_data) 66 | full_data = wow_df 67 | elif task == 'cnn': 68 | cnn_data = datasets.load_dataset('cnn_dailymail', '3.0.0')['train'] 69 | partial_data = datasets.load_from_disk('data/cached/cnn_dailymail_train_augmented_neg_google-flan-t5-xl') 70 | partial_negative = get_neg_data_only(partial_data) 71 | full_data = cnn_data 72 | else: 73 | raise ValueError 74 | 75 | for i in random.sample(range(len(partial_negative['label'])), 10): 76 | analyze(partial_negative, i, task, tokenizer, full_data, verbose=verbose) 77 | 78 | if __name__ == '__main__': 79 | fire.Fire(main) 80 | -------------------------------------------------------------------------------- /scripts/evaluate_generations_with_classifier.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import glob 3 | import json 4 | import os 5 | 6 | from datasets import Dataset 7 | import torch 8 | import pandas as pd 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer, HfArgumentParser, BitsAndBytesConfig 12 | from peft import LoraConfig, PeftModel, get_peft_model 13 | 14 | from kcd.token_classifier.dataloader import DataCollatorForSeq2SeqTokenClassification, load_data 15 | from kcd.token_classifier.model import T5DoubleHeadModel 16 | from kcd.util import shift_right 17 | 18 | @dataclass 19 | class ExperimentArgs: 20 | model_name: str = field(default="google/flan-t5-xl") 21 | num_labels: int = field(default=2) 22 | attr_idx: int = 1 23 | load_8bit: bool = True 24 | instruction_model: str = 'basic' # 'basic' or 'alpaca' 25 | dataset: str = 'wow' # 'wow' or 'fever' 26 | use_kilt_format: bool = False 27 | test_data_path: str = "data/cached/wow/test_unseen.jsonl" 28 | generations_path: str = "generations/pplm_prompts.jsonl" 29 | causal_lm_generations: bool = False 30 | load_checkpoint: str = None 31 | load_peft_checkpoint: str = None 32 | load_classifier: str = field(default=None) 33 | use_mlp_classifier: bool = False 34 | batch_size: int = 1 35 | skip_no_knowledge: bool = False 36 | 37 | 38 | def main(): 39 | parser = HfArgumentParser([ExperimentArgs]) 40 | args = parser.parse_args_into_dataclasses()[0] 41 | args.train_data_path = None 42 | args.validation_data_path = None 43 | 44 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 45 | load_kwargs = { 46 | 'device_map': 'auto' if args.load_8bit else None, 47 | 'load_in_8bit': args.load_8bit, 48 | 'torch_dtype': torch.float16 if args.load_8bit else torch.bfloat16, 49 | } 50 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, **load_kwargs) 51 | tokenizer.truncation_side = 'left' 52 | 53 | if args.use_mlp_classifier: 54 | load_kwargs.pop('load_in_8bit') 55 | load_kwargs['use_mlp_classifier'] = True 56 | load_kwargs['quantization_config'] = BitsAndBytesConfig(load_in_8bit=args.load_8bit, 57 | llm_int8_skip_modules=['lm_head', 'mlp_layer1', 'mlp_layer2']) 58 | model = T5DoubleHeadModel.from_pretrained(args.model_name, 59 | output_hidden_states=True, 60 | use_cache=True, 61 | num_labels=args.num_labels, 62 | pool_method='last', 63 | **load_kwargs) 64 | 65 | if args.load_peft_checkpoint: 66 | model = PeftModel.from_pretrained(model, args.load_peft_checkpoint) 67 | 68 | if args.load_checkpoint: 69 | peft_config_path = os.path.join(os.path.dirname(args.load_checkpoint), 'adapter_model') 70 | peft_config = LoraConfig.from_pretrained(peft_config_path) 71 | model = get_peft_model(model, peft_config) 72 | incompatible = model.load_state_dict(torch.load(args.load_checkpoint), strict=False) 73 | assert (len(incompatible.missing_keys) == 1 74 | and incompatible.missing_keys[0].endswith('lm_head.weight')) 75 | 76 | if args.load_classifier: 77 | ckpt = torch.load(args.load_classifier) 78 | ckpt = {k.replace('classifier.', ''): v for k, v in ckpt.items() if 'classifier' in k} 79 | model.classifier.load_state_dict(ckpt, strict=True) 80 | 81 | dataset = load_data(args, 82 | tokenizer, 83 | is_encoder_decoder=model.config.is_encoder_decoder, 84 | instruction_model=args.instruction_model)['test'] 85 | 86 | if '*' in args.generations_path: 87 | paths = glob.glob(args.generations_path) 88 | elif os.path.isdir(args.generations_path): 89 | paths = glob.glob(os.path.join(args.generations_path, '*.jsonl')) 90 | else: 91 | paths = [args.generations_path] 92 | 93 | outfile = open('class_prob.jsonl', 'a') 94 | for path in paths: 95 | print(f"loading generations at {path}...") 96 | gen_dataset = load_generations(path, 97 | dataset, 98 | tokenizer, 99 | causal_lm_generations=args.causal_lm_generations) 100 | if args.skip_no_knowledge: 101 | gen_dataset = gen_dataset.filter( 102 | lambda x: 'no_passages_used' not in tokenizer.decode(x['input_ids'])) 103 | collator = DataCollatorForSeq2SeqTokenClassification(tokenizer) 104 | dataloader = DataLoader(gen_dataset, 105 | batch_size=args.batch_size, 106 | collate_fn=collator, 107 | shuffle=False) 108 | print("generations loaded") 109 | samples_pbar = tqdm(enumerate(dataloader), total=len(dataloader)) 110 | 111 | sum_probs = 0 112 | for i, batch in samples_pbar: 113 | batch = batch.to(device) 114 | _, output = model(**batch) 115 | probs = torch.softmax(output.logits, dim=-1)[:, args.attr_idx] # [B,] 116 | sum_probs += probs.sum().item() 117 | 118 | mean_score = sum_probs / len(gen_dataset) 119 | print(f"Evaluation of {path}: {mean_score}") 120 | outfile.write(json.dumps({'path': path, 'score': mean_score}) + '\n') 121 | outfile.close() 122 | 123 | 124 | def load_generations(generations_path, dataset, tokenizer, causal_lm_generations=False): 125 | df = pd.read_json(generations_path, lines=True) 126 | data_df = dataset.to_pandas() 127 | data_df = data_df.iloc[:len(df)] # truncate if generations are shorter 128 | df['labels'] = data_df['labels'].apply(lambda x: x[0]) 129 | gen_dataset = Dataset.from_pandas(df) 130 | def _tokenize(example): 131 | if isinstance(example['response'], dict): 132 | # for chat GPT 133 | if 'text' in example['response']['choices'][0]: 134 | gen = example['response']['choices'][0]['text'] 135 | else: 136 | gen = example['response']['choices'][0]['message']['content'] 137 | else: 138 | gen = example['response'] 139 | if causal_lm_generations: 140 | try: 141 | gen = gen.split('### Response:')[1] 142 | except: 143 | print("No response found in generation.") 144 | gen = "no response" 145 | gen_ids = tokenizer(gen, 146 | truncation=True, 147 | max_length=tokenizer.model_max_length, 148 | return_tensors='pt').input_ids[0] 149 | gen_ids = shift_right(gen_ids, tokenizer.pad_token_id) 150 | labels = torch.full_like(gen_ids, example['labels']) 151 | return {'decoder_input_ids': gen_ids, 'labels': labels} 152 | 153 | tokenized = gen_dataset.map(_tokenize, remove_columns=gen_dataset.column_names) 154 | data_df['decoder_input_ids'] = tokenized['decoder_input_ids'] 155 | data_df['labels'] = tokenized['labels'] 156 | dataset = Dataset.from_pandas(data_df) 157 | return dataset 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /scripts/evaluate_summary_mfma.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import glob 3 | import os 4 | 5 | from datasets import Dataset 6 | import torch 7 | import pandas as pd 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, HfArgumentParser, DataCollatorWithPadding 11 | 12 | from kcd.token_classifier.dataloader import load_data 13 | 14 | @dataclass 15 | class ExperimentArgs: 16 | instruction_model: str = 'basic' # 'basic' or 'alpaca' 17 | dataset: str = 'wow' # 'wow' or 'fever' 18 | use_kilt_format: bool = False 19 | test_data_path: str = "data/cached/wow/test_unseen.jsonl" 20 | generations_path: str = "generations/pplm_prompts.jsonl" 21 | causal_lm_generations: bool = False 22 | batch_size: int = 1 23 | 24 | 25 | def main(): 26 | parser = HfArgumentParser([ExperimentArgs]) 27 | args = parser.parse_args_into_dataclasses()[0] 28 | args.train_data_path = None 29 | args.validation_data_path = None 30 | 31 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 32 | tokenizer = AutoTokenizer.from_pretrained("henry931007/mfma") 33 | model = AutoModelForSequenceClassification.from_pretrained("henry931007/mfma").to(device) 34 | 35 | dataset = load_data(args, 36 | tokenizer, 37 | is_encoder_decoder=model.config.is_encoder_decoder, 38 | instruction_model=args.instruction_model, 39 | get_knowledge_ids=True)['test'] 40 | 41 | if '*' in args.generations_path: 42 | paths = glob.glob(args.generations_path) 43 | elif os.path.isdir(args.generations_path): 44 | paths = glob.glob(os.path.join(args.generations_path, '*.jsonl')) 45 | else: 46 | paths = [args.generations_path] 47 | for path in paths: 48 | print(f"loading generations at {path}...") 49 | gen_dataset = load_generations(path, 50 | dataset, 51 | tokenizer, 52 | causal_lm_generations=args.causal_lm_generations) 53 | collator = DataCollatorWithPadding(tokenizer) 54 | dataloader = DataLoader(gen_dataset, 55 | batch_size=args.batch_size, 56 | collate_fn=collator, 57 | shuffle=False) 58 | print("generations loaded") 59 | samples_pbar = tqdm(enumerate(dataloader), total=len(dataloader)) 60 | 61 | sum_probs = 0 62 | for i, batch in samples_pbar: 63 | batch = batch.to(device) 64 | output = model(**batch) 65 | probs = torch.softmax(output.logits, dim=-1)[:, 0] # [B,] 66 | sum_probs += probs.sum().item() 67 | 68 | mean_score = sum_probs / len(gen_dataset) 69 | print(f"Evaluation of {path}: {mean_score}") 70 | 71 | 72 | def load_generations(generations_path, dataset, tokenizer, causal_lm_generations=False): 73 | df = pd.read_json(generations_path, lines=True) 74 | data_df = dataset.to_pandas() 75 | data_df = data_df.iloc[:len(df)] # truncate if generations are shorter 76 | df['labels'] = data_df['labels'].apply(lambda x: x[0]) 77 | df['knowledge_ids'] = data_df['knowledge_ids'].apply(lambda x: tokenizer.decode(x, skip_special_tokens=True)) 78 | gen_dataset = Dataset.from_pandas(df) 79 | def _tokenize(example): 80 | if isinstance(example['response'], dict): 81 | # for chat GPT 82 | if 'text' in example['response']['choices'][0]: 83 | gen = example['response']['choices'][0]['text'] 84 | else: 85 | gen = example['response']['choices'][0]['message']['content'] 86 | else: 87 | gen = example['response'] 88 | if causal_lm_generations: 89 | try: 90 | gen = gen.split('### Response:')[1] 91 | except: 92 | print("No response found in generation.") 93 | gen = "no response" 94 | 95 | inputs = tokenizer(example['knowledge_ids'], 96 | gen, 97 | truncation=True, 98 | max_length=tokenizer.model_max_length, 99 | return_tensors='pt') 100 | inputs['labels'] = torch.LongTensor([example['labels']]) 101 | return {k: v[0] for k, v in inputs.items()} 102 | 103 | tokenized = gen_dataset.map(_tokenize, remove_columns=gen_dataset.column_names) 104 | return tokenized 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /scripts/evaluate_zeroshot_wow_classification.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | from transformers import HfArgumentParser, DataCollatorWithPadding 7 | from sklearn.metrics import classification_report 8 | 9 | from kcd.token_classifier.dataloader import load_data 10 | from kcd.util import load_transformer_LM_tokenizer 11 | 12 | @dataclass 13 | class ExperimentArgs: 14 | model_name: str = field(default="google/flan-t5-xl") 15 | load_8bit: bool = True 16 | instruction_model: str = 'basic' # 'basic' or 'alpaca' 17 | dataset: str = 'wow' # 'wow' or 'fever' 18 | use_kilt_format: bool = False 19 | test_data_path: str = "data/cached/wow/test_unseen.jsonl" 20 | batch_size: int = 1 21 | 22 | 23 | def main(): 24 | parser = HfArgumentParser([ExperimentArgs]) 25 | args = parser.parse_args_into_dataclasses()[0] 26 | args.train_data_path = None 27 | args.validation_data_path = None 28 | 29 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 30 | load_kwargs = { 31 | 'device_map': 'auto' if args.load_8bit else None, 32 | 'load_in_8bit': args.load_8bit, 33 | 'torch_dtype': torch.float16 if args.load_8bit else torch.bfloat16, 34 | } 35 | model, tokenizer = load_transformer_LM_tokenizer(args.model_name, **load_kwargs) 36 | model.eval() 37 | tokenizer.truncation_side = 'left' 38 | 39 | dataset = load_data(args, 40 | tokenizer, 41 | zeroshot_classification=True, 42 | is_encoder_decoder=model.config.is_encoder_decoder, 43 | instruction_model=args.instruction_model)['test'] 44 | collator = DataCollatorWithPadding(tokenizer) 45 | dataloader = DataLoader(dataset, 46 | batch_size=args.batch_size, 47 | collate_fn=collator, 48 | shuffle=False) 49 | 50 | pos_idx = tokenizer.encode('Yes')[0] 51 | neg_idx = tokenizer.encode('No')[0] 52 | 53 | preds = [] 54 | labels = [] 55 | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)): 56 | batch = batch.to(device) 57 | label = batch.pop('labels') 58 | if model.config.is_encoder_decoder: 59 | batch['decoder_input_ids'] = torch.full((batch['input_ids'].shape[0], 1), 60 | model.config.decoder_start_token_id, 61 | dtype=torch.long, 62 | device=device) 63 | with torch.inference_mode(): 64 | logits = model(**batch).logits 65 | probs = torch.softmax(logits[:, -1, [neg_idx, pos_idx]], dim=-1) 66 | pred = torch.argmax(probs, dim=-1) 67 | 68 | preds.append(pred.cpu()) 69 | labels.append(label.cpu()) 70 | 71 | preds = torch.cat(preds).numpy() 72 | labels = torch.cat(labels).numpy() 73 | print(classification_report(labels, preds, digits=4)) 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /scripts/run_guided_generation.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import json 3 | import os 4 | from dataclasses import dataclass 5 | from pprint import pprint 6 | 7 | import torch 8 | from tqdm.auto import tqdm 9 | from transformers import HfArgumentParser, Seq2SeqTrainingArguments, BitsAndBytesConfig 10 | from peft import PeftModel, LoraConfig, get_peft_model 11 | 12 | from kcd.token_classifier.dataloader import load_data, DataCollatorForSeq2SeqTokenClassification 13 | from kcd.classifier_guidance import GuidedGenerationPredictor, load_generate_fn 14 | from kcd.evaluation import TokenF1Score, evaluate_per_sent 15 | from kcd.util import load_transformer_LM_tokenizer 16 | from kcd.configs import GenerationConfig 17 | 18 | 19 | @dataclass 20 | class ExperimentArgs: 21 | test_data_path: str = 'data/wow-dev-kilt-processed.jsonl' 22 | output_path: str = 'generations/fudge' 23 | model_name: str = "google/flan-t5-xl" 24 | dataset: str = 'wow' 25 | use_kilt_format: bool = True 26 | load_8bit: bool = True 27 | print_output: bool = False 28 | instruction_model: str = 'basic' # choices=['basic', 'openai', 'alpaca'] 29 | guidance_method: str = 'fudge' # choices=['metric_guidance', 'fudge', 'nado'] 30 | metric: str = 'token_f1' # choices=['token_f1'] 31 | disc_name: str = '' 32 | 33 | num_labels: int = 2 34 | load_checkpoint: str = None 35 | load_peft_checkpoint: str = None 36 | load_classifier: str = None 37 | use_mlp_classifier: bool = False 38 | continue_from: int = 0 39 | v2: bool = False 40 | human_indices: str = None 41 | alpha: float = 1.0 # how much grounded for nado 42 | 43 | complete_after: int = 0 44 | 45 | 46 | def main(): 47 | parser = HfArgumentParser((ExperimentArgs, GenerationConfig, Seq2SeqTrainingArguments)) 48 | args, gen_cfg, train_args = parser.parse_args_into_dataclasses() 49 | args.output_path = train_args.output_dir 50 | train_args.predict_with_generate = True 51 | train_args.remove_unused_columns = False # keep to False 52 | args.train_data_path = None 53 | args.validation_data_path = None 54 | 55 | load_kwargs = { 56 | 'device_map': 'auto' if args.load_8bit else None, 57 | 'load_in_8bit': args.load_8bit, 58 | 'torch_dtype': torch.float16 if args.load_8bit else torch.bfloat16, 59 | } 60 | if args.guidance_method in ('fudge', 'nado', 'astar'): 61 | load_kwargs['num_labels'] = args.num_labels 62 | load_kwargs['pool_method'] = 'last' # for efficiency 63 | load_kwargs['load_t5_doublehead'] = True 64 | load_kwargs['v2'] = args.v2 65 | print('cuda available:', torch.cuda.is_available()) 66 | 67 | if args.use_mlp_classifier: 68 | load_kwargs.pop('load_in_8bit') 69 | load_kwargs['use_mlp_classifier'] = True 70 | load_kwargs['quantization_config'] = BitsAndBytesConfig(load_in_8bit=args.load_8bit, 71 | llm_int8_skip_modules=['lm_head', 'mlp_layer1', 'mlp_layer2']) 72 | 73 | model, tokenizer = load_transformer_LM_tokenizer(args.model_name, **load_kwargs) 74 | 75 | if args.load_peft_checkpoint: 76 | model = PeftModel.from_pretrained(model, args.load_peft_checkpoint) 77 | if args.load_checkpoint: 78 | peft_config_path = os.path.join(os.path.dirname(args.load_checkpoint), 'adapter_model') 79 | peft_config = LoraConfig.from_pretrained(peft_config_path) 80 | model = get_peft_model(model, peft_config) 81 | incompatible = model.load_state_dict(torch.load(args.load_checkpoint), strict=False) 82 | assert (len(incompatible.missing_keys) == 1 and 83 | incompatible.missing_keys[0].endswith('lm_head.weight')) 84 | if args.load_classifier: 85 | ckpt = torch.load(args.load_classifier) 86 | ckpt = {k.replace('classifier.', ''): v for k, v in ckpt.items() if 'classifier' in k} 87 | model.classifier.load_state_dict(ckpt, strict=True) 88 | 89 | if not model.config.is_encoder_decoder: 90 | raise NotImplementedError( 91 | 'The dataloading for non-encoder-decoder is not set yet for inference.' 92 | 'Take a look at kcd.token_classifier.dataloader.') 93 | 94 | dataset = load_data(args, 95 | tokenizer, 96 | is_encoder_decoder=model.config.is_encoder_decoder, 97 | instruction_model=args.instruction_model, 98 | get_knowledge_ids=True)['test'] 99 | indices = None 100 | if args.human_indices: 101 | with open(args.human_indices) as f: 102 | indices = [int(i) for i in f.readlines()] 103 | dataset = dataset.select(indices) 104 | if args.continue_from > 0: 105 | dataset = dataset.select(range(args.continue_from, len(dataset))) 106 | # load guidance criteria 107 | if args.guidance_method == 'metric_guidance': 108 | if args.metric == 'token_f1': 109 | metric = TokenF1Score.batch_compute 110 | elif args.metric in ('bleu', 'bertscore', 'rougeL', 'weighted_bleu'): 111 | if args.metric == 'weighted_bleu': 112 | # NOTE: idea - negative weight for higher n-gram -> less copy 113 | # bleu_weights = (2, 1, -0.5, -1.5) 114 | bleu_weights = (0.5, 0.25, 0.2, 0.05) 115 | metric = partial(evaluate_per_sent, metric='bleu', bleu_weights=bleu_weights) 116 | else: 117 | metric = partial(evaluate_per_sent, metric=args.metric) 118 | else: 119 | raise NotImplementedError(f"Metric {args.metric} not implemented") 120 | generate_fn = load_generate_fn(args.guidance_method, 121 | metric=metric, 122 | metric_name=args.metric, 123 | max_new_tokens=gen_cfg.max_new_tokens, 124 | k=gen_cfg.top_k) 125 | 126 | elif args.guidance_method == 'fudge': 127 | generate_fn = load_generate_fn(args.guidance_method, 128 | model=model, 129 | max_new_tokens=gen_cfg.max_new_tokens, 130 | k=gen_cfg.top_k, 131 | complete_after=args.complete_after, 132 | disable_adapter_lm_forward=args.load_classifier is None) 133 | elif args.guidance_method == 'nado': 134 | generate_fn = load_generate_fn(args.guidance_method, 135 | model=model, 136 | max_new_tokens=gen_cfg.max_new_tokens, 137 | k=gen_cfg.top_k, 138 | alpha=args.alpha, 139 | disable_adapter_lm_forward=args.load_classifier is None) 140 | elif args.guidance_method == 'astar': 141 | generate_fn = load_generate_fn(args.guidance_method, 142 | model=model, 143 | max_new_tokens=gen_cfg.max_new_tokens, 144 | k=gen_cfg.top_k, 145 | disable_adapter_lm_forward=args.load_classifier is None, 146 | future_steps=5, 147 | lambda_weight=0.25, 148 | soft_forward=False) 149 | 150 | else: 151 | raise NotImplementedError(f"Guidance method {args.guidance_method} not implemented") 152 | 153 | trainer = GuidedGenerationPredictor( 154 | generate_fn=generate_fn, 155 | model=model, 156 | args=train_args, 157 | data_collator=DataCollatorForSeq2SeqTokenClassification( 158 | tokenizer, 159 | other_features_to_pad=['knowledge_ids'], 160 | ), 161 | tokenizer=tokenizer, 162 | ) 163 | preds = trainer.predict(dataset, **gen_cfg.__dict__) 164 | preds.predictions[preds.predictions == -100] = tokenizer.pad_token_id 165 | responses = tokenizer.batch_decode(preds.predictions, skip_special_tokens=True) 166 | 167 | os.makedirs(args.output_path, exist_ok=True) 168 | dataset_name = args.dataset 169 | if args.use_kilt_format: 170 | dataset_name = f'{dataset_name}-kilt' 171 | 172 | base_fname = args.model_name.replace("/", "-") 173 | if args.guidance_method == 'metric_guidance': 174 | base_fname += f'-{args.metric}' 175 | else: 176 | base_fname += f'-{args.guidance_method}-{args.disc_name}' 177 | out_fname = os.path.join(args.output_path, f'{dataset_name}-{base_fname}.jsonl') 178 | fout = open(out_fname, 'a') 179 | for i, (example, response) in tqdm(enumerate(zip(dataset, responses)), total=len(responses)): 180 | prompt = tokenizer.decode(example['input_ids'], skip_special_tokens=True) 181 | completion = tokenizer.decode(example['decoder_input_ids'], skip_special_tokens=True) 182 | result = { 183 | 'prompt': prompt, 184 | 'response': response, 185 | 'gold': completion, 186 | 'index': i + args.continue_from, 187 | } 188 | if indices is not None: 189 | result['index'] = indices[i] 190 | fout.write(json.dumps(result) + '\n') 191 | if args.print_output: 192 | print("prompt:\n") 193 | pprint(prompt) 194 | print() 195 | print(f"{args.model_name} Response:\n") 196 | print(response + '\n\n') 197 | print(f"Gold Response:\n") 198 | print(completion + '\n\n') 199 | fout.close() 200 | 201 | 202 | if __name__ == "__main__": 203 | main() 204 | -------------------------------------------------------------------------------- /scripts/run_openai_guided_generation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm.auto import tqdm 8 | from transformers import HfArgumentParser, AutoTokenizer, DataCollatorWithPadding 9 | from peft import PeftModel, LoraConfig, get_peft_model 10 | 11 | from kcd.attribute_classifier.attribute_classifier_model import DoubleHeadModel 12 | from kcd.token_classifier.model import T5DoubleHeadModel 13 | from kcd.text_data import load_text_data 14 | from kcd.classifier_guidance import load_generate_fn 15 | from kcd.openai_module import OpenAIModel, OpenAIAPIParameters, MockOpenAIModel 16 | from kcd.configs import GenerationConfig 17 | 18 | 19 | @dataclass 20 | class ExperimentArgs: 21 | test_data_path: str = 'data/wow-dev-kilt-processed.jsonl' 22 | output_path: str = 'generations/fudge' 23 | model_name: str = "google/flan-t5-xl" 24 | openai_model_name: str = 'text-davinci-003' 25 | dataset: str = 'wow' 26 | use_kilt_format: bool = False 27 | instruction_model: str = 'basic' # choices=['basic', 'openai', 'alpaca'] 28 | guidance_method: str = 'openai_fudge' # choices=['fudge', 'nado'] 29 | disc_name: str = '' 30 | use_logit_bias: bool = False 31 | pre_post_guidance: bool = False 32 | propose_topk: int = 50 33 | 34 | num_labels: int = 2 35 | load_checkpoint: str = None 36 | load_peft_checkpoint: str = None 37 | load_classifier: str = None 38 | human_indices: str = None 39 | 40 | batch_size: int = 1 41 | continue_from: int = 0 42 | max_num_gen: int = -1 43 | 44 | mock_debug: bool = False 45 | 46 | 47 | def main(): 48 | parser = HfArgumentParser((ExperimentArgs, GenerationConfig)) 49 | args, gen_cfg = parser.parse_args_into_dataclasses() 50 | 51 | print('cuda available:', torch.cuda.is_available()) 52 | # TODO: token space / wordpiece issue with understanding chatGPT output. 53 | if args.openai_model_name == 'gpt-3.5-turbo': 54 | raise NotImplementedError("chatGPT not implemented yet") 55 | ################### load discriminator #################################### 56 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 57 | load_kwargs = { 58 | 'device_map': 'auto', 59 | 'torch_dtype': torch.float16, 60 | } 61 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 62 | tokenizer.padding_side = "left" # only for causal LM 63 | tokenizer.pad_token = tokenizer.eos_token 64 | 65 | if 't5' in args.model_name: 66 | model = T5DoubleHeadModel.from_pretrained(args.model_name, 67 | output_hidden_states=True, 68 | use_cache=True, 69 | num_labels=args.num_labels, 70 | pool_method='last', 71 | load_in_8bit=True, 72 | **load_kwargs) 73 | else: 74 | model = DoubleHeadModel.from_pretrained(args.model_name, 75 | output_hidden_states=True, 76 | use_cache=True, 77 | num_labels=args.num_labels, 78 | pool_method='last', 79 | **load_kwargs) 80 | 81 | if args.load_peft_checkpoint: 82 | model = PeftModel.from_pretrained(model, args.load_peft_checkpoint) 83 | if args.load_checkpoint: 84 | peft_config_path = os.path.join(os.path.dirname(args.load_checkpoint), 'adapter_model') 85 | peft_config = LoraConfig.from_pretrained(peft_config_path) 86 | model = get_peft_model(model, peft_config) 87 | incompatible = model.load_state_dict(torch.load(args.load_checkpoint), strict=False) 88 | assert (len(incompatible.missing_keys) == 1 and 89 | incompatible.missing_keys[0].endswith('lm_head.weight')) 90 | if args.load_classifier: 91 | ckpt = torch.load(args.load_classifier) 92 | ckpt = {k.replace('score.', ''): v for k, v in ckpt.items() if 'score' in k} 93 | model.score.load_state_dict(ckpt, strict=True) 94 | 95 | task = 'chat' if args.openai_model_name == 'gpt-3.5-turbo' else 'completion' 96 | if args.mock_debug: 97 | openai_model = MockOpenAIModel() 98 | else: 99 | openai_model = OpenAIModel(args.openai_model_name, task=task) 100 | 101 | # load guidance criteria 102 | if args.guidance_method == 'openai_fudge': 103 | parameters = OpenAIAPIParameters( 104 | max_tokens=1, # no lookahead 105 | temperature=gen_cfg.temperature, 106 | top_p=gen_cfg.top_p, 107 | logprobs=5) 108 | generate_fn = load_generate_fn(args.guidance_method, 109 | openai_model=openai_model, 110 | model=model, 111 | tokenizer=tokenizer, 112 | max_new_tokens=gen_cfg.max_new_tokens, 113 | k=6, 114 | pre_post_guidance=args.pre_post_guidance, 115 | use_logit_bias=args.use_logit_bias, 116 | propose_topk=args.propose_topk, 117 | parameters=parameters) 118 | else: 119 | raise NotImplementedError(f"Guidance method {args.guidance_method} not implemented") 120 | 121 | ############################## Data ######################################## 122 | print("loading dataset") 123 | dataset = load_text_data(path=args.test_data_path, 124 | instruction_model=args.instruction_model, 125 | task='completion', 126 | use_kilt_format=args.use_kilt_format, 127 | tokenize=True, 128 | add_trailing_newline='t5' not in args.model_name, 129 | tokenizer=tokenizer, 130 | no_label=True) 131 | text_dataset = load_text_data(path=args.test_data_path, 132 | instruction_model=args.instruction_model, 133 | task='completion', 134 | use_kilt_format=args.use_kilt_format, 135 | tokenize=False) 136 | indices = None 137 | if args.max_num_gen > 0: 138 | import random 139 | random.seed(42) 140 | indices = random.sample(range(len(dataset)), args.max_num_gen) 141 | dataset = dataset.select(indices) 142 | text_dataset = text_dataset.select(indices) 143 | 144 | if args.human_indices: 145 | with open(args.human_indices) as f: 146 | indices = [int(i) for i in f.readlines()] 147 | dataset = dataset.select(indices) 148 | text_dataset = text_dataset.select(indices) 149 | if args.continue_from > 0: 150 | dataset = dataset.select(range(args.continue_from, len(dataset))) 151 | text_dataset = text_dataset.select(range(args.continue_from, len(text_dataset))) 152 | 153 | collator = DataCollatorWithPadding(tokenizer=tokenizer) 154 | dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collator, shuffle=False) 155 | 156 | ###################### out file configuration ############################## 157 | os.makedirs(args.output_path, exist_ok=True) 158 | dataset_name = args.dataset 159 | if args.use_kilt_format: 160 | dataset_name = f'{dataset_name}-kilt' 161 | out_fname = os.path.join( 162 | args.output_path, 163 | f'{dataset_name}-{args.model_name.replace("/", "-")}-{args.disc_name}.jsonl') 164 | fout = open(out_fname, 'a') 165 | 166 | ####################### Generation Loop #################################### 167 | samples_pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Samples generated") 168 | for i, batch in samples_pbar: 169 | batch = batch.to(device) 170 | 171 | decoded_ids, usage, success = generate_fn(batch) 172 | if not success: 173 | print("the generation failed abruptly.") 174 | 175 | for j, (prompt_ids, response_ids) in enumerate(zip(batch['input_ids'], decoded_ids)): 176 | seqlen = len(prompt_ids) 177 | response_ids = response_ids[seqlen:] 178 | prompt = tokenizer.decode(prompt_ids, skip_special_tokens=True) 179 | response = tokenizer.decode(response_ids, skip_special_tokens=True) 180 | result = { 181 | 'prompt': prompt, 182 | 'response': response, 183 | 'index': i * args.batch_size + j + args.continue_from, 184 | 'token_usage': usage[j].item(), 185 | 'success': success, 186 | } 187 | if indices is not None: 188 | result['index'] = indices[result['index']] 189 | result['gold'] = text_dataset['completion'][i * args.batch_size + j + 190 | args.continue_from] 191 | fout.write(json.dumps(result) + '\n') 192 | fout.close() 193 | 194 | 195 | if __name__ == "__main__": 196 | main() 197 | -------------------------------------------------------------------------------- /scripts/run_openai_ppl_mcts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass, field 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from transformers import AutoTokenizer, HfArgumentParser, DataCollatorWithPadding 9 | from peft import PeftModel, LoraConfig, get_peft_model 10 | 11 | from kcd.text_data import load_text_data 12 | from kcd.attribute_classifier.attribute_classifier_model import DoubleHeadModel 13 | from kcd.openai_module import OpenAIModel, OpenAIAPIParameters, MockOpenAIModel 14 | from kcd.classifier_guidance.ppl_mcts import PplMCTSConfig 15 | from kcd.classifier_guidance.openai_ppl_mcts import OpenAIMCTS 16 | from kcd.configs import GenerationConfig 17 | 18 | 19 | @dataclass 20 | class ExperimentArgs: 21 | lm_name: str = field(default="gpt2-xl") 22 | openai_model_name: str = 'text-davinci-003' 23 | num_labels: int = field(default=2) 24 | attr_idx: int = 1 25 | dataset: str = 'wow' 26 | use_kilt_format: bool = False 27 | test_data_path: str = field(default="data/pplm_prompts.csv") 28 | output_path: str = 'generations/ppl_mcts' 29 | disc_name: str = '' 30 | instruction_model: str = 'basic' # choices=['basic', 'openai', 'alpaca'] 31 | load_peft_checkpoint: str = None 32 | load_checkpoint: str = None 33 | load_classifier: str = field(default=None) 34 | batch_size: int = 1 35 | continue_from: int = 0 36 | human_indices: str = None 37 | 38 | max_num_gen: int = -1 39 | 40 | mock_debug: bool = False 41 | 42 | 43 | def main(): 44 | parser = HfArgumentParser([ExperimentArgs, PplMCTSConfig, GenerationConfig]) 45 | args, mcts_args, gen_cfg = parser.parse_args_into_dataclasses() 46 | 47 | print('cuda available:', torch.cuda.is_available()) 48 | ################### load discriminator #################################### 49 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 50 | load_kwargs = { 51 | 'device_map': 'auto', 52 | 'torch_dtype': torch.float16, 53 | } 54 | tokenizer = AutoTokenizer.from_pretrained('gpt2') 55 | tokenizer.padding_side = "left" # only for causal LM 56 | tokenizer.pad_token = tokenizer.eos_token 57 | 58 | model = DoubleHeadModel.from_pretrained(args.lm_name, 59 | output_hidden_states=True, 60 | use_cache=True, 61 | num_labels=args.num_labels, 62 | pool_method='last', 63 | **load_kwargs) 64 | 65 | if args.load_peft_checkpoint: 66 | model = PeftModel.from_pretrained(model, args.load_peft_checkpoint) 67 | if args.load_checkpoint: 68 | peft_config_path = os.path.join(os.path.dirname(args.load_checkpoint), 'adapter_model') 69 | peft_config = LoraConfig.from_pretrained(peft_config_path) 70 | model = get_peft_model(model, peft_config) 71 | incompatible = model.load_state_dict(torch.load(args.load_checkpoint), strict=False) 72 | assert (len(incompatible.missing_keys) == 1 and 73 | incompatible.missing_keys[0].endswith('lm_head.weight')) 74 | if args.load_classifier: 75 | ckpt = torch.load(args.load_classifier) 76 | ckpt = {k.replace('score.', ''): v for k, v in ckpt.items() if 'score' in k} 77 | model.score.load_state_dict(ckpt, strict=True) 78 | 79 | ############################## OPENAI MODEL ############################### 80 | if args.mock_debug: 81 | openai_model = MockOpenAIModel() 82 | else: 83 | openai_model = OpenAIModel(args.openai_model_name) 84 | parameters = OpenAIAPIParameters( 85 | max_tokens=1, # no lookahead 86 | temperature=gen_cfg.temperature, 87 | top_p=gen_cfg.top_p, 88 | logprobs=5) 89 | ############################## Data ######################################## 90 | print("loading dataset") 91 | dataset = load_text_data(path=args.test_data_path, 92 | instruction_model=args.instruction_model, 93 | task='completion', 94 | use_kilt_format=args.use_kilt_format, 95 | tokenize=True, 96 | tokenizer=tokenizer, 97 | no_label=True) 98 | text_dataset = load_text_data(path=args.test_data_path, 99 | instruction_model=args.instruction_model, 100 | task='completion', 101 | use_kilt_format=args.use_kilt_format, 102 | tokenize=False) 103 | indices = None 104 | if args.max_num_gen > 0: 105 | import random 106 | random.seed(42) 107 | indices = random.sample(range(len(dataset)), args.max_num_gen) 108 | dataset = dataset.select(indices) 109 | text_dataset = text_dataset.select(indices) 110 | 111 | if args.human_indices: 112 | with open(args.human_indices) as f: 113 | indices = [int(i) for i in f.readlines()] 114 | dataset = dataset.select(indices) 115 | text_dataset = text_dataset.select(indices) 116 | if args.continue_from > 0: 117 | dataset = dataset.select(range(args.continue_from, len(dataset))) 118 | text_dataset = text_dataset.select(range(args.continue_from, len(text_dataset))) 119 | 120 | collator = DataCollatorWithPadding(tokenizer=tokenizer) 121 | dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collator, shuffle=False) 122 | print("dataset loaded") 123 | ############################### MCTS ####################################### 124 | batch_size = args.batch_size 125 | MCTS = OpenAIMCTS(mcts_args, 126 | tokenizer, 127 | openai_model, 128 | parameters, 129 | model, 130 | batch_size=batch_size, 131 | top_k=6, 132 | num_labels=args.num_labels, 133 | unused_token_id=tokenizer.unk_token_id, 134 | device=device) 135 | 136 | samples_pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Samples generated") 137 | 138 | ###################### out file configuration ############################## 139 | os.makedirs(args.output_path, exist_ok=True) 140 | dataset_name = args.dataset 141 | if args.use_kilt_format: 142 | dataset_name = f'{dataset_name}-kilt' 143 | out_fname = os.path.join( 144 | args.output_path, f'{dataset_name}-{args.lm_name.replace("/", "-")}-{args.disc_name}.jsonl') 145 | fout = open(out_fname, 'a') 146 | ####################### Generation Loop #################################### 147 | for i, batch in samples_pbar: 148 | batch = batch.to(device) 149 | labels = torch.zeros(batch['input_ids'].shape[0], args.num_labels, device=device) 150 | labels[:, args.attr_idx] = 1 151 | 152 | MCTS.set_labels(labels) 153 | _, decoded_ids = MCTS.search(batch, tokens_to_generate=gen_cfg.max_new_tokens) 154 | 155 | for j, (prompt_ids, response_ids) in enumerate(zip(batch['input_ids'], decoded_ids)): 156 | seqlen = len(prompt_ids) 157 | response_ids = response_ids[seqlen:] 158 | prompt = tokenizer.decode(prompt_ids, skip_special_tokens=True) 159 | response = tokenizer.decode(response_ids, skip_special_tokens=True) 160 | result = { 161 | 'prompt': prompt, 162 | 'response': response, 163 | 'index': i * batch_size + j + args.continue_from, 164 | 'token_usage': MCTS.lm_step.token_usages[j], 165 | } 166 | if indices is not None: 167 | result['index'] = indices[result['index']] 168 | result['gold'] = text_dataset['completion'][i * batch_size + j + args.continue_from] 169 | fout.write(json.dumps(result) + '\n') 170 | fout.close() 171 | 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /scripts/shell/baselines/baseline_run.sh: -------------------------------------------------------------------------------- 1 | ids=$1 2 | model=$2 3 | task=$3 # [wow | summarization] 4 | # compute number of gpus 5 | arrIDs=(${ids//,/ }) 6 | GPU_PER_NODE="${#arrIDs[@]}" 7 | 8 | # decide python launcher 9 | if [ $GPU_PER_NODE = 1 ]; then 10 | echo "Using 1 GPU: use simple python launcher..." 11 | launcher="CUDA_VISIBLE_DEVICES=$ids python" 12 | else 13 | echo "Using multi-GPU: using torchrun launcher..." 14 | launcher="CUDA_VISIBLE_DEVICES=$ids WORLD_SIZE=$GPU_PER_NODE torchrun --nproc_per_node $GPU_PER_NODE" 15 | fi 16 | 17 | if [ $task = 'wow' ]; then 18 | task_options="--data_path data/cached/wow/test_unseen.jsonl --dataset wow --max_new_tokens 32" 19 | elif [ $task = 'summarization' ]; then 20 | task_options="--data_path cnn_dailymail --dataset cnn_dailymail --max_new_tokens 64" 21 | else 22 | echo $task not defined. 23 | exit 24 | fi 25 | 26 | script="$launcher baseline/huggingface_run.py \ 27 | $task_options \ 28 | --use_kilt_format False \ 29 | --task completion --top_p 0.95 \ 30 | --model_name $model \ 31 | --load_8bit \ 32 | --per_device_eval_batch_size 4 \ 33 | --output_dir generations/baseline \ 34 | --predict_with_generate" 35 | 36 | eval $script 37 | -------------------------------------------------------------------------------- /scripts/shell/baselines/openai_run.sh: -------------------------------------------------------------------------------- 1 | python baseline/openai_run.py \ 2 | --data_path ../kcd_data/cached/wow/test_unseen.jsonl --dataset wow \ 3 | --use_kilt_format False \ 4 | --task completion --max_tokens 32 --top_p 0.1 \ 5 | --model_name text-davinci-003 --human_indices generations/wow_human_indices.txt 6 | -------------------------------------------------------------------------------- /scripts/shell/baselines/sft_baseline_run.sh: -------------------------------------------------------------------------------- 1 | ids=$1 2 | model=$2 3 | task=$3 # [wow | summarization] 4 | # compute number of gpus 5 | arrIDs=(${ids//,/ }) 6 | GPU_PER_NODE="${#arrIDs[@]}" 7 | 8 | # decide python launcher 9 | if [ $GPU_PER_NODE = 1 ]; then 10 | echo "Using 1 GPU: use simple python launcher..." 11 | launcher="CUDA_VISIBLE_DEVICES=$ids python" 12 | else 13 | echo "Using multi-GPU: using torchrun launcher..." 14 | launcher="CUDA_VISIBLE_DEVICES=$ids WORLD_SIZE=$GPU_PER_NODE torchrun --nproc_per_node $GPU_PER_NODE" 15 | fi 16 | 17 | if [ $task = 'wow' ]; then 18 | task_options="--data_path data/cached/wow/test_unseen.jsonl --dataset wow --max_new_tokens 32" 19 | elif [ $task = 'summarization' ]; then 20 | task_options="--data_path cnn_dailymail --dataset cnn_dailymail --max_new_tokens 64" 21 | else 22 | echo $task not defined. 23 | exit 24 | fi 25 | 26 | script="$launcher baseline/huggingface_run.py \ 27 | $task_options \ 28 | --use_kilt_format False \ 29 | --task completion --top_p 0.95 \ 30 | --model_name $model \ 31 | --load_8bit \ 32 | --load_checkpoint saved_models/flan-t5-xl-sft-wow/checkpoint-best/pytorch_model.bin \ 33 | --per_device_eval_batch_size 4 \ 34 | --output_dir generations/baseline \ 35 | --predict_with_generate" 36 | 37 | eval $script 38 | -------------------------------------------------------------------------------- /scripts/shell/data_process/partial_neg_gen.sh: -------------------------------------------------------------------------------- 1 | ids=$1 2 | data=$2 3 | bs=$3 4 | 5 | # compute number of gpus 6 | arrIDs=(${ids//,/ }) 7 | GPU_PER_NODE="${#arrIDs[@]}" 8 | 9 | # decide python launcher 10 | launcher="CUDA_VISIBLE_DEVICES=$ids python" 11 | 12 | if [ $data = 'wow' ]; then 13 | data_options="--dataset_name wow --dataset_path data/cached/wow/train.jsonl \ 14 | --max_neg_samples 10000 --max_new_tokens 64" 15 | elif [ $data = 'cnn_dailymail' ]; then 16 | data_options="--dataset_path cnn_dailymail --dataset_name cnn_dailymail \ 17 | --max_neg_samples 100000 --max_new_tokens 64" 18 | else 19 | echo $data not recognized. 20 | exit 21 | fi 22 | 23 | script="$launcher kcd/partial_negative.py \ 24 | $data_options \ 25 | --per_device_eval_batch_size $bs \ 26 | --temperature 1.4 --top_p 1 --output_dir data/cached \ 27 | --eval_accumulation_steps 200" 28 | 29 | eval $script 30 | -------------------------------------------------------------------------------- /scripts/shell/data_process/preprocess_wow.sh: -------------------------------------------------------------------------------- 1 | N=$1 2 | 3 | if [[ $N = '' ]]; then 4 | N=20 5 | fi 6 | 7 | ParlAI=$HOME/ParlAI 8 | 9 | # 1. Train 10 | python kcd/wizard_of_wikipedia/preprocess.py \ 11 | --in_file $ParlAI/data/wizard_of_wikipedia/train.json \ 12 | --out_file data/cached/wow/train.jsonl \ 13 | --keep_last_n $N 14 | 15 | # 2. valid 16 | python kcd/wizard_of_wikipedia/preprocess.py \ 17 | --in_file $ParlAI/data/wizard_of_wikipedia/valid_random_split.json \ 18 | --out_file data/cached/wow/dev_seen.jsonl \ 19 | --keep_last_n $N 20 | python kcd/wizard_of_wikipedia/preprocess.py \ 21 | --in_file $ParlAI/data/wizard_of_wikipedia/valid_topic_split.json \ 22 | --out_file data/cached/wow/dev_unseen.jsonl \ 23 | --keep_last_n $N 24 | 25 | # 3. test 26 | python kcd/wizard_of_wikipedia/preprocess.py \ 27 | --in_file $ParlAI/data/wizard_of_wikipedia/test_random_split.json \ 28 | --out_file data/cached/wow/test_seen.jsonl \ 29 | --keep_last_n $N 30 | python kcd/wizard_of_wikipedia/preprocess.py \ 31 | --in_file $ParlAI/data/wizard_of_wikipedia/test_topic_split.json \ 32 | --out_file data/cached/wow/test_unseen.jsonl \ 33 | --keep_last_n $N 34 | -------------------------------------------------------------------------------- /scripts/shell/data_process/random_neg.sh: -------------------------------------------------------------------------------- 1 | data=$1 2 | 3 | 4 | if [ $data = 'wow' ]; then 5 | data_options="--dataset_name wow --dataset_path data/cached/wow/train.jsonl \ 6 | --use_kilt_format False" 7 | elif [ $data = 'cnn_dailymail' ]; then 8 | data_options="--dataset_path cnn_dailymail --dataset_name cnn_dailymail \ 9 | --use_kilt_format False" 10 | else 11 | echo $data not recognized. 12 | exit 13 | fi 14 | 15 | script="python kcd/sample_negative.py $data_options" 16 | eval $script 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /scripts/shell/eval/class_prob.sh: -------------------------------------------------------------------------------- 1 | generation_file=$1 2 | 3 | function eval_prob(){ 4 | ids=$1 5 | task=$2 6 | name=$3 7 | skip=$4 8 | script="CUDA_VISIBLE_DEVICES=$ids python scripts/eval/evaluate_generations_with_classifier.py \ 9 | --model_name google/flan-t5-xl \ 10 | --load_8bit \ 11 | --dataset $task \ 12 | --test_data_path data/cached/wow/test_unseen.jsonl \ 13 | --generations_path $name \ 14 | --load_checkpoint saved_models/flan-t5-xl-DecoderDisc-$task-EOS/checkpoint-best/pytorch_model.bin \ 15 | --batch_size 8" 16 | 17 | if [[ $name == *"alpaca"* ]]; then 18 | script="$script --causal_lm_generations" 19 | fi 20 | 21 | if [ $skip ]; then 22 | script="$script --skip_no_knowledge" 23 | fi 24 | 25 | eval $script 26 | } 27 | 28 | eval_prob 5 wow $generation_file 1 29 | -------------------------------------------------------------------------------- /scripts/shell/eval/test_t5_token_classifier.sh: -------------------------------------------------------------------------------- 1 | ##### stdin #### 2 | ids=$1 3 | type=$2 # [EOS, ALL, RAND, RIPA] 4 | dataset=$3 5 | bs=$4 6 | V2=$5 7 | pool=$6 8 | ################ 9 | # compute number of gpus 10 | arrIDs=(${ids//,/ }) 11 | GPU_PER_NODE="${#arrIDs[@]}" 12 | 13 | # decide python launcher 14 | if [ $ids = '' ]; then 15 | echo "no gpu..." 16 | launcher="CUDA_VISIBLE_DEVICES= python" 17 | elif [ $GPU_PER_NODE = 1 ]; then 18 | echo "Using 1 GPU: use simple python launcher..." 19 | launcher="CUDA_VISIBLE_DEVICES=$ids python" 20 | else 21 | echo "Using multi-GPU: using torchrun launcher..." 22 | launcher="CUDA_VISIBLE_DEVICES=$ids WORLD_SIZE=$GPU_PER_NODE torchrun --nproc_per_node $GPU_PER_NODE" 23 | fi 24 | 25 | if [[ $bs = '' ]]; then 26 | bs=64 27 | fi 28 | 29 | use_kilt_format=False 30 | size=xl 31 | if [[ $type = '' ]]; then 32 | echo type was not provided. Defaults to RIPA... 33 | type=RIPA 34 | fi 35 | if [[ $use_kilt_format = True ]]; then 36 | model_name=flan-t5-$size-DecoderDisc-$dataset-kilt-$type 37 | else 38 | model_name=flan-t5-$size-DecoderDisc-$dataset-$type 39 | fi 40 | if [ $V2 = 1 ]; then 41 | model_name=$model_name-v2 42 | fi 43 | 44 | if [ $dataset = wow ]; then 45 | validation_data_path=data/cached/wow_test_unseen_augmented_neg_random 46 | test_data_path=data/cached/wow_test_augmented_neg_google-flan-t5-xl-0.1 47 | elif [ $dataset = cnn_dailymail ]; then 48 | validation_data_path=data/cached/cnn_dailymail_test_augmented_neg_google-flan-t5-xl-0.1 49 | test_data_path=data/cached/cnn_dailymail_test_augmented_neg_google-flan-t5-xl-0.1 50 | else 51 | echo $dataset unknown. 52 | exit 53 | fi 54 | 55 | script="$launcher kcd/token_classifier/train.py \ 56 | --model_name google/flan-t5-$size \ 57 | --is_decoder \ 58 | --num_labels 2 \ 59 | --pool_method $pool \ 60 | --wandb_project_name knowledge-classifier \ 61 | --wandb_run_name eval_$model_name-$pool \ 62 | --dataset $dataset \ 63 | --use_kilt_format $use_kilt_format \ 64 | --train_data_path $validation_data_path \ 65 | --validation_data_path $validation_data_path \ 66 | --test_data_path $test_data_path \ 67 | --test_only \ 68 | --load_checkpoint saved_models/$model_name/checkpoint-best/pytorch_model.bin \ 69 | --output_dir saved_models/$model_name \ 70 | --use_lora --train_8bit \ 71 | --eval_accumulation_steps 100 \ 72 | --per_device_eval_batch_size $bs" 73 | 74 | if [ $V2 = 1 ]; then 75 | script="$script --v2" 76 | fi 77 | 78 | eval $script 79 | -------------------------------------------------------------------------------- /scripts/shell/eval/unieval.sh: -------------------------------------------------------------------------------- 1 | ids=$1 2 | 3 | export CUDA_VISIBLE_DEVICES=$ids 4 | 5 | ############################## Summarization ################################## 6 | cnn_model_name=( 7 | baseline/cnn_dailymail-bigscience-T0pp 8 | baseline/cnn_dailymail-google-flan-t5-xl 9 | baseline/cnn_dailymail-google-flan-t5-xxl 10 | baseline/cnn_dailymail-openai_gpt-3.5-turbo 11 | baseline/cnn_dailymail-openai_text-davinci-003 12 | fudge/cnn_dailymail-google-flan-t5-xl-fudge-DecoderDisc-cnn_dailymail-RAND 13 | nado/cnn_dailymail-google-flan-t5-xl-nado-DecoderDisc-cnn_dailymail-ALL-v2-alpha0.25 14 | ppl_mcts/cnn_dailymail-google-flan-t5-xl-DecoderDisc-cnn_dailymail-RAND 15 | fudge/cnn_dailymail-google-flan-t5-xl-fudge-DecoderDisc-cnn_dailymail-RIPA 16 | ppl_mcts/cnn_dailymail-google-flan-t5-xl-DecoderDisc-cnn_dailymail-RIPA 17 | ) 18 | 19 | for name in "${cnn_model_name[@]}"; do 20 | save_name="${name/"/"/-}" 21 | 22 | script="python UniEval/run.py \ 23 | --task summarization \ 24 | --generations_path generations/${name}.jsonl \ 25 | --dataset_path cnn_dailymail \ 26 | --save_name $save_name" 27 | if [[ $name == *"alpaca"* ]]; then 28 | script="$script --causal_lm_generations" 29 | fi 30 | 31 | eval $script 32 | 33 | # MFMA score 34 | mfma_script="python scripts/evaluate_summary_mfma.py \ 35 | --dataset cnn_dailymail \ 36 | --test_data_path cnn_dailymail \ 37 | --batch_size 8 38 | --generations_path generations/${name}.jsonl" 39 | if [[ $name == *"alpaca"* ]]; then 40 | script="$script --causal_lm_generations" 41 | fi 42 | 43 | eval $mfma_script 44 | 45 | done 46 | 47 | ############################### WoW Dialogue ################################### 48 | wow_model_name=( 49 | baseline/wow-openai_gpt-3.5-turbo 50 | baseline/wow-openai_text-davinci-003 51 | baseline/wow-bigscience-T0pp 52 | baseline/wow-google-flan-t5-xl 53 | baseline/wow-google-flan-t5-xxl 54 | baseline/wow-google-flan-t5-xl-sft 55 | fudge/wow-google-flan-t5-xl-fudge-DecoderDisc-wow-RAND 56 | nado/wow-google-flan-t5-xl-nado-DecoderDisc-wow-ALL-v2-alpha0.25 57 | ppl_mcts/wow-google-flan-t5-xl-DecoderDisc-wow-RAND 58 | ppl_mcts/wow-google-flan-t5-xl-DecoderDisc-wow-RIPA 59 | fudge/wow-google-flan-t5-xl-fudge-DecoderDisc-wow-RIPA 60 | ) 61 | 62 | function run_dialog_eval (){ 63 | ids=$1 64 | name=$2 65 | skip=$3 66 | human=$4 67 | 68 | save_name="${name/"/"/-}" 69 | 70 | script="CUDA_VISIBLE_DEVICES=$ids python UniEval/run.py \ 71 | --task dialogue \ 72 | --generations_path generations/${name}.jsonl \ 73 | --dataset_path data/cached/wow/test_unseen.jsonl \ 74 | --save_name $save_name" 75 | if [[ $name == *"alpaca"* ]]; then 76 | script="$script --causal_lm_generations" 77 | fi 78 | 79 | if [[ $skip = 1 ]]; then 80 | script="$script --skip_no_knowledge" 81 | fi 82 | 83 | if [[ $human = 1 ]]; then 84 | script="$script --human_indices generations/wow_human_indices.txt" 85 | fi 86 | 87 | eval $script 88 | } 89 | 90 | for model in "${wow_model_name[@]}"; do 91 | run_dialog_eval $ids $model 1 0 92 | done 93 | -------------------------------------------------------------------------------- /scripts/shell/guided_run.sh: -------------------------------------------------------------------------------- 1 | ids=$1 2 | exp=$2 3 | metric=$3 4 | task=$4 5 | bs=$5 6 | V2=$6 7 | human=$7 8 | quick=$8 9 | cont=$9 10 | # compute number of gpus 11 | arrIDs=(${ids//,/ }) 12 | GPU_PER_NODE="${#arrIDs[@]}" 13 | 14 | # decide python launcher 15 | if [ $GPU_PER_NODE = 1 ]; then 16 | echo "Using 1 GPU: use simple python launcher..." 17 | launcher="CUDA_VISIBLE_DEVICES=$ids python" 18 | else 19 | echo "Using multi-GPU: using torchrun launcher..." 20 | launcher="CUDA_VISIBLE_DEVICES=$ids WORLD_SIZE=$GPU_PER_NODE torchrun --nproc_per_node $GPU_PER_NODE" 21 | fi 22 | 23 | DATADIR=data/cached 24 | OUTDIR=generations 25 | CKPTDIR=saved_models 26 | CKPTNAME=best 27 | 28 | if [[ $bs = '' ]]; then 29 | bs=16 30 | fi 31 | 32 | model_name=flan-t5-xl-DecoderDisc-$task-$metric 33 | disc_name=DecoderDisc-$task-$metric 34 | 35 | if [ $V2 = 1 ]; then 36 | model_name=$model_name-v2 37 | disc_name=$disc_name-v2 38 | fi 39 | 40 | if [ $task = 'wow' ]; then 41 | task_options="--test_data_path $DATADIR/wow/test_unseen.jsonl --dataset wow --max_new_tokens 32" 42 | elif [ $task = 'cnn_dailymail' ]; then 43 | task_options="--test_data_path cnn_dailymail --dataset cnn_dailymail --max_new_tokens 64" 44 | else 45 | echo $task not defined. 46 | exit 47 | fi 48 | 49 | elif [[ $exp = 'fudge' ]]; then 50 | 51 | script="$launcher scripts/run_guided_generation.py \ 52 | $task_options \ 53 | --use_kilt_format False \ 54 | --top_p 0.95 --top_k 50 --temperature 1.0 \ 55 | --model_name google/flan-t5-xl \ 56 | --guidance_method fudge \ 57 | --load_8bit \ 58 | --load_checkpoint $CKPTDIR/$model_name/checkpoint-$CKPTNAME/pytorch_model.bin \ 59 | --per_device_eval_batch_size $bs \ 60 | --output_dir $OUTDIR/fudge \ 61 | --disc_name $disc_name" 62 | 63 | elif [[ $exp = 'nado' ]]; then 64 | 65 | script="$launcher scripts/run_guided_generation.py \ 66 | $task_options \ 67 | --use_kilt_format False \ 68 | --top_p 0.95 --top_k 50 --temperature 1.0 \ 69 | --model_name google/flan-t5-xl \ 70 | --guidance_method nado \ 71 | --load_8bit True \ 72 | --alpha 0.25 \ 73 | --load_checkpoint $CKPTDIR/$model_name/checkpoint-$CKPTNAME/pytorch_model.bin \ 74 | --per_device_eval_batch_size $bs \ 75 | --output_dir $OUTDIR/nado \ 76 | --disc_name $disc_name" 77 | 78 | elif [[ $exp = 'astar' ]]; then 79 | # NOTE: this implmentation is very slow and infeasible. 80 | script="$launcher scripts/run_guided_generation.py \ 81 | $task_options \ 82 | --use_kilt_format False \ 83 | --top_p 0.95 --top_k 50 --temperature 1.0 \ 84 | --model_name google/flan-t5-xl \ 85 | --guidance_method astar \ 86 | --load_8bit \ 87 | --load_checkpoint $CKPTDIR/$model_name/checkpoint-$CKPTNAME/pytorch_model.bin \ 88 | --per_device_eval_batch_size $bs \ 89 | --output_dir $OUTDIR/astar \ 90 | --disc_name $disc_name" 91 | 92 | else 93 | echo not implemented $exp yet. 94 | exit 95 | fi 96 | 97 | if [ $V2 = 1 ]; then 98 | script="$script --v2" 99 | fi 100 | 101 | if [ $human = 1 ]; then 102 | script="$script --human_indices generations/${task}_human_indices.txt" 103 | fi 104 | 105 | if [[ $cont != '' ]]; then 106 | script="$script --continue_from $cont" 107 | fi 108 | 109 | if [ $quick = 1 ]; then 110 | script="$script --complete_after 10" 111 | fi 112 | 113 | eval $script 114 | -------------------------------------------------------------------------------- /scripts/shell/openai_guided_run.sh: -------------------------------------------------------------------------------- 1 | ids=$1 2 | type=$2 # [EOS, RIPA, ALL, RAND] 3 | bs=$3 4 | rootdir=$4 5 | CONT=$5 6 | human=$6 7 | guidance=$7 8 | use_t5=$8 9 | chatgpt=$9 10 | debug=${10} 11 | 12 | export CUDA_VISIBLE_DEVICES=$ids 13 | 14 | if [ $use_t5 = 1 ]; then 15 | model_name=google/flan-t5-xl 16 | load_name=flan-t5-xl 17 | else 18 | model_name=gpt2-xl 19 | load_name=gpt2-xl 20 | fi 21 | 22 | if [ $chatgpt = 1 ]; then 23 | openai_model_name=gpt-3.5-turbo 24 | else 25 | openai_model_name=text-davinci-003 26 | fi 27 | 28 | script="python scripts/run_openai_guided_generation.py \ 29 | --model_name $model_name \ 30 | --openai_model_name $openai_model_name \ 31 | --num_labels 2 \ 32 | --dataset wow \ 33 | --test_data_path $rootdir/data/cached/wow/test_unseen.jsonl \ 34 | --output_path $rootdir/generations/fudge \ 35 | --guidance_method openai_fudge \ 36 | --instruction_model basic \ 37 | --load_checkpoint $rootdir/saved_models/$load_name-DecoderDisc-wow-$type/checkpoint-best/pytorch_model.bin \ 38 | --batch_size $bs" 39 | 40 | disc_name=$openai_model_name-fudge-$type 41 | 42 | if [ $debug = 1 ]; then 43 | script="$script --mock_debug" 44 | fi 45 | 46 | if [ $human = 1 ]; then 47 | script="$script --human_indices $rootdir/generations/wow_human_indices.txt" 48 | fi 49 | 50 | if [[ $CONT != 0 ]]; then 51 | script="$script --continue_from $CONT" 52 | fi 53 | 54 | if [ $guidance = 1 ]; then 55 | script="$script --use_logit_bias True --propose_topk 50" 56 | disc_name=$disc_name-logit_bias 57 | elif [ $guidance = 2 ]; then 58 | script="$script --use_logit_bias False" 59 | disc_name=$disc_name-post_guidance 60 | elif [ $guidance = 3 ]; then 61 | script="$script --use_logit_bias True --propose_topk 50 --pre_post_guidance" 62 | disc_name=$disc_name-pre_post_guidance 63 | fi 64 | 65 | script="$script --disc_name $disc_name" 66 | 67 | eval $script 68 | -------------------------------------------------------------------------------- /scripts/shell/openai_mcts_run.sh: -------------------------------------------------------------------------------- 1 | ids=$1 2 | type=$2 # [EOS, RIPA, RIPA, ALL, RAND] 3 | bs=$3 4 | 5 | export CUDA_VISIBLE_DEVICES=$ids 6 | 7 | python scripts/run_openai_ppl_mcts.py \ 8 | --lm_name gpt2-xl \ 9 | --openai_model_name text-davinci-003 \ 10 | --num_labels 2 \ 11 | --attr_idx 1 \ 12 | --dataset wow \ 13 | --test_data_path data/cached/wow/test_unseen.jsonl \ 14 | --output_path generations/ppl_mcts \ 15 | --disc_name text-davinci-003-$type \ 16 | --instruction_model basic \ 17 | --load_checkpoint saved_models/gpt2-xl-DecoderDisc-wow-$type/checkpoint-best/pytorch_model.bin \ 18 | --batch_size $bs \ 19 | --max_num_gen 512 \ 20 | --num_simulations 20 \ 21 | --mock_debug 22 | -------------------------------------------------------------------------------- /scripts/shell/ppl_mcts_run.sh: -------------------------------------------------------------------------------- 1 | ids=$1 2 | typ=$2 3 | metric=$3 4 | task=$4 5 | bs=$5 6 | CLASS_ONLY=$6 7 | V2=$7 8 | CONT=$8 9 | human=$9 10 | quick_option=${10} 11 | 12 | if [ $bs = '' ]; then 13 | bs=16 14 | fi 15 | 16 | if [ $CONT = '' ]; then 17 | CONT=0 18 | fi 19 | 20 | DATADIR=data/cached 21 | OUTDIR=generations 22 | CKPTDIR=saved_models 23 | CKPTNAME=best 24 | 25 | if [ $task = 'wow' ]; then 26 | task_options="--test_data_path $DATADIR/wow/test_unseen.jsonl --dataset wow --max_new_tokens 32" 27 | elif [ $task = 'cnn_dailymail' ]; then 28 | task_options="--test_data_path cnn_dailymail --dataset cnn_dailymail --max_new_tokens 64" 29 | else 30 | echo $task not defined. 31 | exit 32 | fi 33 | 34 | model_name=flan-t5-xl-DecoderDisc-$task-$typ 35 | disc_name=DecoderDisc-$task-$typ 36 | 37 | if [ $V2 = 1 ]; then 38 | model_name=$model_name-v2 39 | disc_name=$disc_name-v2 40 | fi 41 | 42 | 43 | if [ $CLASS_ONLY = 1 ]; then 44 | ckpt_options="--use_mlp_classifier --load_classifier $CKPTDIR/$model_name-only_classifier/checkpoint-$CKPTNAME/pytorch_model.bin" 45 | else 46 | ckpt_options="--load_checkpoint $CKPTDIR/$model_name/checkpoint-$CKPTNAME/pytorch_model.bin" 47 | fi 48 | 49 | copy_penalty=1.0 50 | script="CUDA_VISIBLE_DEVICES=$ids python scripts/run_ppl_mcts.py \ 51 | $task_options \ 52 | --use_kilt_format False \ 53 | --lm_name google/flan-t5-xl \ 54 | --num_labels 2 --attr_idx 1 \ 55 | --load_8bit \ 56 | $ckpt_options \ 57 | --output_path $OUTDIR/ppl_mcts \ 58 | --batch_size $bs \ 59 | --num_simulations 50 \ 60 | --knowledge_copy_penalty $copy_penalty \ 61 | --top_k 50 \ 62 | --temperature 1.0 \ 63 | --continue_from $CONT" 64 | 65 | 66 | if [[ $metric != '' ]]; then 67 | script="$script --guide_using_metric True --metric_name $metric --disc_name $metric" 68 | else 69 | script="$script --disc_name $disc_name" 70 | fi 71 | 72 | if [ $V2 = 1 ]; then 73 | script="$script --v2" 74 | fi 75 | 76 | if [ $human = 1 ]; then 77 | script="$script --human_indices generations/${task}_human_indices.txt" 78 | fi 79 | 80 | if [[ $quick_option != 0 ]]; then 81 | # this will change everytime 82 | script="$script --complete_after $quick_option" 83 | fi 84 | 85 | eval $script 86 | -------------------------------------------------------------------------------- /scripts/shell/train/sft_t5.sh: -------------------------------------------------------------------------------- 1 | ##### stdin #### 2 | ids=$1 3 | ################ 4 | # compute number of gpus 5 | arrIDs=(${ids//,/ }) 6 | GPU_PER_NODE="${#arrIDs[@]}" 7 | 8 | # decide python launcher 9 | if [ $GPU_PER_NODE = 1 ]; then 10 | echo "Using 1 GPU: use simple python launcher..." 11 | launcher="CUDA_VISIBLE_DEVICES=$ids python" 12 | else 13 | echo "Using multi-GPU: using torchrun launcher..." 14 | launcher="CUDA_VISIBLE_DEVICES=$ids WORLD_SIZE=$GPU_PER_NODE torchrun --nproc_per_node $GPU_PER_NODE" 15 | fi 16 | 17 | dataset=wow 18 | use_kilt_format=False 19 | size=xl 20 | 21 | if [[ $use_kilt_format = True ]]; then 22 | model_name=flan-t5-$size-sft-$dataset-kilt 23 | else 24 | model_name=flan-t5-$size-sft-$dataset 25 | fi 26 | lr=1e-5 27 | bs=16 28 | grad_accum=2 29 | 30 | 31 | script="$launcher kcd/token_classifier/train.py \ 32 | --sft \ 33 | --model_name google/flan-t5-$size \ 34 | --is_decoder \ 35 | --wandb_project_name knowledge-sft \ 36 | --wandb_run_name $model_name \ 37 | --dataset $dataset \ 38 | --use_kilt_format $use_kilt_format \ 39 | --train_data_path data/cached/wow/train.jsonl \ 40 | --validation_data_path data/cached/wow/dev_unseen.jsonl \ 41 | --output_dir saved_models/$model_name \ 42 | --use_lora --bf16 --train_8bit \ 43 | --learning_rate $lr \ 44 | --warmup_steps 0 \ 45 | --weight_decay 0.01 \ 46 | --num_train_epochs 5 \ 47 | --max_steps 2000 \ 48 | --logging_steps 10 \ 49 | --eval_accumulation_steps 100 \ 50 | --eval_steps 500 \ 51 | --save_steps 500 \ 52 | --save_total_limit 2 \ 53 | --load_best_model_at_end False \ 54 | --per_device_train_batch_size $bs \ 55 | --per_device_eval_batch_size $bs \ 56 | --gradient_accumulation_steps $grad_accum" 57 | 58 | eval $script -------------------------------------------------------------------------------- /scripts/shell/train/train_t5_token_classifier.sh: -------------------------------------------------------------------------------- 1 | ##### stdin #### 2 | ids=$1 3 | type=$2 # [EOS, ALL, RAND, RIPA] 4 | ONLY_CLASS=$3 5 | V2=$4 6 | V2_REG=$5 7 | FINETUNE=$6 8 | ################ 9 | # compute number of gpus 10 | arrIDs=(${ids//,/ }) 11 | GPU_PER_NODE="${#arrIDs[@]}" 12 | 13 | # decide python launcher 14 | if [ $GPU_PER_NODE = 1 ]; then 15 | echo "Using 1 GPU: use simple python launcher..." 16 | launcher="CUDA_VISIBLE_DEVICES=$ids python" 17 | else 18 | echo "Using multi-GPU: using torchrun launcher..." 19 | launcher="CUDA_VISIBLE_DEVICES=$ids WORLD_SIZE=$GPU_PER_NODE torchrun --nproc_per_node $GPU_PER_NODE" 20 | fi 21 | 22 | DATADIR=data/cached 23 | CKPTDIR=saved_models 24 | CKPTNAME=best 25 | 26 | 27 | dataset=wow 28 | use_kilt_format=False 29 | size=xl 30 | 31 | train_data_path=$DATADIR/wow_train_augmented 32 | validation_data_path=$DATADIR/wow_dev_unseen_augmented 33 | 34 | if [[ $type = '' ]]; then 35 | echo type was not provided. Defaults to ALL... 36 | type=ALL 37 | fi 38 | if [[ $type = EOS ]]; then 39 | pool='last' 40 | elif [[ $type = RAND ]]; then 41 | pool='random' 42 | elif [[ $type = RIPA ]]; then 43 | pool='none' 44 | else 45 | # ALL 46 | pool='none' 47 | fi 48 | if [[ $use_kilt_format = True ]]; then 49 | model_name=flan-t5-$size-DecoderDisc-$dataset-kilt-$type 50 | else 51 | model_name=flan-t5-$size-DecoderDisc-$dataset-$type 52 | fi 53 | if [ $ONLY_CLASS = 1 ]; then 54 | model_name=$model_name-only_classifier 55 | fi 56 | if [ $V2 = 1 ]; then 57 | model_name=$model_name-v2 58 | if [[ $V2_REG != 0 ]]; then 59 | model_name=$model_name-v2reg$V2_REG 60 | fi 61 | fi 62 | 63 | lr=1e-5 64 | bs=8 65 | grad_check=False 66 | grad_accum=4 67 | 68 | 69 | script="$launcher kcd/token_classifier/train.py \ 70 | --model_name google/flan-t5-$size \ 71 | --is_decoder \ 72 | --num_labels 2 \ 73 | --pool_method $pool \ 74 | --wandb_project_name knowledge-classifier \ 75 | --wandb_run_name $model_name \ 76 | --dataset $dataset \ 77 | --use_kilt_format $use_kilt_format \ 78 | --train_data_path $train_data_path \ 79 | --validation_data_path $validation_data_path \ 80 | --output_dir $CKPTDIR/$model_name \ 81 | --bf16 --train_8bit \ 82 | --learning_rate $lr \ 83 | --warmup_steps 0 \ 84 | --weight_decay 0.01 \ 85 | --num_train_epochs 5 \ 86 | --max_steps 2000 \ 87 | --logging_steps 10 \ 88 | --eval_accumulation_steps 100 \ 89 | --eval_steps 500 \ 90 | --save_steps 500 \ 91 | --save_total_limit 2 \ 92 | --load_best_model_at_end False \ 93 | --per_device_train_batch_size $bs \ 94 | --per_device_eval_batch_size $bs \ 95 | --gradient_accumulation_steps $grad_accum \ 96 | --gradient_checkpointing $grad_check" 97 | 98 | if [[ $type = ALL ]]; then 99 | script="$script --sequence_label" 100 | fi 101 | 102 | 103 | if [ $ONLY_CLASS = 1 ]; then 104 | script="$script --only_classifier --use_mlp_classifier" 105 | else 106 | script="$script --use_lora" 107 | fi 108 | 109 | if [ $V2 = 1 ]; then 110 | # regularization default in NADO = 0.5 111 | script="$script --v2 --nado_reg $V2_REG" 112 | fi 113 | 114 | if [ $FINETUNE = 1 ]; then 115 | script="$script --load_checkpoint $CKPTDIR/flan-t5-$size-DecoderDisc-$dataset-EOS/checkpoint-$CKPTNAME/pytorch_model.bin" 116 | fi 117 | 118 | eval $script 119 | -------------------------------------------------------------------------------- /scripts/shell/train/train_t5_token_classifier_cnn.sh: -------------------------------------------------------------------------------- 1 | ##### stdin #### 2 | ids=$1 3 | type=$2 # [EOS, ALL, RAND, RIPA] 4 | ONLY_CLASS=$3 5 | V2=$4 6 | V2_REG=$5 7 | FINETUNE=$6 8 | ################ 9 | # compute number of gpus 10 | arrIDs=(${ids//,/ }) 11 | GPU_PER_NODE="${#arrIDs[@]}" 12 | 13 | launcher="CUDA_VISIBLE_DEVICES=$ids python" 14 | 15 | dataset=cnn_dailymail 16 | size=xl 17 | if [[ $type = '' ]]; then 18 | echo type was not provided. Defaults to ALL... 19 | type=ALL 20 | fi 21 | if [[ $type = EOS ]]; then 22 | pool='last' 23 | elif [[ $type = RAND ]]; then 24 | pool='random' 25 | else 26 | pool='none' 27 | fi 28 | 29 | 30 | model_name=flan-t5-$size-DecoderDisc-$dataset-$type 31 | train_data_path=data/cached/cnn_dailymail_train_augmented_neg_google-flan-t5-xl-0.9 32 | validation_data_path=data/cached/cnn_dailymail_test_augmented_neg_google-flan-t5-xl-0.1 33 | 34 | if [ $ONLY_CLASS = 1 ]; then 35 | model_name=$model_name-only_classifier 36 | fi 37 | 38 | if [ $V2 = 1 ]; then 39 | model_name=$model_name-v2 40 | if [[ $V2_REG != 0 ]]; then 41 | model_name=$model_name-v2reg$V2_REG 42 | fi 43 | fi 44 | 45 | lr=5e-6 46 | bs=8 47 | grad_check=False 48 | grad_accum=4 49 | 50 | 51 | script="$launcher kcd/token_classifier/train.py \ 52 | --model_name google/flan-t5-$size \ 53 | --is_decoder \ 54 | --num_labels 2 \ 55 | --pool_method $pool \ 56 | --wandb_project_name knowledge-classifier \ 57 | --wandb_run_name $model_name \ 58 | --dataset $dataset \ 59 | --use_kilt_format False \ 60 | --train_data_path $train_data_path \ 61 | --validation_data_path $validation_data_path \ 62 | --output_dir saved_models/$model_name \ 63 | --bf16 --train_8bit \ 64 | --learning_rate $lr \ 65 | --warmup_steps 0 \ 66 | --weight_decay 0.01 \ 67 | --num_train_epochs 5 \ 68 | --max_steps 2000 \ 69 | --logging_steps 10 \ 70 | --eval_accumulation_steps 50 \ 71 | --eval_steps 500 \ 72 | --save_steps 500 \ 73 | --save_total_limit 2 \ 74 | --load_best_model_at_end False \ 75 | --per_device_train_batch_size $bs \ 76 | --per_device_eval_batch_size $bs \ 77 | --gradient_accumulation_steps $grad_accum \ 78 | --gradient_checkpointing $grad_check" 79 | 80 | if [[ $type = ALL ]]; then 81 | script="$script --sequence_label" 82 | fi 83 | 84 | if [ $ONLY_CLASS = 1 ]; then 85 | script="$script --only_classifier --use_mlp_classifier" 86 | else 87 | script="$script --use_lora" 88 | fi 89 | 90 | if [ $V2 = 1 ]; then 91 | # regularization default in NADO = 0.5 92 | script="$script --v2 --nado_reg $V2_REG" 93 | fi 94 | 95 | if [ $FINETUNE = 1 ]; then 96 | script="$script --load_checkpoint saved_models/flan-t5-$size-DecoderDisc-$dataset-EOS/checkpoint-best/pytorch_model.bin" 97 | fi 98 | 99 | eval $script 100 | -------------------------------------------------------------------------------- /scripts/shell/train/train_token_classifier_gpt.sh: -------------------------------------------------------------------------------- 1 | ##### stdin #### 2 | ids=$1 3 | type=$2 # [EOS, ALL, RAND, RIPA] 4 | ONLY_CLASS=$3 5 | FINETUNE=$4 6 | ################ 7 | # compute number of gpus 8 | arrIDs=(${ids//,/ }) 9 | GPU_PER_NODE="${#arrIDs[@]}" 10 | 11 | # decide python launcher 12 | if [ $GPU_PER_NODE = 1 ]; then 13 | echo "Using 1 GPU: use simple python launcher..." 14 | launcher="CUDA_VISIBLE_DEVICES=$ids python" 15 | else 16 | echo "Using multi-GPU: using torchrun launcher..." 17 | launcher="CUDA_VISIBLE_DEVICES=$ids WORLD_SIZE=$GPU_PER_NODE torchrun --nproc_per_node $GPU_PER_NODE" 18 | fi 19 | 20 | DATADIR=data/cached 21 | CKPTDIR=saved_models 22 | CKPTNAME=best 23 | 24 | 25 | dataset=wow 26 | use_kilt_format=False 27 | size=xl 28 | 29 | train_data_path=$DATADIR/wow_train_augmented_neg_google-flan-t5-xl-0.9+random 30 | validation_data_path=$DATADIR/wow_test_augmented_neg_google-flan-t5-xl-0.1+random 31 | 32 | if [[ $type = '' ]]; then 33 | echo type was not provided. Defaults to ALL... 34 | type=ALL 35 | fi 36 | if [[ $type = EOS ]]; then 37 | pool='last' 38 | elif [[ $type = RAND ]]; then 39 | pool='random' 40 | elif [[ $type = RIPA ]]; then 41 | pool='none' 42 | else 43 | # ALL 44 | pool='none' 45 | fi 46 | if [[ $use_kilt_format = True ]]; then 47 | model_name=gpt2-$size-DecoderDisc-$dataset-kilt-$type 48 | else 49 | model_name=gpt2-$size-DecoderDisc-$dataset-$type 50 | fi 51 | if [ $ONLY_CLASS = 1 ]; then 52 | model_name=$model_name-only_classifier 53 | fi 54 | 55 | lr=1e-5 56 | bs=32 57 | grad_check=True 58 | grad_accum=1 59 | 60 | 61 | script="$launcher kcd/token_classifier/train.py \ 62 | --model_name gpt2-$size \ 63 | --num_labels 2 \ 64 | --is_decoder False \ 65 | --pool_method $pool \ 66 | --wandb_project_name knowledge-classifier \ 67 | --wandb_run_name $model_name \ 68 | --dataset $dataset \ 69 | --use_kilt_format $use_kilt_format \ 70 | --train_data_path $train_data_path \ 71 | --validation_data_path $validation_data_path \ 72 | --output_dir $CKPTDIR/$model_name \ 73 | --learning_rate $lr \ 74 | --warmup_steps 0 \ 75 | --weight_decay 0.01 \ 76 | --num_train_epochs 5 \ 77 | --max_steps 2000 \ 78 | --logging_steps 10 \ 79 | --eval_accumulation_steps 100 \ 80 | --eval_steps 500 \ 81 | --save_steps 500 \ 82 | --save_total_limit 2 \ 83 | --load_best_model_at_end False \ 84 | --per_device_train_batch_size $bs \ 85 | --per_device_eval_batch_size $bs \ 86 | --gradient_accumulation_steps $grad_accum \ 87 | --gradient_checkpointing $grad_check --bf16" 88 | 89 | if [[ $type = ALL ]]; then 90 | script="$script --sequence_label" 91 | fi 92 | 93 | 94 | if [ $ONLY_CLASS = 1 ]; then 95 | script="$script --only_classifier --use_mlp_classifier" 96 | else 97 | script="$script --use_lora" 98 | fi 99 | 100 | if [ $FINETUNE = 1 ]; then 101 | script="$script --load_checkpoint $CKPTDIR/gpt2-$size-DecoderDisc-$dataset-EOS/checkpoint-$CKPTNAME/pytorch_model.bin" 102 | fi 103 | 104 | eval $script 105 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import setuptools 9 | 10 | with open("README.md", "r") as fh: 11 | long_description = fh.read() 12 | 13 | setuptools.setup( 14 | name="kcd", 15 | version="0.1.0", 16 | description="Knowledge Constraint Decoding", 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | packages=setuptools.find_packages(), 20 | classifiers=[ 21 | "Programming Language :: Python :: 3", 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | ], 25 | python_requires=">=3.9", 26 | ) 27 | --------------------------------------------------------------------------------