├── README.md ├── assets ├── cognacgen.png └── teaser.png ├── requirements.txt └── src ├── __init__.py ├── data_utils.py ├── diverse_instructions.csv ├── diverse_instructions.py ├── experiment.py ├── guidance_models.py ├── guidance_utils.py ├── guide.py ├── guide_with_offset.py ├── lm_scorer.py ├── main.py ├── metric_tracker.py ├── metrics.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Cognac 2 | Repo for paper: [Cognac: Controllable Text Generation with Language Constraints](https://arxiv.org/abs/2212.10466) 3 | 4 | 5 | ## Overview 6 | 7 | 8 | 9 | 10 | We propose the Cognac task to stress test LMs ability to follow constraints. 11 | Green highlight specifies the topic to be covered. Red highlight specifies the constraint to conform to. GPT-3 generates continuation that mentioned a politician, thus violating the constraint. Our method, CognacGen, generates continuation that satisfies both the topic requirement and the constraint. 12 | 13 | ## Setup 14 | We use Python version 3.8. Install the dependencies with pip: 15 | ```bash 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | Download necessary resources: 20 | ```python 21 | python -m spacy download en_core_web_lg 22 | python -m nltk.downloader wordnet 23 | ``` 24 | 25 | ## Cognac Benchmark 26 | 27 | ### WordNet 28 | Download the WordNet data [here](https://drive.google.com/file/d/17mpi7fufaKVEvGdNMAsY_ZBit1FwER7z/view?usp=drive_link). The folder contains files `train.jsonl`, `dev.jsonl`, and `test.jsonl` that include instances of instructions with topics and constraints. The file `topic_to_leafs.json` contains the WordNet hierarchy (used to verify if the generation is conformant). The data is loaded in the code [here](https://github.com/princeton-nlp/Cognac/blob/main/src/utils.py#L239). 29 | 30 | ### Wikidata 31 | Coming soon... 32 | 33 | ## CognacGen 34 | 35 | 36 | 37 | The image above shows the step-by-step procdure for CognacGen to handle natural language instructions. 38 | 39 | Stage 1: the LM generates a list of guidance examples from the queries that specify the topic and constraint. During self-guidance distillation, the topic and constraint prefixes are tuned using the guidance example as target and the instruction with demonstrations as input. 40 | 41 | Stage 2: The guidance model (blue LM & the tuned prefixes) generates guidance examples from the test instance. The guidance examples are used to construct trie trees for both the topic (green) and the constraint (red). The generation (blue) LM’s next token probability is modified by the tries. 42 | 43 | ### Overall Structure 44 | The main run script is `main.py`. 45 | Some important hyperparameters are described below: 46 | - `eval_version`: determining if control code or natural language instruction is used as context 47 | - `guidance`: the combination of guidances to use. `in` means the inclusion of topic is applied. `ex` means the exclusion of constraint is applied. `wd` (weighted decoding) is used in CognacGen. Other options are also available such as 48 | - `guidance_model_type`: guidance type to use for constraint exclusion; `discrete` is "Textual Guidance", `full` is "Top-K Token", and `binary` is "Binary Verifier" described in the paper 49 | - `guidance_model_type_2`: guidance type to use for topic inclusion 50 | - `alpha`: strength of inclusion to apply on the logits during inference 51 | - `beta`: strength of exclusion to applu on the logits during inference 52 | 53 | ### Run Control Code Setting on WordNet 54 | 55 | ```python 56 | python -m src.main \ 57 | --name "your_run_name" \ 58 | --dataset_split "dev" \ 59 | --dev_path "./data/wordnet/dev.jsonl" \ 60 | --hierarchy_path "./data/wordnet/topic_to_leafs.json" \ 61 | --eval_version -2 \ 62 | --guidance "wd+ex+in" \ 63 | --guidance_model_name "gpt2-xl" \ 64 | --guidance_model_type "discrete" \ 65 | --guidance_model_type_2 "discrete" \ 66 | --discrete_max_length 200 \ 67 | --discrete_guidance_use_trie \ 68 | --alpha 100.0 \ 69 | --beta 5.0 \ 70 | --top_p 0.92 \ 71 | --temperature 0.7 72 | ``` 73 | Note that the control code setting applies only stage 2 and does not fine-tune the guidance model. 74 | 75 | ### Run Natural Language Instruction Setting on WordNet 76 | 77 | ```python 78 | python -m src.main \ 79 | --name "your_run_name" \ 80 | --discrete_guidance_instruct2guide_model_dir "path/to/your/prefix/tuned/model/folder" \ 81 | --dataset_split "dev" \ 82 | --dev_path "./data/wordnet/dev.jsonl" \ 83 | --hierarchy_path "./data/wordnet/topic_to_leafs.json" \ 84 | --eval_version -1 \ 85 | --guidance "wd+ex+in" \ 86 | --guidance_model_name "gpt2-xl" \ 87 | --guidance_model_type "discrete" \ 88 | --guidance_model_type_2 "discrete" \ 89 | --discrete_max_length 200 \ 90 | --discrete_guidance_use_trie \ 91 | --alpha 100.0 \ 92 | --beta 5.0 93 | ``` 94 | The prefix-tuned model in stage 1 can be downloaded [here](https://drive.google.com/file/d/1gTxwetdyK3X-IkUnw0FFdfPw1KyLI5DK/view?usp=drive_link). 95 | 96 | 97 | ### Run Control Code Setting on Wikidata 98 | Coming soon... 99 | 100 | ### Run Natural Language Instruction Setting on Wikidata 101 | Coming soon... 102 | 103 | ## Prefix-Tuning the Guidance Model 104 | Coming soon... 105 | 106 | ## Questions 107 | 108 | Please contact Howard Chen (`howardchen@cs.princeton.edu`) if you have any questions. 109 | 110 | ## Citation 111 | 112 | ```bibtex 113 | @inproceedings{chen2022cognac, 114 | title={{Cognac}: Controllable Text Generation with Language Constraints}, 115 | author={Chen, Howard and Li, Huihan and Chen, Danqi and Narasimhan, Karthik}, 116 | booktitle={arXiv}, 117 | year={2022} 118 | } 119 | ``` 120 | -------------------------------------------------------------------------------- /assets/cognacgen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Cognac/c1b304b884f76b667d2f76b325ecbadfb2d1c90c/assets/cognacgen.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Cognac/c1b304b884f76b667d2f76b325ecbadfb2d1c90c/assets/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml 2 | torch 3 | alive-progress 4 | transformers 5 | spacy 6 | nltk 7 | numpy 8 | rich -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Cognac/c1b304b884f76b667d2f76b325ecbadfb2d1c90c/src/__init__.py -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | """ 3 | import json 4 | import random 5 | from collections import defaultdict 6 | 7 | from tqdm import tqdm 8 | from rich import print 9 | 10 | import torch 11 | from nltk.tokenize import sent_tokenize 12 | 13 | 14 | def get_hierarchy(path): 15 | with open(path) as f: 16 | hierarchy_ = json.load(f) 17 | hierarchy = dict() 18 | for topic, leafs in hierarchy_.items(): 19 | new_topic = topic.replace('_', ' ') 20 | new_leafs = [l.replace('_', ' ') for l in leafs] 21 | hierarchy[new_topic] = new_leafs 22 | return hierarchy 23 | 24 | 25 | def load_datasets(data_paths, hierarchy, args): 26 | def process(datapoints, data_path): 27 | if 'pairs' in dataset_path: 28 | pass 29 | else: 30 | datapoints = [ 31 | d for d in datapoints 32 | if hierarchy is not None and 33 | d['constraint'] in hierarchy[d['topic']] and 34 | d['constraint'] != d['topic'] 35 | ] 36 | return datapoints 37 | 38 | datasets = dict() 39 | for dataset_name, dataset_path in data_paths.items(): 40 | if dataset_path is not None: 41 | datapoints = [] 42 | with open(dataset_path) as f: 43 | for line in f: 44 | datapoint = json.loads(line.strip()) 45 | datapoints.append(datapoint) 46 | datapoints = process(datapoints, dataset_path) 47 | 48 | if args.randomize_dataset: 49 | random.seed(1) 50 | random.shuffle(datapoints) 51 | datasets[dataset_name] = datapoints 52 | return datasets 53 | 54 | 55 | def get_guidance_data(args): 56 | if args.hierarchy_path is not None: 57 | hierarchy = get_hierarchy(args.hierarchy_path) 58 | else: 59 | hierarchy = None 60 | 61 | data_paths = { 62 | 'train': args.train_path, 63 | 'dev': args.dev_path, 64 | 'test': args.test_path, 65 | } 66 | datasets = load_datasets(data_paths, hierarchy, args) 67 | 68 | if args.wiki_gold is not None: 69 | gold = dict() 70 | with open('data/wordnet_wiki_clean.jsonl') as f: 71 | for line in f: 72 | obj = json.loads(line.strip()) 73 | gold[obj['node'].replace('_', ' ')] = obj 74 | else: 75 | gold = None 76 | 77 | print(f'Dataset loaded.') 78 | return datasets, hierarchy, gold 79 | 80 | 81 | def get_bow(words, tokenizer): 82 | inds = [] 83 | for word in words: 84 | inds += tokenizer.encode(word, add_prefix_space=True) 85 | inds = list(set(inds)) 86 | 87 | inds = torch.tensor(inds) 88 | bow = torch.zeros(1, len(tokenizer)) 89 | bow[:, inds] = 1.0 90 | 91 | return bow 92 | 93 | 94 | def regressor_batcher(datapoints, batch_size, max_length, **kwargs): 95 | hierarchy = kwargs.get('hierarchy', None) 96 | tokenizer = kwargs.get('tokenizer', None) 97 | use_cuda = kwargs.get('use_cuda', False) 98 | num_batches = len(datapoints) // batch_size + 1 99 | 100 | for i in range(num_batches): 101 | batch = datapoints[i * batch_size:(i + 1) * batch_size] 102 | 103 | if not batch: 104 | continue 105 | 106 | topic_texts = [f"Talk about {b['topic']}" for b in batch] 107 | topic_inputs = tokenizer.batch_encode_plus( 108 | topic_texts, 109 | padding='longest' 110 | ) 111 | constraint_texts = [f"Don't talk about {b['constraint']}" for b in batch] 112 | constraint_inputs = tokenizer.batch_encode_plus( 113 | constraint_texts, 114 | padding='longest' 115 | ) 116 | 117 | topics = [b['topic'] for b in batch] 118 | constraints = [b['constraint'] for b in batch] 119 | topic_targets = torch.cat([ 120 | get_bow(hierarchy.get(topic, []) + [topic], tokenizer) 121 | for topic in topics 122 | ], dim=0) 123 | constraint_targets = torch.cat([ 124 | get_bow(hierarchy.get(constraint, []) + [constraint], tokenizer) 125 | for constraint in constraints 126 | ], dim=0) 127 | 128 | if use_cuda: 129 | topic_targets = topic_targets.cuda() 130 | constraint_targets = constraint_targets.cuda() 131 | 132 | yield dict( 133 | ids=[b['id'] for b in batch], 134 | topic_inputs=topic_inputs, 135 | constraint_inputs=constraint_inputs, 136 | topic_targets=topic_targets, 137 | constraint_targets=constraint_targets, 138 | topics=topics, 139 | constraints=constraints, 140 | ) 141 | 142 | 143 | def classifier_batcher(datapoints, batch_size, max_length, **kwargs): 144 | hierarchy = kwargs.get('hierarchy', None) 145 | tokenizer = kwargs.get('tokenizer', None) 146 | args = kwargs.get('args', None) 147 | num_batches = len(datapoints) // batch_size + 1 148 | 149 | for i in range(num_batches): 150 | batch = datapoints[i * batch_size:(i + 1) * batch_size] 151 | if not batch: 152 | continue 153 | 154 | if args.model_type == 'classification': 155 | texts = [sent_tokenize(b['text'])[1] for b in batch] 156 | label_texts = [b['label'] for b in batch] 157 | labels = torch.tensor([1 if b['label'] == 'yes' else 0 for b in batch]).long() 158 | elif args.model_type in ('prompt-tune', 'fine-tune'): 159 | _YES = 3763 160 | _NO = 645 161 | texts = [b['text'] for b in batch] 162 | label_texts = [b['label'] for b in batch] 163 | labels = torch.tensor([_YES if b['label'] == 'yes' else _NO for b in batch]).long() 164 | else: 165 | raise ValueError(f'Unknown model type: {args.model_type}') 166 | categories = [b['category'] for b in batch] 167 | 168 | encoded = tokenizer.batch_encode_plus( 169 | texts, 170 | max_length=max_length, 171 | padding='max_length', 172 | ) 173 | input_ids = torch.tensor(encoded['input_ids']).long() 174 | attention_mask = torch.tensor(encoded['attention_mask']).long() 175 | 176 | input_ids = input_ids.cuda() 177 | attention_mask = attention_mask.cuda() 178 | labels = labels.cuda() 179 | 180 | yield dict( 181 | input_ids=input_ids, 182 | attention_mask=attention_mask, 183 | texts=texts, 184 | categories=categories, 185 | labels=labels, 186 | label_texts=label_texts, 187 | ) 188 | 189 | 190 | class Node: 191 | def __init__(self, name, parent, children): 192 | assert isinstance(parent, Node) or parent is None 193 | assert isinstance(children, list) 194 | self.name = name 195 | self.parent = parent 196 | self.children = children 197 | 198 | def __repr__(self): 199 | #parent = self.parent.name if self.parent is not None else None 200 | #children = [c.name for c in self.children] 201 | return f'Node(name={self.name})' 202 | 203 | 204 | class Hierarchy: 205 | def __init__(self): 206 | self.name_to_node = dict() 207 | self.load() 208 | 209 | def get_node(self, name): 210 | return self.name_to_node.get(name, None) 211 | 212 | def get_leafs(self, node): 213 | pass 214 | 215 | def attach_children_to_parent(self, parent_name, children_names): 216 | parent = self.get_node(parent_name) 217 | 218 | if parent is None: 219 | parent = Node(parent_name, None, []) 220 | self.name_to_node[parent_name] = parent 221 | 222 | for child_name in children_names: 223 | child = self.get_node(child_name) 224 | if child is None: 225 | child = Node(child_name, parent, []) 226 | self.name_to_node[child_name] = child 227 | elif child.parent is None: 228 | child.parent = parent 229 | else: 230 | continue 231 | parent.children.append(child) 232 | 233 | def load(self): 234 | with open('data/hierarchy_path_to_children.jsonl') as f: 235 | for line in f: 236 | obj = json.loads(line.strip()) 237 | hierarchy_path = [''] + obj['hierarchy_path'] 238 | for p, c in zip(hierarchy_path, hierarchy_path[1:]): 239 | self.attach_children_to_parent(p, [c]) 240 | parent = hierarchy_path[-1] 241 | children = obj['children'] 242 | self.attach_children_to_parent(parent, children) 243 | 244 | 245 | if __name__ == '__main__': 246 | """ 247 | python -m ci.data_utils 248 | """ 249 | from transformers import AutoTokenizer 250 | hierarchy = get_hierarchy('data/topic_to_leafs.json') 251 | tokenizer = AutoTokenizer.from_pretrained('gpt2-xl', use_fast=False) 252 | 253 | datapoints = [] 254 | with open('data/gpt2-xl/wiki_id_ood/sample/train.jsonl') as f: 255 | for line in f: 256 | datapoint = json.loads(line.strip()) 257 | datapoints.append(datapoint) 258 | print(len(datapoints)) 259 | datapoints = add_bow_to_datapoint(datapoints, hierarchy, tokenizer) 260 | print(len(datapoints)) 261 | -------------------------------------------------------------------------------- /src/diverse_instructions.csv: -------------------------------------------------------------------------------- 1 | index,topic,constraint,type,mode 2 | 0,Write down examples of {}.,Continue listing them but do not include examples of {}.,begin+end,train 3 | 1,Below we show examples of {1}.,"Following these examples, keep listing but don’t mention {0}.",begin,train 4 | 2,The above showed examples of {1}.,Now list more examples but don’t mention {0}.,end,train 5 | 3,Generate examples that are under the category of {}.,Now keep generating but exclude anything that's in the category of {}.,begin+end,dev 6 | 4,List out examples of {1}:,"Right after these examples, continue listing but avoid mentioning {0}.",begin,dev 7 | 5,The above are sentences describing {1}.,Now write similar sentences as the above while omitting any mention of {0}.,end,dev 8 | 6,Generate examples about {}:,Now generate more examples but refrain from generating {}.,begin+end,test 9 | 7,"Task: write down three examples of {1}.,",and after the three examples keep writing more but don’t refer to any {0}.,begin,test 10 | 8,The above sentences are examples of {1}.,Your task is to list more sentences but exclude mentions of {0}.,end,test 11 | 9,Make a list of {}.,Do not list any {}.,begin+end,test 12 | 10,Give me a list of {}.,Extend the list which does not pertain to {}.,begin+end,test 13 | 11,Listed below are examples of {1}.,"Keep going down this list, but don't include any {0}.",begin,test 14 | 12,Examples of {1} are shown below.,"List more examples, but don't mention {0} this time.",begin,test 15 | 13,There are examples of {1} below.,Write similar sentences after these that omits mention of {0}.,begin,test 16 | 14,The examples provided were of {1}.,provide additional examples without referring to the constraint {0}.,end,test 17 | 15,The examples given were of different instances of {1}.,give more examples that don't go against the original constraint {0}.,end,test 18 | 16,The examples illustrated were of {1}.,list more examples that don't bring up {0}.,end,test 19 | 17,The aforementioned showed examples of {1}.,Write similar sentences omitting any mention of {0}.,end,test 20 | 18,Create examples for {}.,"Keep generating ideas, but don't include anything that falls under the category of {}.",begin+end,test 21 | 19,Make up examples for {}.,"Keep coming up with ideas, but don't include anything that would fit under {}.",begin+end,test 22 | 20,Come up with examples for {}.,"Keep brainstorming, but don't write down anything that could be classified as {}.",begin+end,test 23 | 21,Provide a list of examples for {1}.,"After these examples, keep going but don't mention {0}.",begin,test 24 | 22,Give a list of examples for {1}.,"After these examples, continue listing but avoid mentioning {0}.",begin,test 25 | 23,Share a list of examples for {1}.,"Keep listing after these examples, but don't mention {0}.",begin,test 26 | 24,A list of examples for {1} is as follows:,"List more examples after these, but don't mention {0}.",begin,test 27 | 25,The sentences above describe {1}.,Write similar sentences without mentioning {0}.,end,test 28 | 26,The foregoing sentences describe {1}.,Write similar sentences while omitting mention of {0}.,end,test 29 | 27,These sentences describe {1}.,Write similar sentences that does not mention {0}.,end,test 30 | 28,Above are sentences of {1}.,Your task is to list more sentences without mentioning {0}.,end,test 31 | 29,The sentences below show instances of {1}.,Your goal is to list more instances that don't mention {0}.,end,test 32 | 30,The sentences below are examples of {1}.,Your task is to enumerate more of them which don't mention {0}.,end,test 33 | 31,The sentences below are examples of {1}.,Your task is to list more sentences that do not mention {0}.,end,test 34 | 32,Write down examples of {}.,"Generate more examples, but don't create any that include {}.",begin+end,test 35 | 33,"Given the topic {}, write down three examples:","Write down more examples, but don't mention {} in any of them.",begin+end,test 36 | 34,Come up with instances of the topic {}.,"Come up with more examples, but keep {} out of them.",begin+end,test 37 | -------------------------------------------------------------------------------- /src/diverse_instructions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Topic/constraint descriptions as the instruction. 3 | """ 4 | import csv 5 | import copy 6 | from src.utils import get_wikidata_p_text 7 | 8 | MAPPING = { 9 | -2: [ 10 | 'begin+end', 11 | 'dummy topic {}', 12 | 'dummy constraint {}', 13 | 'dummy', 14 | ] 15 | } 16 | with open('src/diverse_instructions.csv', 'r') as f: 17 | reader = csv.DictReader(f) 18 | for row in reader: 19 | index = int(row['index']) 20 | topic_inst = row['topic'].strip() 21 | constraint_inst = row['constraint'].strip() 22 | inst_type = row['type'] 23 | mode = row['mode'] 24 | 25 | if inst_type == 'begin+end': 26 | MAPPING[int(row['index'])] = [ 27 | inst_type, 28 | topic_inst, 29 | constraint_inst, 30 | mode, 31 | ] 32 | else: 33 | MAPPING[int(row['index'])] = [ 34 | inst_type, 35 | topic_inst + ' ' + constraint_inst, 36 | mode, 37 | ] 38 | 39 | 40 | def get_instruction(datapoint, args, version): 41 | if args.data == 'wordnet': 42 | topic = datapoint['topic'] 43 | constraint = datapoint['constraint'] 44 | insert_position, *instruction_templates, mode = MAPPING[version] 45 | if insert_position == 'begin+end': 46 | instructions = [ 47 | instruction_templates[0].format(topic), 48 | instruction_templates[1].format(constraint), 49 | ] 50 | elif insert_position == 'begin': 51 | instructions = [instruction_templates[0].format(constraint, topic), ""] 52 | elif insert_position == 'end': 53 | instructions = ["", instruction_templates[0].format(constraint, topic)] 54 | else: 55 | raise ValueError(f'`{insert_position}` not recognized.') 56 | 57 | elif args.data == 'wikidata': 58 | topic = datapoint['topic'] 59 | constraint = datapoint['constraint'] 60 | insert_position, *instruction_templates, mode = MAPPING[version] 61 | 62 | topic = get_wikidata_p_text(topic[0], topic[1]) 63 | constraint = get_wikidata_p_text(constraint[0], constraint[1]) 64 | 65 | if insert_position == 'begin+end': 66 | instructions = [ 67 | instruction_templates[0].format(topic), 68 | instruction_templates[1].format(constraint), 69 | ] 70 | elif insert_position == 'begin': 71 | instructions = [instruction_templates[0].format(constraint, topic), ""] 72 | elif insert_position == 'end': 73 | instructions = ["", instruction_templates[0].format(constraint, topic)] 74 | else: 75 | raise ValueError(f'`{insert_position}` not recognized.') 76 | 77 | return dict( 78 | begin=instructions[0], 79 | end=instructions[1], 80 | insert_position=insert_position, 81 | ) 82 | 83 | 84 | def prepare_context(datapoint, args, version): 85 | new_datapoint = copy.deepcopy(datapoint) 86 | blocks = new_datapoint['context'] 87 | blocks = ['== ' + sent.strip('==').strip() + ' ==' for sent in blocks.split('==\n')] 88 | 89 | instruction = get_instruction(new_datapoint, args, version) 90 | begin = instruction['begin'] 91 | end = instruction['end'] 92 | insert_position = instruction['insert_position'] 93 | blocks = [begin] + blocks + [end] 94 | context_with_instructions = '\n'.join(blocks).strip() 95 | 96 | new_datapoint['context_with_instructions'] = context_with_instructions 97 | new_datapoint['version'] = version 98 | new_datapoint['begin'] = begin 99 | new_datapoint['end'] = end 100 | new_datapoint['insert_position'] = insert_position 101 | return new_datapoint 102 | -------------------------------------------------------------------------------- /src/experiment.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import yaml 4 | import uuid 5 | import logging 6 | import argparse 7 | import tempfile 8 | from pathlib import Path 9 | 10 | 11 | class ExperimentManager: 12 | """ 13 | The only object needed to be instantiated throughout an experiment run. 14 | 15 | Functionalities: 16 | 1. Create run dir 17 | 2. Manage all logging 18 | a. Save command 19 | b. Save config/args 20 | c. Save predictions 21 | d. Save final eval stats 22 | 3. Manage all metric calculations 23 | """ 24 | def __init__(self, 25 | run_dir=None, 26 | name=None, 27 | override=True, 28 | num_runs=1, 29 | **kwargs 30 | ): 31 | """ 32 | Create run dir and setup file paths. 33 | By default, one run folder will be created that contains the predictions file. 34 | If `num_runs` > 1, then the amount of sub dirs will be created. 35 | The sub dirs will be named `seed_{run_id}`. 36 | """ 37 | self.num_runs = num_runs 38 | 39 | # Set up run dir path. 40 | if run_dir is None: 41 | run_dir = '/tmp/unnamed-experiments' 42 | if name is None: 43 | name = str(uuid.uuid4()).split('-')[-1] 44 | 45 | self.run_dir = Path(run_dir) / name 46 | 47 | self.config_path = self.run_dir / 'config.yaml' 48 | 49 | self.prediction_paths = [ 50 | self.run_dir / f'seed_{run_id}' / 'predictions.jsonl' 51 | for run_id in range(num_runs) 52 | ] 53 | 54 | self.agg_result_path = self.run_dir / 'results.yaml' 55 | self.result_paths = [ 56 | self.run_dir / f'seed_{run_id}' / 'results.yaml' 57 | for run_id in range(num_runs) 58 | ] 59 | 60 | # Create run dir and the needed paths. 61 | if not self.run_dir.exists() or override: 62 | self.run_dir.mkdir(parents=True, exist_ok=True) 63 | for run_id in range(num_runs): 64 | seed_run_dir = self.run_dir / f'seed_{run_id}' 65 | seed_run_dir.mkdir(parents=True, exist_ok=True) 66 | print(f'Run dir `{self.run_dir}` created or existed.') 67 | else: 68 | print(f'Run dir `{self.run_dir}` not saved or overriden.') 69 | 70 | # Set up file loggers. 71 | self.file_loggers = [] 72 | for run_id in range(self.num_runs): 73 | self.setup_logger(f'file_logger_{run_id}', run_id) 74 | self.file_loggers.append(logging.getLogger(f'file_logger_{run_id}')) 75 | 76 | # Set up other modules. 77 | self.metric_tracker = kwargs.get('metric_tracker', None) 78 | 79 | def setup_logger(self, logger_name, run_id=None): 80 | logger = logging.getLogger(logger_name) 81 | logger.setLevel(logging.INFO) 82 | formatter = logging.Formatter('%(message)s') 83 | handler = logging.FileHandler(self.prediction_paths[run_id], mode='w') 84 | handler.setFormatter(formatter) 85 | logger.addHandler(handler) 86 | logger.propagate = False 87 | 88 | def console(self, content): 89 | print(content) 90 | 91 | def log(self, content, run_id): 92 | assert run_id is not None, 'With `save=True`, `run_id` must be provided.' 93 | content_to_log = ( 94 | json.dumps(content) if isinstance(content, dict) else content 95 | ) 96 | self.file_loggers[run_id].info(content_to_log) 97 | 98 | def add_metric(self, content, run_id): 99 | assert self.metric_tracker is not None and run_id is not None, ( 100 | 'With `add_metric=True`, the `metric_tracker` cannnot be `None` and ' 101 | '`run_id` must be specified.') 102 | self.metric_tracker.add(content, run_id=run_id) 103 | 104 | def take(self, content, console=False, log=False, run_id=None, add_metric=False): 105 | """ 106 | To log to files, a `run_id` must be provided. 107 | `content` can be of type `str` or in most cases `dict`. 108 | 109 | With default arguments, nothing will be performed. 110 | The function is meant to be used when multiple uses are needed at once. 111 | """ 112 | if console: 113 | self.console(content) 114 | if log: 115 | self.log(content, run_id) 116 | if add_metric: 117 | self.add_metric(content, run_id) 118 | 119 | def save_results(self, results: dict, run_id=None): 120 | if run_id is None: 121 | with open(self.agg_result_path, 'w') as f: 122 | yaml.dump(results, f, default_flow_style=False) 123 | else: 124 | with open(self.result_paths[run_id], 'w') as f: 125 | yaml.dump(results, f, default_flow_style=False) 126 | 127 | def save_config(self, config): 128 | """ 129 | Save config and command used to execute the main script. 130 | """ 131 | if isinstance(config, argparse.Namespace): 132 | config = vars(config) 133 | 134 | with open(self.config_path, 'w') as f: 135 | yaml.dump(config, f) 136 | 137 | def compute_metric(self, metric_name, run_id): 138 | assert self.metric_tracker is not None 139 | return self.metric_tracker.compute(metric_name, run_id) 140 | 141 | def aggregate_metrics(self, avg_runs=True): 142 | assert self.metric_tracker is not None 143 | return self.metric_tracker.aggregate(avg_runs=avg_runs) 144 | 145 | -------------------------------------------------------------------------------- /src/guidance_models.py: -------------------------------------------------------------------------------- 1 | import re 2 | import random 3 | from pathlib import Path 4 | from ast import literal_eval 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from nltk.tokenize import sent_tokenize 11 | 12 | from src.guidance_utils import ( 13 | remove_parenthesis, 14 | get_wikidata_query, 15 | get_wordnet_query, 16 | ) 17 | from src.instruct2guide.utils import postprocess 18 | 19 | _YES = 3763 # gpt2 tokenizer 20 | _NO = 645 # gpt2 tokenizer 21 | _PAD = 50256 # gpt2 tokenizer 22 | 23 | random.seed(0) 24 | 25 | 26 | class BaseGuidanceModel(nn.Module): 27 | def __init__(self, guidance_lm, tokenizer, args, **kwargs): 28 | super().__init__() 29 | self.args = args 30 | self.tokenizer = tokenizer 31 | self.guidance_lm = guidance_lm 32 | self.hierarchy = kwargs.get('hierarchy', None) 33 | self.entity_to_trie = dict() 34 | 35 | # Track generation state and use the next token for guidance. 36 | self.guidance_prefix = { 37 | 'in': [], 38 | 'ex': [], 39 | } 40 | 41 | # Track which `example_id` is done and stop guidance. 42 | self.in_finished = set() 43 | 44 | def forward(self, entity, query_entity=None, **kwargs): 45 | if self.args.data == 'wordnet': 46 | query = get_wordnet_query( 47 | guidance_model_type=self.guidance_model_type, 48 | entity=entity, 49 | query_entity=query_entity, 50 | args=self.args, 51 | **kwargs, 52 | ) 53 | elif self.args.data == 'wikidata': 54 | query = get_wikidata_query( 55 | guidance_model_type=self.guidance_model_type, 56 | entity=entity, 57 | query_entity=query_entity, 58 | args=self.args, 59 | **kwargs, 60 | ) 61 | else: 62 | raise ValueError(f'Unknown data: {self.args.data}') 63 | 64 | query_tokens = self.tokenizer.encode(query) 65 | query_tokens = torch.tensor(query_tokens).long().unsqueeze(0).cuda() 66 | out = self.guidance_lm(query_tokens) 67 | return out.logits[:, -1, :] 68 | 69 | def query_guidance(self, entity, **kwargs): 70 | """ 71 | Returns indicies, loss, and info. 72 | """ 73 | raise NotImplementedError 74 | 75 | def get_trie_filtered_indicies( 76 | self, 77 | mode, 78 | words, 79 | entity, 80 | last_token, 81 | curr_context, 82 | example_id, 83 | ): 84 | if self.args.data == 'wikidata': 85 | entity = tuple(entity) 86 | # Debug ######################################################## 87 | #print('id:', example_id) 88 | #print('mode:', mode) 89 | #print('gen_examples:', words) 90 | #print('entity:', entity) 91 | #print('in_finished:', self.in_finished) 92 | ################################################################ 93 | if mode == 'in' and example_id in self.in_finished: 94 | return [] 95 | 96 | if entity not in self.entity_to_trie: 97 | trie = Trie() 98 | for word in words: 99 | name_tok_inds = self.tokenizer.encode(' ' + word) 100 | name_toks = [ 101 | self.tokenizer.decode(ind, skip_special_tokens=True) 102 | for ind in name_tok_inds 103 | ] 104 | 105 | trie.start_node_inds.add(name_tok_inds[0]) 106 | trie.insert(name_toks) 107 | trie.node_set.update(name_toks) 108 | self.entity_to_trie[entity] = trie 109 | 110 | # Get the trie corresponding to the current entity. 111 | trie = self.entity_to_trie[entity] 112 | last_token_text = self.tokenizer.decode(last_token.squeeze(0)) 113 | 114 | # Debug ############################################################# 115 | #curr_context_text = self.tokenizer.decode(curr_context.squeeze(0)) 116 | #print('last:', last_token_text) 117 | #print('curr context:', curr_context_text) 118 | #print('node_set:', trie.node_set) 119 | #print(f'gp before [{mode}]: {self.guidance_prefix}') 120 | ##################################################################### 121 | 122 | if last_token_text not in trie.node_set: 123 | # Debug ########################################## 124 | #print(f'[{mode}] here 1 -> [{last_token_text}]') 125 | ################################################## 126 | # Use the the start tokens for guidance. 127 | inds = list(trie.start_node_inds) 128 | else: 129 | # Debug ########################################## 130 | #print(f'[{mode}] here 2 -> [{last_token_text}]') 131 | ################################################## 132 | self.guidance_prefix[mode].append(last_token_text) 133 | query_prefix = self.guidance_prefix[mode] 134 | name_next_tok_index = len(query_prefix) 135 | 136 | name_candidates = trie.query(query_prefix) 137 | #print(f'Name candidates [{mode}]: {name_candidates}') 138 | if not name_candidates: 139 | # The prefix fails. Reset the tracker. 140 | inds = list(trie.start_node_inds) 141 | self.guidance_prefix[mode] = [] 142 | else: 143 | name_candidate_tok_texts = [ 144 | c[name_next_tok_index] for c in name_candidates 145 | if len(c) > name_next_tok_index 146 | ] 147 | #print(name_candidate_tok_texts) 148 | inds = [ 149 | self.tokenizer.encode(tok)[0] for tok in name_candidate_tok_texts 150 | ] 151 | name_candidates = {''.join(toks) for toks in name_candidates} 152 | hit = ''.join(query_prefix) in {''.join(toks) for toks in name_candidates} 153 | if hit: 154 | self.in_finished.add(example_id) 155 | self.guidance_prefix[mode] = [] 156 | # Debug ###################################################################### 157 | #print(f'gp after [{mode}]: {self.guidance_prefix}') 158 | #print([self.tokenizer.decode(ind) for ind in inds]) 159 | #print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 160 | #input() 161 | ############################################################################## 162 | return inds 163 | 164 | 165 | class BinaryGuidanceModel(BaseGuidanceModel): 166 | def __init__(self, guidance_lm, tokenizer, args, **kwargs): 167 | super().__init__(guidance_lm, tokenizer, args, **kwargs) 168 | self.guidance_model_type = 'binary' 169 | 170 | def query_guidance(self, entity, **kwargs): 171 | mode = kwargs.get('mode', None) 172 | query_entity = kwargs['query_entity'] 173 | orig_probs = kwargs['orig_probs'] 174 | if query_entity is None: 175 | return [], None, dict( 176 | yes_prob=-1.0, 177 | no_prob=-1.0, 178 | ) 179 | if self.args.guidance_model_name == 'oracle': 180 | if self.args.data == 'wordnet': 181 | oracle_words = self.hierarchy[entity] + [entity] 182 | oracle_plural_words = [w + 's' for w in oracle_words] 183 | all_oracle_words = set(oracle_words) | set(oracle_plural_words) 184 | normalized_yes_prob = 1.0 * (query_entity in all_oracle_words) 185 | elif self.args.data == 'wikidata': 186 | words = self.hierarchy[entity] 187 | words = [remove_parenthesis(title) for q_id, title in words] 188 | words = set([w.lower() for w in words]) 189 | normalized_yes_prob = 1.0 * (query_entity in words) 190 | else: 191 | raise ValueError(f'Data arg {self.args.data} not recognized.') 192 | yes_prob = normalized_yes_prob 193 | no_prob = 1.0 - normalized_yes_prob 194 | else: 195 | logits = self.forward(entity, query_entity) 196 | probs = F.softmax(logits, dim=-1) 197 | yes_prob = probs[:, _YES].squeeze(0).item() 198 | no_prob = probs[:, _NO].squeeze(0).item() 199 | normalized_yes_prob = yes_prob / (yes_prob + no_prob) 200 | 201 | if normalized_yes_prob > kwargs['threshold']: 202 | inds = torch.argmax(orig_probs, dim=1) 203 | loss = torch.log(orig_probs.max()) 204 | else: 205 | loss = None 206 | inds = [] 207 | 208 | return inds, loss, dict( 209 | yes_prob=yes_prob, 210 | no_prob=no_prob, 211 | query_entity=query_entity, 212 | entity=entity, 213 | ) 214 | 215 | 216 | class FullGuidanceModel(BaseGuidanceModel): 217 | """ 218 | Full guidance model operates with the next token's full probability, 219 | as opposed to binary guidance model only offers yes/no probability. 220 | """ 221 | def __init__(self, guidance_lm, tokenizer, args, **kwargs): 222 | super().__init__(guidance_lm, tokenizer, args, **kwargs) 223 | self.guidance_model_type = 'full' 224 | self.in_bow_vecs = None 225 | self.ex_bow_vecs = None 226 | self.in_set = None 227 | self.ex_set = None 228 | 229 | def query_guidance(self, entity, **kwargs): 230 | mode = kwargs.get('mode', None) 231 | query_entity = kwargs['query_entity'] 232 | orig_probs = kwargs['orig_probs'] 233 | 234 | if self.args.guidance_model_name == 'oracle': 235 | if mode == 'in': 236 | self.in_bow_vecs, self.in_set = self._init_bow_vecs(entity, **kwargs) 237 | bow_vecs = self.in_bow_vecs 238 | ind_set = self.in_set 239 | elif mode == 'ex': 240 | self.ex_bow_vecs, self.ex_set = self._init_bow_vecs(entity, **kwargs) 241 | bow_vecs = self.ex_bow_vecs 242 | ind_set = self.ex_set 243 | else: 244 | raise ValueError(f'`mode` {mode} not supported.') 245 | 246 | loss = 0.0 247 | for vec in bow_vecs: 248 | loss += torch.log(torch.mm(orig_probs, torch.t(vec)).sum()) 249 | inds = list(ind_set) 250 | else: 251 | logits = self.forward(entity, **kwargs) 252 | if mode == 'in': 253 | full_guide_topk = self.args.full_guide_topk_in 254 | elif mode == 'ex': 255 | full_guide_topk = self.args.full_guide_topk_ex 256 | else: 257 | raise ValueError(f'`mode` {mode} not supported.') 258 | 259 | _, inds = torch.topk(logits, k=full_guide_topk) 260 | inds = inds.squeeze(0) # 1xk -> k 261 | topk_probs = orig_probs[:, inds] 262 | loss = torch.log(topk_probs.sum()) 263 | 264 | ind_tokens = [self.tokenizer.decode(ind) for ind in inds] 265 | 266 | return inds, loss, dict( 267 | yes_prob=-1.0, 268 | no_prob=-1.0, 269 | query_entity=query_entity, 270 | entity=entity, 271 | ind_tokens=ind_tokens, 272 | ) 273 | 274 | def _init_bow_vecs(self, entity, **kwargs): 275 | mode = kwargs.get('mode', None) 276 | curr_context = kwargs.get('curr_context', None) 277 | if self.args.data == 'wordnet': 278 | if mode == 'in': 279 | words = self.hierarchy[entity] 280 | elif mode == 'ex': 281 | words = self.hierarchy[entity] + [entity] 282 | words = words + [w + 's' for w in words] 283 | elif self.args.data == 'wikidata': 284 | if mode == 'in': 285 | words = self.hierarchy[entity][:100] 286 | elif mode == 'ex': 287 | words = self.hierarchy[entity] 288 | else: 289 | raise ValueError(f'`{mode}` not recognized.') 290 | 291 | # If full length is used there's gonna be too many Q's. 292 | words = [remove_parenthesis(q_title) for q_id, q_title in words] 293 | else: 294 | raise ValueError(f'Data arg {self.args.data} not recognized.') 295 | 296 | inds = self.get_trie_filtered_indicies( 297 | mode, 298 | words, 299 | entity, 300 | kwargs['last_token'], 301 | kwargs['curr_context'], 302 | kwargs['example_id'], 303 | ) 304 | 305 | vecs = [] 306 | if inds: 307 | inds = torch.tensor(inds) 308 | onehot = torch.zeros(1, len(self.tokenizer)) 309 | onehot[:, inds] = 1.0 310 | vecs.append(onehot.cuda()) 311 | return vecs, set(inds) 312 | 313 | 314 | class DiscreteGuidanceModel(BaseGuidanceModel): 315 | def __init__(self, guidance_lm, tokenizer, args, **kwargs): 316 | super().__init__(guidance_lm, tokenizer, args, **kwargs) 317 | self.guidance_model_type = 'discrete' 318 | # NOTE: this might cause memory leak. This might be largely fine 319 | # given the expected number of new entity. 320 | self.enity_to_examples = dict() 321 | 322 | def forward_prefix(self, entity, **kwargs): 323 | datapoint = kwargs['datapoint'] 324 | if self.args.data == 'wikidata': 325 | entity = tuple(entity) 326 | #gen_examples = self.enity_to_examples.get(entity, None) 327 | gen_examples = self.enity_to_examples.get((entity, datapoint['version']), None) 328 | if gen_examples is not None: 329 | return dict(gen_examples=gen_examples) 330 | 331 | datapoint = kwargs['datapoint'] 332 | text = datapoint['context_with_instructions'] + ' [SEP]' 333 | 334 | if kwargs['mode'] == 'in': 335 | mode = 'topic' 336 | elif kwargs['mode'] == 'ex': 337 | mode = 'constraint' 338 | else: 339 | raise ValueError(f'mode={mode} not recognized.') 340 | 341 | gen_text = self.guidance_lm.generate(text, mode=mode) 342 | gen_examples = postprocess(gen_text, data_mode=self.args.data) 343 | #self.enity_to_examples[entity] = gen_examples 344 | self.enity_to_examples[(entity, datapoint['version'])] = gen_examples 345 | return dict( 346 | gen_examples=gen_examples, 347 | gen_texts=[gen_text], 348 | ) 349 | 350 | def forward(self, entity, **kwargs): 351 | datapoint = kwargs['datapoint'] 352 | if self.args.data == 'wikidata': 353 | entity = tuple(entity) 354 | #gen_examples = self.enity_to_examples.get(entity, None) 355 | gen_examples = self.enity_to_examples.get((entity, datapoint['version']), None) 356 | if gen_examples is not None: 357 | return dict(gen_examples=gen_examples) 358 | 359 | if self.args.data == 'wordnet': 360 | query = get_wordnet_query( 361 | guidance_model_type=self.guidance_model_type, 362 | entity=entity, 363 | args=self.args, 364 | **kwargs, 365 | ) 366 | elif self.args.data == 'wikidata': 367 | query = get_wikidata_query( 368 | guidance_model_type='discrete', 369 | entity=entity, 370 | **kwargs, 371 | ) 372 | else: 373 | raise ValueError(f'{self.args.data} not recognized.') 374 | 375 | query_tokens = self.tokenizer.encode(query) 376 | query_tokens = torch.tensor(query_tokens).long().unsqueeze(0).cuda() 377 | 378 | outputs = self.guidance_lm.generate( 379 | query_tokens, 380 | max_length=self.args.discrete_max_length, 381 | #skip_special_tokens=True, 382 | pad_token_id=self.tokenizer.eos_token_id, 383 | no_repeat_ngram_size=2, 384 | do_sample=self.args.discrete_guidance_do_sample, 385 | num_beams=self.args.discrete_guidance_num_beams, 386 | num_beam_groups=self.args.discrete_guidance_num_beam_groups, 387 | top_p=self.args.discrete_guidance_top_p, 388 | top_k=self.args.discrete_guidance_top_k, 389 | temperature=self.args.discrete_guidance_temperature, 390 | diversity_penalty=self.args.discrete_guidance_diversity_penalty, 391 | num_return_sequences=self.args.discrete_guidance_num_return_sequences or 1, 392 | ) 393 | gen_examples = [] 394 | gen_texts = [] 395 | pattern = 'Some examples are:' 396 | for output in outputs: 397 | gen_text = self.tokenizer.decode(output, skip_special_tokens=True) 398 | gen_texts.append(gen_text) 399 | gen_sents = sent_tokenize(gen_text) 400 | gen_sents = [sent for sent in gen_sents if sent.startswith(pattern)] 401 | gen = gen_sents[1].replace(pattern, '').strip().strip('.').split(', ') 402 | gen_examples += gen 403 | gen_examples = list(set(gen_examples)) 404 | #self.enity_to_examples[entity] = gen_examples 405 | self.enity_to_examples[(entity, datapoint['version'])] = gen_examples 406 | return dict( 407 | gen_examples=gen_examples, 408 | gen_texts=gen_texts, 409 | ) 410 | 411 | def query_guidance_with_loss(self, entity, **kwargs): 412 | mode = kwargs.get('mode', None) 413 | gen_examples = self.forward(entity) 414 | 415 | if mode == 'in': 416 | self.in_bow_vecs, self.in_set = self._init_bow_vecs(gen_examples) 417 | bow_vecs = self.in_bow_vecs 418 | elif mode == 'ex': 419 | self.ex_bow_vecs, self.ex_set = self._init_bow_vecs(gen_examples) 420 | bow_vecs = self.ex_bow_vecs 421 | else: 422 | raise ValueError(f'`mode` {mode} not supported.') 423 | 424 | loss = 0.0 425 | for vec in bow_vecs: 426 | loss += torch.log(torch.mm(next_token_prob, torch.t(vec)).sum()) 427 | 428 | return dict( 429 | loss=loss, 430 | yes_prob=-1.0, 431 | no_prob=-1.0, 432 | gen_examples=gen_examples, 433 | query_entity=query_entity, 434 | entity=entity, 435 | ) 436 | 437 | def query_guidance(self, entity, **kwargs): 438 | mode = kwargs.get('mode', None) 439 | if self.args.discrete_guidance_instruct2guide_model_dir is not None: 440 | gen_out = self.forward_prefix(entity, **kwargs) 441 | else: 442 | gen_out = self.forward(entity, **kwargs) 443 | 444 | inds = self._init_bow_vecs( 445 | gen_out['gen_examples'], 446 | entity=entity, 447 | return_only_inds=True, 448 | **kwargs 449 | ) 450 | inds = list(inds) 451 | gen_info = dict( 452 | gen_examples=gen_out['gen_examples'], 453 | entity=entity, 454 | ) 455 | return inds, None, gen_info 456 | 457 | def _init_bow_vecs(self, words, entity, return_only_inds=False, **kwargs): 458 | if self.args.data == 'wordnet': 459 | words = words + [w + 's' for w in words] 460 | 461 | # Use trie. 462 | if self.args.discrete_guidance_use_trie: 463 | inds = self.get_trie_filtered_indicies( 464 | kwargs['mode'], 465 | words, 466 | entity, 467 | kwargs['last_token'], 468 | kwargs['curr_context'], 469 | kwargs['example_id'], 470 | ) 471 | if return_only_inds: 472 | return set(inds) 473 | else: 474 | bow_indices = [ 475 | self.tokenizer.encode(word.strip(), add_prefix_space=True) 476 | for word in words 477 | ] 478 | bow_set = set([ind for inds in bow_indices for ind in inds]) 479 | if return_only_inds: 480 | return bow_set 481 | 482 | 483 | class OracleGuidanceModel(BaseGuidanceModel): 484 | def __init__(self, guidance_lm, tokenizer, args, **kwargs): 485 | super().__init__(guidance_lm, tokenizer, args, **kwargs) 486 | self.guidance_model_type = 'oracle' 487 | 488 | def query_guidance(self, entity, **kwargs): 489 | mode = kwargs.get('mode', None) 490 | 491 | if self.args.data == 'wordnet': 492 | if mode == 'in': 493 | words = self.hierarchy[entity] 494 | elif mode == 'ex': 495 | words = self.hierarchy[entity] + [entity] 496 | else: 497 | raise ValueError(f'`{mode}` not recognized.') 498 | elif self.args.data == 'wikidata': 499 | if mode == 'in': 500 | words = self.hierarchy[entity][:100] 501 | elif mode == 'ex': 502 | words = self.hierarchy[entity][:1000] 503 | else: 504 | raise ValueError(f'`{mode}` not recognized.') 505 | # If full length is used there's gonna be too many Q's. 506 | words = [remove_parenthesis(q_title) for q_id, q_title in words] 507 | else: 508 | raise ValueError(f'Data arg {self.args.data} not recognized.') 509 | 510 | inds = self._init_bow_vecs( 511 | words, 512 | entity=entity, 513 | return_only_inds=True, 514 | **kwargs 515 | ) 516 | inds = list(inds) 517 | gen_info = dict( 518 | entity=entity, 519 | ) 520 | #print(mode) 521 | #print(entity) 522 | #print(words) 523 | #print([self.tokenizer.decode(ind) for ind in inds]) 524 | #print(self.guidance_prefix) 525 | #print('---') 526 | #input() 527 | return inds, None, gen_info 528 | 529 | def _init_bow_vecs(self, words, entity, return_only_inds=False, **kwargs): 530 | # Use trie. 531 | if self.args.discrete_guidance_use_trie: 532 | inds = self.get_trie_filtered_indicies( 533 | kwargs['mode'], 534 | words, 535 | entity, 536 | kwargs['last_token'], 537 | kwargs['curr_context'], 538 | kwargs['example_id'], 539 | ) 540 | if return_only_inds: 541 | return set(inds) 542 | else: 543 | bow_indices = [ 544 | self.tokenizer.encode(word.strip(), add_prefix_space=True) 545 | for word in words 546 | ] 547 | bow_set = set([ind for inds in bow_indices for ind in inds]) 548 | if return_only_inds: 549 | return bow_set 550 | 551 | 552 | class TrieNode: 553 | """A node in the trie structure""" 554 | 555 | def __init__(self, tok): 556 | self.tok = tok 557 | 558 | self.is_end = False 559 | 560 | # A counter indicating how many times a word 561 | # is inserted (if this node's is_end is True). 562 | self.counter = 0 563 | 564 | # A dictionary of child nodes. Keys are tokens, values are nodes. 565 | self.children = {} 566 | 567 | 568 | class Trie: 569 | """The trie object""" 570 | 571 | def __init__(self): 572 | """ 573 | The trie has at least the root node. 574 | The root node does not store any character 575 | """ 576 | self.root = TrieNode('') 577 | self.node_set = set() 578 | self.start_node_inds = set() 579 | 580 | def insert(self, word): 581 | """Insert a word into the trie""" 582 | node = self.root 583 | 584 | # Loop through each token in the word. 585 | # Check if there is no child containing the character, 586 | # create a new child for the current node. 587 | for tok in word: 588 | if tok in node.children: 589 | node = node.children[tok] 590 | else: 591 | # If a token is not found, create a new node in the trie. 592 | new_node = TrieNode(tok) 593 | node.children[tok] = new_node 594 | node = new_node 595 | 596 | node.is_end = True 597 | 598 | # Increment the counter to indicate that we see this word once more. 599 | node.counter += 1 600 | 601 | def dfs(self, node, prefix): 602 | """ 603 | Depth-first traversal of the trie. 604 | 605 | Args: 606 | - node: the node to start with. 607 | - prefix: the current prefix, for tracing a word while traversing. 608 | """ 609 | if node.is_end: 610 | self.output.append(prefix + [node.tok]) 611 | 612 | for child in node.children.values(): 613 | self.dfs(child, prefix + [node.tok]) 614 | 615 | def query(self, x): 616 | """Given an input (a prefix), retrieve all words stored in 617 | the trie with that prefix, sort the words by the number of 618 | times they have been inserted 619 | """ 620 | # Use a variable within the class to keep all possible outputs 621 | # as there can be more than one word with such prefix. 622 | self.output = [] 623 | node = self.root 624 | 625 | # Check if the prefix is in the trie. 626 | for tok in x: 627 | if tok in node.children: 628 | node = node.children[tok] 629 | else: 630 | return [] 631 | 632 | # Traverse the trie to get all candidates. 633 | self.dfs(node, x[:-1]) 634 | 635 | # Sort the results in reverse order and return. 636 | return self.output 637 | -------------------------------------------------------------------------------- /src/guidance_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities relevant to the guidance models. 3 | """ 4 | import re 5 | from pathlib import Path 6 | 7 | import torch 8 | import spacy 9 | 10 | from src.utils import ( 11 | to_namespace, 12 | get_data, 13 | get_lm, 14 | get_tokenizer, 15 | cleanup_gen_text, 16 | reformat_text, 17 | load_ckpt, 18 | ) 19 | 20 | SMALL_CONST = 1e-8 21 | PROMPT_TEXT = ( 22 | # 'List out some examples of us presidents. Some examples are: lincoln, obama, washignton.\n' 23 | # 'List out some examples of color. Some examples are: red, blue, green.\n' 24 | # 'List out some examples of cities. Some examples are: new york, oslo, tokyo.\n' 25 | 'List out some examples of movie genres. Some examples are: drama, horror, action.\n' 26 | ) 27 | PROMPT_TEXT_TEMPLATE = { 28 | 'animal': 'List out some examples of animals. Some examples are: cat, dog, lion.\n', 29 | 'food': 'List out some examples of food. Some examples are: pizza, burger, pasta.\n', 30 | 'vehicle': 'List out some examples of vehicles. Some examples are: car, bus, train.\n', 31 | 'art': 'List out some examples of art. Some examples are: painting, sculpture, drawing.\n', 32 | 'sport': 'List out some examples of sports. Some examples are: soccer, basketball, tennis.\n', 33 | } 34 | 35 | spacy_parser = spacy.load("en_core_web_lg") 36 | 37 | 38 | def remove_parenthesis(s): 39 | m = re.search(r'(.*) \(.*\)', s) 40 | if m is None: 41 | return s 42 | else: 43 | return m.group(1) 44 | 45 | 46 | def get_wikidata_query( 47 | guidance_model_type, 48 | entity, 49 | query_entity=None, 50 | **kwargs 51 | ): 52 | p_name, p_value = entity 53 | 54 | if guidance_model_type == 'binary': 55 | if p_name == 'place_of_birth': 56 | prompt = 'Was {0} born in {1}? The answer is' 57 | elif p_name == 'place_of_death': 58 | prompt = 'Did {0} die in {1}? The answer is' 59 | elif p_name == 'occupation': 60 | prompt = 'Was {0} a {1}? The answer is' 61 | elif p_name == 'country_of_citizenship': 62 | prompt = 'Was {0} a citizen of {1}? The answer is' 63 | elif p_name == 'academic_degree': 64 | prompt = 'Did {0} hold a degree in {1}? The answer is' 65 | elif p_name == 'educated_at': 66 | prompt = 'Did {0} get the education at {1}? The answer is' 67 | else: 68 | raise ValueError(f'Property name `{p_name}` not recognized.') 69 | query = prompt.format(query_entity, p_value) 70 | return query 71 | 72 | elif guidance_model_type in ('full', 'discrete'): 73 | #QUERY_FORMAT = 'List out some famous names of {0}. Some examples are:' 74 | #QUERY_FORMAT = 'List out some famous full names of passed {0}. Some examples are:' 75 | #QUERY_FORMAT = 'List out some historical names of {0}. Some examples are:' 76 | QUERY_FORMAT = 'List out some famous names of dead {0}. Some examples are:' 77 | if p_name == 'place_of_birth': 78 | fill_text = f'people who were born in {p_value}' 79 | prompt_text = f'List out some famous names of dead people who has traveled to France. Some examples are: Ernest Hemingway, Miles Davis, Oscar Wilde.\n' 80 | elif p_name == 'place_of_death': 81 | fill_text = f'people who died in {p_value}' 82 | prompt_text = f'List out some famous names of dead people who has traveled to France. Some examples are: Ernest Hemingway, Miles Davis, Oscar Wilde.\n' 83 | elif p_name == 'occupation': 84 | fill_text = p_value 85 | prompt_text = f'List out some famous names of dead people who were tech CEOs. Some examples are: Steve Jobs, Mark Hurd, Bill Campbell.\n' 86 | elif p_name == 'country_of_citizenship': 87 | fill_text = f'people who are citizens of {p_value}' 88 | prompt_text = f'List out some famous names of dead people who has traveled to France. Some examples are: Ernest Hemingway, Miles Davis, Oscar Wilde.\n' 89 | elif p_name == 'academic_degree': 90 | fill_text = f'people who hold a degree in {p_value}' 91 | prompt_text = f'List out some famous names of dead people who has traveled to France. Some examples are: Ernest Hemingway, Miles Davis, Oscar Wilde.\n' 92 | elif p_name == 'educated_at': 93 | #fill_text = f'people who had their education at {p_value}' 94 | fill_text = f'people who were educatied at {p_value}' 95 | prompt_text = f'List out some famous names of dead people who has traveled to France. Some examples are: Ernest Hemingway, Miles Davis, Oscar Wilde.\n' 96 | else: 97 | raise ValueError(f'Property name `{p_name}` not recognized.') 98 | query = QUERY_FORMAT.format(fill_text) 99 | #prompt_text = kwargs.get('prompt_text', PROMPT_TEXT) 100 | return (prompt_text + query).strip() 101 | 102 | else: 103 | raise ValueError(f'Guidance model type {guidance_model_type} not recognized.') 104 | 105 | 106 | def get_wordnet_query( 107 | guidance_model_type, 108 | entity, 109 | query_entity=None, 110 | **kwargs, 111 | ): 112 | args = kwargs.get('args', None) 113 | if guidance_model_type == 'binary': 114 | MAGIC_PROMPT = 'I am an expert in taxonomy.' 115 | QUERY_FORMAT = 'Is {0} a type of {1}? The answer is{2}' 116 | query = [MAGIC_PROMPT] 117 | if args.num_icl_pairs > 0: 118 | pos_icl_pairs, neg_icl_pairs, noisy_queries = load_icl_pairs(args) 119 | icl_text = get_icl_query( 120 | args.num_icl_pairs, 121 | pos_icl_pairs, 122 | neg_icl_pairs, 123 | noisy_queries, 124 | ) 125 | else: 126 | icl_text = [] 127 | query += icl_text 128 | query.append(QUERY_FORMAT.format(query_entity, entity, '')) 129 | query = ' '.join(query).strip() 130 | return query 131 | 132 | elif guidance_model_type in ('full', 'discrete'): 133 | root_node = kwargs['datapoint']['parents'][0][0] 134 | prompt_text = PROMPT_TEXT_TEMPLATE.get(root_node, PROMPT_TEXT) 135 | 136 | QUERY_FORMAT = 'What are some examples of {0}? Some examples are:' 137 | query = QUERY_FORMAT.format(entity) 138 | return (prompt_text + query).strip() 139 | 140 | else: 141 | raise ValueError(f'Guidance model type {guidance_model_type} not recognized.') 142 | 143 | 144 | def load_icl_pairs(args): 145 | """ 146 | For in-context learning. 147 | """ 148 | pos_icl_pairs = [] 149 | neg_icl_pairs = [] 150 | roots = ['animal', 'food', 'sport', 'art', 'vehicle'] 151 | for root in roots: 152 | path = Path(args.icl_pair_dir) / f'train_pairs_{root}_new.txt' 153 | with open(path) as f: 154 | for line in f: 155 | cat, pos, neg = literal_eval(line.strip()) 156 | cat = cat.replace('_', ' ') 157 | pos = pos.replace('_', ' ') 158 | neg = neg.replace('_', ' ') 159 | pos_icl_pairs.append((pos, cat, root)) 160 | neg_icl_pairs.append((neg, cat, root)) 161 | noisy_queries = [] 162 | with open('/path/to/your/noisy_queries.txt') as f: 163 | for line in f: 164 | noisy_queries.append(line.strip()) 165 | return pos_icl_pairs, neg_icl_pairs, noisy_queries 166 | 167 | 168 | def get_icl_query(num_icl_pairs, pos_icl_pairs, neg_icl_pairs, noisy_queries): 169 | pos_icl_pairs = random.choices(pos_icl_pairs, k=num_icl_pairs) 170 | neg_icl_pairs = random.choices(neg_icl_pairs, k=num_icl_pairs) 171 | noisy_icl = random.choices(noisy_queries, k=num_icl_pairs) 172 | 173 | icl = [] 174 | for p, n, noisy_q in zip(pos_icl_pairs, neg_icl_pairs, noisy_icl): 175 | icl += [ 176 | QUERY_FORMAT.format(p[0], p[1], ' yes.'), 177 | QUERY_FORMAT.format(n[0], n[1], ' no.'), 178 | # QUERY_FORMAT.format(noisy_q, p[1], ' no.'), 179 | # QUERY_FORMAT.format(noisy_q, n[1], ' no.'), 180 | ] 181 | random.shuffle(icl) 182 | return icl 183 | 184 | 185 | def get_guidance_model( 186 | args, 187 | tokenizer, 188 | hierarchy, 189 | num_devices=None, 190 | guidance_lm=None, 191 | ): 192 | from src.guidance_models import ( 193 | BinaryGuidanceModel, 194 | FullGuidanceModel, 195 | DiscreteGuidanceModel, 196 | OracleGuidanceModel, 197 | ) 198 | guidance_args = dict( 199 | tokenizer=tokenizer, 200 | args=args, 201 | hierarchy=hierarchy, 202 | ) 203 | guidance_model_type = args.guidance_model_type 204 | 205 | # Load the core guidance LM (trained, un-trained, or oracle). 206 | if guidance_lm is not None: 207 | # Use the provided guidance_lm from the args. 208 | guidance_model_type = args.guidance_model_type_2 209 | elif args.guidance_model_path: 210 | # Load fine-tuned/prompt-tuned guidance models. 211 | guidance_lm, _ = load_ckpt(load_path=args.guidance_model_path) 212 | elif args.guidance_model_name == 'oracle': 213 | guidance_lm = None 214 | elif args.discrete_guidance_instruct2guide_model_dir is not None: 215 | assert args.guidance_model_type == 'discrete' 216 | from src.instruct2guide.utils import load_checkpoint 217 | guidance_lm = load_checkpoint( 218 | args.discrete_guidance_instruct2guide_model_dir 219 | ) 220 | else: 221 | guidance_lm = get_lm(args, num_devices, load_mode='guidance') 222 | 223 | # Construct the complete guidance model 224 | if args.guidance_model_path: 225 | guidance_model_class = BinaryGuidanceModel 226 | elif guidance_model_type == 'binary': 227 | guidance_model_class = BinaryGuidanceModel 228 | elif guidance_model_type == 'full': 229 | guidance_model_class = FullGuidanceModel 230 | elif guidance_model_type == 'discrete': 231 | guidance_model_class = DiscreteGuidanceModel 232 | elif guidance_model_type == 'oracle': 233 | guidance_model_class = OracleGuidanceModel 234 | guidance_lm = None 235 | else: 236 | raise ValueError(f'Guidance model type: {guidance_model_type} not supported.') 237 | 238 | guidance_args.update(guidance_lm=guidance_lm) 239 | guidance_model = guidance_model_class(**guidance_args) 240 | return guidance_model 241 | 242 | 243 | def get_gradient(loss, curr_history_delta, step_size, args, retain_graph): 244 | if not torch.is_tensor(loss): 245 | # In the case where no loss is returned, we don't need to compute 246 | # the gradient. This happens when using the guidance model and it 247 | # predicts that the answer to the query is ``no''. 248 | grad = [ 249 | (torch.zeros_like(key), torch.zeros_like(value)) 250 | for key, value in curr_history_delta 251 | ] 252 | else: 253 | loss.backward(retain_graph=retain_graph) 254 | grad_norms = [ 255 | ( 256 | torch.norm(key.grad) + SMALL_CONST, 257 | torch.norm(value.grad) + SMALL_CONST 258 | ) 259 | for key, value in curr_history_delta 260 | ] 261 | grad = [( 262 | - step_size * (key.grad / key_grad_norm ** args.gamma), 263 | - step_size * (value.grad / value_grad_norm ** args.gamma) 264 | ) 265 | for (key, value), (key_grad_norm, value_grad_norm) 266 | in zip(curr_history_delta, grad_norms) 267 | ] 268 | return grad 269 | 270 | 271 | def get_query_entity(tokens, tokenizer): 272 | """ 273 | Check and return word/entity with the following order: 274 | 1) tokens form a named entity + extra tokens 275 | 2) tokens form a single word + extra tokens 276 | 277 | """ 278 | text = tokenizer.decode(tokens).strip(' ') 279 | parsed = spacy_parser(text) 280 | parsed = [(t.text, t.pos_) for t in parsed] 281 | 282 | consecutive_idx = 0 283 | query_entity = [] 284 | query_tags = [] 285 | for i, (t_text, t_pos) in enumerate(parsed): 286 | if i == consecutive_idx and t_pos == 'PROPN': 287 | query_entity.append(t_text) 288 | query_tags.append(t_pos) 289 | consecutive_idx += 1 290 | 291 | if not query_entity and parsed and parsed[0][1] == 'NOUN': 292 | query_entity = [parsed[0][0]] 293 | query_tags = [parsed[0][1]] 294 | 295 | query_entity = ' '.join(query_entity) 296 | query_entity = query_entity.lower() if query_entity else None 297 | return query_entity, query_tags, parsed 298 | 299 | 300 | def deepcopy_history(history): 301 | return [ 302 | (key.clone(), value.clone()) 303 | for key, value in history 304 | ] 305 | 306 | 307 | def add_key_values(xs, deltas): 308 | added_xs = [] 309 | for x, delta in zip(xs, deltas): 310 | x_k, x_v = x 311 | delta_k, delta_v = delta 312 | added_xs.append((x_k + delta_k, x_v + delta_v)) 313 | return added_xs 314 | -------------------------------------------------------------------------------- /src/guide.py: -------------------------------------------------------------------------------- 1 | """ 2 | The guidance procedure. 3 | - No Guidance 4 | - PPLM (pplm) 5 | - Weighted Decoding (wd) 6 | """ 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from src.guidance_utils import ( 11 | get_gradient, 12 | get_query_entity, 13 | deepcopy_history, 14 | add_key_values, 15 | ) 16 | 17 | 18 | def run_pplm_step( 19 | example_id, 20 | args, 21 | model, 22 | guidance_model, 23 | in_guidance_model, 24 | tokenizer, 25 | history, 26 | last_token, 27 | curr_context, 28 | topic, 29 | constraint, 30 | datapoint, 31 | ): 32 | """ 33 | Perform PPLM refinement steps to obtain the perturbed history. 34 | """ 35 | history_delta = [ 36 | (torch.zeros_like(key), torch.zeros_like(value)) 37 | for key, value in history 38 | ] 39 | refinement = [] # Track refinement tokens and the guidance probabilties. 40 | for i in range(args.refinement_steps): 41 | curr_history_delta = [ 42 | ( 43 | key.clone().detach().requires_grad_(True), 44 | value.clone().detach().requires_grad_(True) 45 | ) 46 | for key, value in history_delta 47 | ] 48 | perturbed_history = add_key_values(history, curr_history_delta) 49 | 50 | # Generate multiple tokens to query the guidance model. 51 | next_token_ = last_token.clone().detach() 52 | history_ = deepcopy_history(perturbed_history) 53 | multistep_tokens = [] 54 | multistep_probs = [] 55 | for _ in range(args.max_multistep_len): 56 | outputs_ = model(next_token_, past_key_values=history_) 57 | logits_ = outputs_.logits[:, -1, :] 58 | probs_ = F.softmax(logits_, dim=-1) 59 | history_ = outputs_.past_key_values 60 | 61 | next_token_ = torch.argmax(probs_, dim=1).unsqueeze(0) 62 | multistep_tokens.append(next_token_.item()) 63 | multistep_probs.append(probs_) 64 | 65 | # Use the list of probs/histories of the entity tokens for further use. 66 | query_entity, query_tags, parsed = \ 67 | get_query_entity(multistep_tokens, tokenizer) 68 | 69 | # Still use the latest token for generatation while useing the 70 | # full entity probs to get the guidance signal. 71 | next_token_after_refinement = multistep_tokens[0] 72 | next_token_after_refinement_text = tokenizer.decode( 73 | next_token_after_refinement, 74 | skip_special_tokens=True, 75 | ).strip(' ').replace('\n', '\\n') 76 | perturbed_probs = multistep_probs[0] 77 | 78 | ex_loss = 0.0 79 | ex_grad = None 80 | ex_guidance_outs = None 81 | if 'ex' in args.guidance: 82 | ex_guidance_outs = guidance_model.calc_loss( 83 | perturbed_probs, 84 | query_entity, 85 | constraint, 86 | mode='ex', 87 | threshold=args.g_threshold, 88 | ) 89 | ex_loss = ex_guidance_outs['loss'] 90 | 91 | # NOTE: the sign of the loss is assigned here 92 | ex_grad = get_gradient( 93 | ex_loss, 94 | curr_history_delta, 95 | args.alpha, 96 | args, 97 | retain_graph=True, 98 | ) 99 | 100 | in_loss = 0.0 101 | in_grad = None 102 | in_guidance_outs = None 103 | if 'in' in args.guidance: 104 | in_guidance_outs = in_guidance_model.calc_loss( 105 | perturbed_probs, 106 | query_entity, 107 | topic, 108 | mode='in', 109 | threshold=args.g_threshold, 110 | ) 111 | in_loss = in_guidance_outs['loss'] 112 | 113 | # NOTE: the sign of the loss is assigned here 114 | in_grad = get_gradient( 115 | -in_loss, 116 | curr_history_delta, 117 | args.beta, 118 | args, 119 | retain_graph=False, 120 | ) 121 | 122 | grad = None 123 | if ex_grad is not None: 124 | grad = ex_grad 125 | if in_grad is not None: 126 | grad = add_key_values(in_grad, ex_grad) 127 | 128 | history_delta = add_key_values(history_delta, grad) 129 | for key, value in curr_history_delta: 130 | if key.grad is not None: 131 | key.grad.zero_() 132 | if value.grad is not None: 133 | value.grad.zero_() 134 | 135 | refinement.append(dict( 136 | last_token=tokenizer.decode( 137 | last_token.item(), 138 | skip_special_tokens=True 139 | ).strip(' ').replace('\n', '\\n'), 140 | next_token_after_refinement_text=next_token_after_refinement_text, 141 | query_entity=query_entity, 142 | query_tags=query_tags, 143 | parsed=parsed, 144 | ex_yes_prob=( 145 | ex_guidance_outs['yes_prob'] 146 | if ex_guidance_outs is not None else -1.0 147 | ), 148 | ex_no_prob=( 149 | ex_guidance_outs['no_prob'] 150 | if ex_guidance_outs is not None else -1.0 151 | ), 152 | in_yes_prob=( 153 | in_guidance_outs['yes_prob'] 154 | if in_guidance_outs is not None else -1.0 155 | ), 156 | in_no_prob=( 157 | in_guidance_outs['no_prob'] 158 | if in_guidance_outs is not None else -1.0 159 | ), 160 | )) 161 | perturbed_history = add_key_values(history, history_delta) 162 | final_outputs = model(last_token, past_key_values=perturbed_history) 163 | 164 | final_logits = final_outputs.logits[:, -1, :] 165 | return final_logits, refinement 166 | 167 | 168 | def run_wd_step( 169 | example_id, 170 | args, 171 | model, 172 | guidance_model, 173 | in_guidance_model, 174 | tokenizer, 175 | history, 176 | last_token, 177 | curr_context, 178 | topic, 179 | constraint, 180 | datapoint, 181 | ): 182 | stepwise_info = [] 183 | 184 | orig_outputs = model(last_token, past_key_values=history) 185 | orig_logits = orig_outputs.logits[:, -1, :] 186 | orig_probs = F.softmax(orig_logits, dim=-1) 187 | 188 | # Generate multiple tokens to query the guidance model. 189 | if not args.max_multistep_len: 190 | query_entity = None 191 | else: 192 | next_token_ = last_token.clone().detach() 193 | history_ = deepcopy_history(history) 194 | multistep_tokens = [] 195 | for _ in range(args.max_multistep_len): 196 | outputs_ = model(next_token_, past_key_values=history_) 197 | logits_ = outputs_.logits[:, -1, :] 198 | probs_ = F.softmax(logits_, dim=-1) 199 | history_ = outputs_.past_key_values 200 | next_token_ = torch.argmax(probs_, dim=1).unsqueeze(0) 201 | multistep_tokens.append(next_token_.item()) 202 | 203 | # Use the list of probs/histories of the entity tokens for further use. 204 | query_entity, query_tags, parsed = get_query_entity(multistep_tokens, tokenizer) 205 | if query_entity is None: 206 | return orig_logits, [] 207 | 208 | wd_guidance_args = dict( 209 | example_id=example_id, 210 | orig_probs=orig_probs, 211 | query_entity=query_entity, 212 | threshold=args.g_threshold, 213 | curr_context=curr_context, 214 | last_token=last_token, 215 | datapoint=datapoint, 216 | ) 217 | 218 | ex_inds, _, ex_info = guidance_model.query_guidance( 219 | constraint, 220 | mode='ex', 221 | **wd_guidance_args, 222 | ) 223 | orig_logits[:, ex_inds] = orig_logits[:, ex_inds] - args.alpha 224 | 225 | if 'in' in args.guidance: 226 | in_inds, _, in_info = \ 227 | in_guidance_model.query_guidance(topic, mode='in', **wd_guidance_args) 228 | orig_logits[:, in_inds] = orig_logits[:, in_inds] + args.beta 229 | else: 230 | in_info = dict() 231 | 232 | stepwise_info.append(dict( 233 | exclusion_info=ex_info, 234 | inclusion_info=in_info, 235 | )) 236 | final_logits = orig_logits 237 | return final_logits, stepwise_info 238 | 239 | 240 | def run_constrained_decoding( 241 | example_id, 242 | args, 243 | model, 244 | guidance_model, 245 | in_guidance_model, 246 | tokenizer, 247 | history, 248 | last_token, 249 | curr_context, 250 | topic, 251 | constraint, 252 | datapoint, 253 | ): 254 | orig_outputs = model(last_token, past_key_values=history) 255 | orig_logits = orig_outputs.logits[:, -1, :] 256 | 257 | if 'ex' in args.guidance: 258 | words = guidance_model.hierarchy[constraint] 259 | if args.data == 'wordnet': 260 | words = words + [constraint] 261 | elif args.data == 'wikidata': 262 | raise ValueError('Not finished yet. To be completed.') 263 | else: 264 | raise ValueError(f'Data arg {args.data} not recognized.') 265 | 266 | bow_indices = [ 267 | tokenizer.encode(word.strip(), add_prefix_space=True) 268 | for word in words 269 | ] 270 | for inds in bow_indices: 271 | orig_logits[:, inds] = float('-inf') 272 | 273 | return final_logits, [] 274 | -------------------------------------------------------------------------------- /src/guide_with_offset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Offset Guidance. 3 | 4 | Guide generation using the combination of three probabilities 5 | 1. p_1 = p(x_t | x_{ 0: 86 | generated_text = generated_text[0] 87 | else: 88 | generated_text = '' 89 | 90 | if generated_text == '': 91 | print('Empty text.') 92 | continue 93 | 94 | input_ids = lm_scorer.tokenizer(generated_text, return_tensors="pt").input_ids.cuda() 95 | #print(input_ids.size()) 96 | outputs = lm_scorer.model(input_ids, labels=input_ids) 97 | ppl = torch.exp(outputs.loss) 98 | if torch.isnan(ppl) or ppl > 500: 99 | print('NaN loss.') 100 | continue 101 | ppls.append((prediction_id, ppl.item())) 102 | print(f'avg: {sum(ppl for _, ppl in ppls) / len(ppls)} (n={len(ppls)})') 103 | 104 | save_path = path.parent / 'ppl.txt' 105 | with open(save_path, 'w') as f: 106 | for prediction_id, ppl in ppls: 107 | f.write(f'{prediction_id},{ppl}\n') 108 | f.write(f'avg,{sum(ppl for _, ppl in ppls) / len(ppls)}\n') 109 | print('Saved at:', save_path) 110 | print('----------------------------------') 111 | 112 | 113 | @torch.no_grad() 114 | def try_compute_ppl(lm_scorer): 115 | texts = [ 116 | 'There is a dog hiding behind the tree.', 117 | 'There is a dog is a that hiding behind the tree.', 118 | 'The dog is hiding behind the tree.', 119 | 'The dog is hiding in cat of the movie tree.', 120 | 'this is a dog this is a dog this is a dog this is a dog', 121 | 'dog dog dog dog dog dog dog dog dog', 122 | ] 123 | for text in texts: 124 | print('text:') 125 | print(text) 126 | input_ids = lm_scorer.tokenizer(text, return_tensors="pt").input_ids.cuda() 127 | print([lm_scorer.tokenizer.decode(tok) for tok in input_ids[0]]) 128 | outputs = lm_scorer.model(input_ids, labels=input_ids) 129 | print(outputs.loss) 130 | 131 | ppl = torch.exp(outputs.loss).item() 132 | print(f'PPL: {ppl:.4f}') 133 | ppl = lm_scorer.sentence_score(text) 134 | print(f'PPL 2: {ppl:.4f}') 135 | print('---') 136 | 137 | 138 | if __name__ == '__main__': 139 | """ 140 | python -m ci.lm_scorer 141 | """ 142 | model_name = 'gpt2-xl' 143 | #model_name = 'EleutherAI/gpt-j-6B' 144 | 145 | paths = [ 146 | # wordnet 147 | #'./runs/wd/wordnet/wordnet_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=5.0_in=discrete_ex=discrete_trie_i2g_evalv=test/seed_0/predictions.jsonl', 148 | #'./ci/baselines/runs/gpt_engine="davinci"_temp=0.9_top_p=0.95_eval=-1_split=test/wordnet/predictions.jsonl', 149 | #'./ci/baselines/runs/gpt_engine="text-davinci-002"_temp=0.9_top_p=0.95_eval=-1_split=test/wordnet/predictions.jsonl', 150 | #'./runs/wd/wordnet/wordnet_nl+ex+in_gen=gpt2-xl_guide=none_a=100.0_b=5.0_in=none_ex=none_evalv=dev/seed_0/predictions.jsonl', 151 | #'./ci/baselines/runs/gpt_engine="davinci"_temp=0.9_top_p=0.95_eval=-1/wordnet/predictions.jsonl', 152 | #'./ci/baselines/runs/gpt3-legacy/wordnet/predictions.jsonl', 153 | #'./ci/baselines/runs/gpt3-extra-prompt/wordnet/predictions.jsonl', 154 | #'./runs/wd/wordnet/wordnet_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=5.0_in=discrete_ex=discrete_trie_evalv=dev/seed_0/predictions.jsonl', 155 | #'./runs/wd/wordnet/selfdebias_a=0.5_b=0.5_pmt=expof/seed_0/predictions.jsonl', 156 | #'./ci/baselines/runs/ctrl/wordnet/debug/ctrl_eval/seed_0/predictions.jsonl', 157 | 158 | # wikidata 159 | #'./runs/wd/wikidata/wikidata_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=10.0_in=discrete_ex=discrete_i2g_evalv=test/seed_0/predictions.jsonl', 160 | #'./ci/baselines/runs/gpt_engine="davinci"_temp=0.9_top_p=0.95_eval=-1_split=test/wikidata/predictions.jsonl', 161 | #'./ci/baselines/runs/gpt_engine="text-davinci-002"_temp=0.9_top_p=0.95_eval=-1_split=test/wikidata/predictions.jsonl', 162 | #'./runs/wd/wikidata/wikidata_nl+ex+in_gen=gpt2-xl_guide=none_a=100.0_b=10.0_in=none_ex=none_evalv=dev/seed_0/predictions.jsonl', 163 | #'./ci/baselines/runs/gpt_engine="davinci"_temp=0.9_top_p=0.95_eval=-1/wikidata/predictions.jsonl', 164 | #'./ci/baselines/runs/gpt3-legacy/wikidata/predictions.jsonl', 165 | #'./ci/baselines/runs/gpt3/wikidata/predictions.jsonl', 166 | #'./runs/wd/wikidata/wikidata_wd+ex+in_gen=gpt2-xl_guide=oracle_a=100.0_b=10.0_in=none_ex=oracle_evalv=-2/seed_0/predictions.jsonl', 167 | #'./runs/wd/wikidata/_right_results_bad_pred_files/wikidata_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=5.0_in=discrete_ex=binary/seed_0/predictions.jsonl', 168 | #'./runs/wd/wikidata/_right_results_bad_pred_files/wikidata_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=5.0_in=discrete_ex=full_extopk=20/seed_0/predictions.jsonl', 169 | #'./runs/wd/wikidata/wikidata_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=10.0_in=discrete_ex=discrete_beam=8_topp=0.92_return=8_impprompt/seed_0/predictions.jsonl', 170 | #'./runs/wd/wikidata/wikidata_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=10.0_in=discrete_ex=discrete_i2g_evalv=dev/seed_0/predictions.jsonl', 171 | #'./runs/wd/wikidata/selfdebias_a=0.5_b=0.5_pmt=expof/predictions.jsonl', 172 | './ci/baselines/runs/ctrl/wikidata/debug/ctrl_eval/seed_0/predictions.jsonl', 173 | ] 174 | 175 | lm_scorer = LMScorer(model_name=model_name) 176 | #try_compute_ppl(lm_scorer) 177 | for path in paths: 178 | print(path) 179 | compute_ppl(path, lm_scorer) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import yaml 4 | import json 5 | import time 6 | import random 7 | import logging 8 | import argparse 9 | from argparse import Namespace 10 | from pathlib import Path 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | from rich import print 16 | from alive_progress import alive_bar 17 | 18 | from src.utils import ( 19 | get_data, 20 | get_lm, 21 | get_tokenizer, 22 | cleanup_gen_text, 23 | reformat_text, 24 | top_p_sampling, 25 | ) 26 | from src.diverse_instructions import prepare_context 27 | from src.guide import ( 28 | run_pplm_step, 29 | run_wd_step, 30 | run_constrained_decoding, 31 | ) 32 | from src.guide_with_offset import run_offset_step 33 | from src.guidance_utils import add_key_values, get_guidance_model 34 | from src.metrics import compute_prediction_metrics, aggregate_metrics 35 | from src.lm_scorer import LMScorer 36 | 37 | from src.experiment import ExperimentManager 38 | from src.metric_tracker import MetricTracker 39 | 40 | random.seed(0) 41 | 42 | 43 | def run_generation( 44 | datapoint, 45 | hierarchy, 46 | model, 47 | tokenizer, 48 | args, 49 | guidance_model=None, 50 | in_guidance_model=None 51 | ): 52 | if args.eval_version == -2: 53 | context = datapoint['context'] 54 | else: 55 | context = datapoint['context_with_instructions'] 56 | #context = datapoint['context'] 57 | context = tokenizer.encode(context, return_tensors='pt').cuda() 58 | start_context_length = context.size(1) 59 | topic = datapoint['topic'] 60 | constraint = datapoint['constraint'] 61 | example_id = datapoint['id'] 62 | if in_guidance_model is None: 63 | in_guidance_model = guidance_model 64 | 65 | gen_step = 0 66 | context_length = 0 67 | generated_tokens = [] 68 | end = [] 69 | refinements = [] 70 | 71 | while ( 72 | context_length < args.max_gen_length + start_context_length and 73 | end != [' ==', '\n'] 74 | ): 75 | last_token = context[:, -1:] 76 | curr_context = context[:, :-1] 77 | 78 | outputs = model(curr_context) 79 | history = outputs.past_key_values 80 | 81 | gen_state = dict( 82 | example_id=example_id, 83 | args=args, 84 | model=model, 85 | guidance_model=guidance_model, 86 | in_guidance_model=in_guidance_model, 87 | tokenizer=tokenizer, 88 | history=history, 89 | last_token=last_token, 90 | curr_context=curr_context, 91 | topic=topic, 92 | constraint=constraint, 93 | datapoint=datapoint, 94 | ) 95 | if 'pplm' in args.guidance: 96 | final_logits, refinement = run_pplm_step(**gen_state) 97 | final_probs = F.softmax(final_logits, dim=-1) 98 | elif 'wd' in args.guidance: 99 | final_logits, refinement = run_wd_step(**gen_state) 100 | final_probs = F.softmax(final_logits, dim=-1) 101 | elif 'cd' in args.guidance: 102 | final_logits, refinement = run_constrained_decoding(**gen_state) 103 | final_probs = F.softmax(final_logits, dim=-1) 104 | elif args.guidance == 'none' or 'nl' in args.guidance: 105 | final_outputs = model(last_token, past_key_values=history) 106 | final_logits = final_outputs.logits[:, -1, :] 107 | final_probs = F.softmax(final_logits, dim=-1) 108 | refinement = [] 109 | elif 'os' in args.guidance: 110 | final_probs, refinement = run_offset_step(**gen_state) 111 | else: 112 | raise ValueError(f'Guidance type `{args.guidance}` not supported.') 113 | refinements.append(refinement) 114 | 115 | if args.fusion_gamma is not None or args.fusion_gamma != 1.0: 116 | orig_outputs = model(last_token, past_key_values=history) 117 | orig_logits = orig_outputs.logits[:, -1, :] 118 | orig_probs = F.softmax(orig_logits, dim=-1) 119 | final_probs = final_probs.pow(args.fusion_gamma) * \ 120 | orig_probs.pow(1.0 - args.fusion_gamma) 121 | 122 | if args.top_p is not None: 123 | final_logits = top_p_sampling(final_logits, args.top_p) 124 | final_probs = F.softmax(final_logits / args.temperature, dim=-1) 125 | next_token = torch.multinomial(final_probs, num_samples=1) 126 | else: 127 | next_token = torch.argmax(final_probs, dim=1) 128 | next_token_text = tokenizer.decode( 129 | next_token, 130 | skip_special_tokens=True 131 | ) 132 | 133 | context = torch.cat([context, next_token[:, None]], dim=-1) 134 | context_length = context.size(1) 135 | 136 | generated_tokens.append(next_token) 137 | gen_step += 1 138 | 139 | end.append(next_token_text) 140 | end = end[-2:] 141 | return generated_tokens, refinements 142 | 143 | 144 | def setup_logger(args): 145 | handlers = [] 146 | run_dir = Path(args.run_dir) / args.name 147 | run_dir.mkdir(parents=True, exist_ok=True) 148 | prediction_path = run_dir / 'predictions.jsonl' 149 | config_path = run_dir / 'config.yaml' 150 | 151 | if args.override == 'manual': 152 | ans = input( 153 | f'The following files will be override:\n' 154 | f'`{prediction_path}`\n' 155 | f'`{config_path}`\n' 156 | f'Proceed? [yes/no]: ' 157 | ) 158 | if ans.lower() == 'yes': 159 | handlers.append(logging.FileHandler(prediction_path, 'w')) 160 | with open(config_path, 'w') as f: 161 | yaml.dump(vars(args), f) 162 | else: 163 | sys.exit(0) 164 | elif args.override == 'auto': 165 | handlers.append(logging.FileHandler(prediction_path, 'w')) 166 | with open(config_path, 'w') as f: 167 | yaml.dump(vars(args), f) 168 | elif args.override == 'no': 169 | handlers.append(logging.FileHandler(prediction_path)) 170 | with open(config_path, 'w') as f: 171 | yaml.dump(vars(args), f) 172 | else: 173 | raise ValueError(f'`{args.override}` not recognized.') 174 | 175 | if args.log_to_console: 176 | handlers.append(RichHandler()) 177 | 178 | logging.basicConfig( 179 | format='%(message)s', 180 | level=logging.INFO, 181 | handlers=handlers 182 | ) 183 | 184 | # Save command to file in `run_dir``. 185 | run_command = ' '.join(['python -m ci.main'] + sys.argv[1:]) 186 | with open(run_dir / 'run.sh', 'w') as f: 187 | f.write(run_command) 188 | return run_dir 189 | 190 | 191 | def setup_args(args): 192 | if args.guidance_model_name == 'oracle': 193 | args.max_multistep_len = 0 194 | if args.guidance != 'none': 195 | args.guidance = args.guidance.split('+') 196 | return args 197 | 198 | 199 | def run(args, run_id, exp_manager=None): 200 | #run_dir = setup_logger(args) 201 | args = setup_args(args) 202 | 203 | vs = [] 204 | os = [] 205 | onvs = [] 206 | 207 | num_devices = torch.cuda.device_count() 208 | 209 | datasets, hierarchy, gold = get_data(args) 210 | model = get_lm(args, num_devices) 211 | tokenizer = get_tokenizer(args) 212 | 213 | # Load the guidance model. 214 | if args.guidance_model_type != 'none': 215 | guidance_model = get_guidance_model( 216 | args, 217 | tokenizer, 218 | hierarchy, 219 | num_devices=num_devices 220 | ) 221 | else: 222 | guidance_model = None 223 | 224 | if (args.guidance_model_type_2 != args.guidance_model_type and 225 | args.guidance_model_type_2 != 'none'): 226 | in_guidance_model = get_guidance_model( 227 | args, 228 | tokenizer, 229 | hierarchy, 230 | guidance_lm=guidance_model.guidance_lm 231 | ) 232 | else: 233 | in_guidance_model = None 234 | 235 | lm_scoreer = LMScorer(model=model, tokenizer=tokenizer) 236 | 237 | # Main loop. 238 | predictions = [] 239 | prediction_metrics = [] 240 | total = len(datasets[args.dataset_split]) 241 | 242 | with alive_bar(total, enrich_print=False) as bar: 243 | for idx, datapoint in enumerate(datasets[args.dataset_split]): 244 | if len(predictions) == args.num_datapoints: 245 | break 246 | 247 | topic = datapoint['topic'] 248 | constraint = datapoint['constraint'] 249 | 250 | if args.eval_version == -1 and args.dataset_split == 'dev': 251 | eval_version = random.choice(range(3, 6)) 252 | elif args.eval_version == -1 and args.dataset_split == 'test': 253 | eval_version = random.choice(range(6, 35)) 254 | else: 255 | eval_version = args.eval_version 256 | 257 | datapoint = prepare_context(datapoint, args, version=eval_version) 258 | 259 | # NOTE: The main logic starts here. 260 | gen_ids, refinements = run_generation( 261 | datapoint, 262 | hierarchy, 263 | model, 264 | tokenizer, 265 | args, 266 | guidance_model, 267 | in_guidance_model=in_guidance_model, 268 | ) 269 | 270 | generated_text = tokenizer.decode( 271 | torch.cat(gen_ids), 272 | skip_special_tokens=True 273 | ) 274 | generated_text = cleanup_gen_text(generated_text) 275 | prediction = dict( 276 | datapoint=datapoint, 277 | guidance=( 278 | '+'.join(args.guidance) 279 | if args.guidance != 'none' else args.guidance 280 | ), 281 | generated_text=generated_text, 282 | refinements=refinements, 283 | generated_tokens=[tokenizer.decode(token) for token in gen_ids], 284 | ) 285 | predictions.append(prediction) 286 | 287 | prediction_metric_outputs = compute_prediction_metrics( 288 | prediction, 289 | hierarchy, 290 | data_mode=args.data, 291 | ) 292 | prediction_metric = prediction_metric_outputs['prediction_metric'] 293 | prediction_metric['id'] = datapoint['id'] 294 | extracted = prediction_metric_outputs['extracted'] 295 | if generated_text: 296 | ppl = lm_scoreer.sentence_score(generated_text) 297 | else: 298 | ppl = 0.0 299 | 300 | prediction_metric.update({'ppl': ppl}) 301 | prediction_metrics.append(prediction_metric) 302 | 303 | exp_manager.take( 304 | {**prediction, **prediction_metric}, 305 | log=True, 306 | console=False, 307 | run_id=run_id, 308 | add_metric=True, 309 | ) 310 | 311 | v = prediction_metric['violated'] 312 | o = prediction_metric['on_topic'] 313 | onv = prediction_metric['on_topic_not_violated'] 314 | vs.append(v) 315 | os.append(o) 316 | onvs.append(onv) 317 | print(f'name: {args.name} (num={len(predictions)})') 318 | print(args) 319 | print(f'eval_version: {eval_version}') 320 | print(f'topic: {topic}') 321 | print(datapoint['context']) 322 | print(f'constraint: {constraint}') 323 | print(f'generated_text:') 324 | print(generated_text) 325 | print(f'Extracted: {extracted}') 326 | print(f'v: {v}') 327 | print(f'o: {o}') 328 | print(f'onv: {onv}') 329 | print(f'ppl: {ppl:.4f}') 330 | print(f'Accumulated violated: {sum(vs) / len(vs):.4f}') 331 | print(f'Accumulated on_topic: {sum(os) / len(os):.4f}') 332 | print(f'Accumulated on_topic_not_violated: {sum(onvs) / len(onvs):.4f}') 333 | print('---') 334 | 335 | time.sleep(0.005) 336 | bar() 337 | 338 | stats = exp_manager.aggregate_metrics() 339 | exp_manager.save_results(stats) 340 | print(stats) 341 | print(args) 342 | print('Run finished.') 343 | 344 | 345 | def get_argparse(config=None): 346 | parser = argparse.ArgumentParser() 347 | parser.add_argument("--name", type=str, default=None) 348 | parser.add_argument("--run_dir", type=str, default="./runs") 349 | parser.add_argument("--debug", action="store_true") 350 | parser.add_argument("--override", type=str, default="manual", 351 | choices=['auto', 'manual', 'no'], 352 | help=( 353 | "`auto`: override all, " 354 | "`manual`: ask in terminal before override, " 355 | "`no`: no override." 356 | ) 357 | ) 358 | parser.add_argument("--data", type=str, default='wordnet', 359 | choices=['wordnet', 'wikidata'], 360 | ) 361 | 362 | # Data arguments. 363 | parser.add_argument("--train_path", type=str, default=None) 364 | parser.add_argument("--dev_path", type=str, default=None) 365 | parser.add_argument("--test_path", type=str, default=None) 366 | parser.add_argument("--hierarchy_path", type=str, default=None) 367 | parser.add_argument("--wiki_gold", type=str, default=None) 368 | parser.add_argument("--num_datapoints", type=int, default=500) 369 | parser.add_argument("--dataset_split", type=str, default='dev') 370 | parser.add_argument("--eval_version", type=int, default=0, 371 | help=( 372 | "-1: use dataset_split to decide which version to use.\n" 373 | "-2: use only the context." 374 | ) 375 | ) 376 | 377 | # Generation model arguments. 378 | parser.add_argument("--model_name", type=str, default="gpt2-xl") 379 | parser.add_argument("--top_p", type=float, default=None) 380 | parser.add_argument("--temperature", type=float, default=1.0) 381 | parser.add_argument("--guidance", type=str, default='none', 382 | choices=[ 383 | 'none', # No constraint 384 | 'pplm+ex', 'pplm+ex+in', # PPLM 385 | 'wd+ex', 'wd+ex+in', # Weighted decoding 386 | 'cd+ex', # Constrained decoding 387 | 'nl+ex', 'nl+ex+in', # Natural language constraint 388 | 'os+ex+in', # Offset guidance 389 | ], 390 | ) 391 | parser.add_argument("--max_gen_length", type=int, default=60) 392 | parser.add_argument("--alpha", type=float, default=None) # 0.02 / 0.1 393 | parser.add_argument("--beta", type=float, default=None) # 0.05 394 | parser.add_argument("--gamma", type=float, default=None) # 1.5 395 | parser.add_argument("--fusion_gamma", type=float, default=1.0) # 1.0 396 | parser.add_argument("--refinement_steps", type=int, default=0) # 3 397 | parser.add_argument("--prev_run_dir", type=str, default=None) 398 | parser.add_argument("--log_to_console", action="store_true") 399 | 400 | # Guidance model. 401 | parser.add_argument("--guidance_model_name", type=str, default='none', 402 | choices=[ 403 | 'none', 404 | 'oracle', 405 | 'gpt2', 'gpt2-xl', 'gpt2-ft', 406 | 'EleutherAI/gpt-j-6B', 407 | ], 408 | help="Guidance model. Default is the oracle." 409 | ) 410 | parser.add_argument("--guidance_model_type", type=str, default='none', 411 | choices=['none', 'full', 'binary', 'discrete'], 412 | help="Guidance model type." 413 | ) 414 | parser.add_argument("--guidance_model_type_2", type=str, default='none', 415 | choices=['none', 'full', 'binary', 'discrete'], 416 | help="Second guidance model (will only be used for INCLUSION if specified)." 417 | ) 418 | parser.add_argument("--guidance_model_path", type=str, default="", 419 | help="Load checkpointed guidance model." 420 | ) 421 | parser.add_argument("--num_icl_pairs", type=int, default=0, help="Default: 0") 422 | parser.add_argument("--g_threshold", type=float, default=0.5, help="Default: 0.5") 423 | parser.add_argument("--max_multistep_len", type=int, default=0, help="Default: 8") 424 | parser.add_argument("--full_guide_topk_in", type=int, default=0, help="Default: 40") 425 | parser.add_argument("--full_guide_topk_ex", type=int, default=0, help="Default: 20") 426 | parser.add_argument("--discrete_max_length", type=int, default=100, 427 | help="Default: 100. The number of tokens generated for discrete guidance." 428 | ) 429 | parser.add_argument("--discrete_guidance_num_beams", type=int, default=1) 430 | parser.add_argument("--discrete_guidance_num_beam_groups", type=int, default=1) 431 | parser.add_argument("--discrete_guidance_do_sample", type=bool, default=False) 432 | parser.add_argument("--discrete_guidance_top_k", type=int, default=None) 433 | parser.add_argument("--discrete_guidance_top_p", type=float, default=None) 434 | parser.add_argument("--discrete_guidance_temperature", type=float, default=None) 435 | parser.add_argument( 436 | "--discrete_guidance_num_return_sequences", type=int, default=None 437 | ) 438 | parser.add_argument("--discrete_guidance_diversity_penalty", type=float, default=None) 439 | parser.add_argument( 440 | "--discrete_guidance_instruct2guide_model_dir", type=str, default=None, 441 | help="Directory of tuned prefixes for topic and constraint." 442 | ) 443 | parser.add_argument("--discrete_guidance_use_trie", action="store_true") 444 | 445 | if config is not None and isinstance(config, dict): 446 | parser.set_defaults(**config) 447 | args = parser.parse_args() 448 | return args 449 | 450 | 451 | if __name__ == '__main__': 452 | """ 453 | python -m src.main --name RUN_NAME 454 | """ 455 | args = get_argparse() 456 | 457 | exp_manager = ExperimentManager( 458 | name=args.name, 459 | run_dir=args.run_dir, 460 | override=True, 461 | num_runs=1, 462 | metric_tracker=MetricTracker(), 463 | ) 464 | run(args, run_id=0, exp_manager=exp_manager) 465 | 466 | -------------------------------------------------------------------------------- /src/metric_tracker.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import numpy as np 3 | 4 | 5 | class MetricTracker: 6 | def __init__(self): 7 | self._state = defaultdict(lambda: defaultdict(list)) 8 | 9 | @property 10 | def state(self): 11 | return {k: dict(v) for k, v in self._state.items()} 12 | 13 | def add(self, metric: dict, run_id: int): 14 | invalid_value_types = {str, list, dict} 15 | for name, value in metric.items(): 16 | if type(value) in invalid_value_types: 17 | continue 18 | self._state[run_id][name].append(value) 19 | 20 | def compute(self, metric_name=None, run_id=None, ignore=('id', 'global_step')): 21 | """ 22 | `metric_name` and `run_id` are expected to be passed in simultaneously, 23 | or they should both be `None`. 24 | """ 25 | stats = defaultdict(dict) 26 | for run_id_, metrics in self._state.items(): 27 | for name, values in metrics.items(): 28 | if metric_name is not None and name != metric_name: 29 | continue 30 | if name in ignore: 31 | continue 32 | avg = sum(values) / len(values) 33 | stats[run_id_][name] = avg 34 | stats[run_id_][name + '_num_datapoints'] = len(values) 35 | stats = dict(stats) 36 | 37 | if run_id is not None and metric_name is not None: 38 | stats = stats[run_id][metric_name] 39 | return stats 40 | 41 | def aggregate(self, avg_runs=True): 42 | per_run_avg_stats = self.compute() 43 | 44 | if not avg_runs: 45 | return per_run_avg_stats 46 | else: 47 | run_avg_stats = defaultdict(list) 48 | metrics = list(per_run_avg_stats.values()) 49 | for metric in metrics: 50 | for name, value in metric.items(): 51 | run_avg_stats[name].append(value) 52 | 53 | # Calculate mean/std. 54 | run_avg_stats = { 55 | k: { 56 | 'mean': float(np.mean(vs)), 57 | 'std': float(np.std(vs)), 58 | 'num_runs': len(vs), 59 | 'num_datapoints': run_avg_stats[k + '_num_datapoints'], 60 | } 61 | for k, vs in run_avg_stats.items() if not k.endswith('_num_datapoints') 62 | } 63 | return run_avg_stats 64 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metrics. 3 | """ 4 | import re 5 | from collections import defaultdict 6 | 7 | import spacy 8 | from nltk import ngrams 9 | from nltk.tokenize import word_tokenize 10 | from nltk.translate.bleu_score import sentence_bleu 11 | 12 | from src.guidance_utils import remove_parenthesis 13 | 14 | spacy_parser = spacy.load("en_core_web_lg") 15 | 16 | 17 | def aggregate_metrics(metrics, name): 18 | """ 19 | Aggregate metrics for one run. 20 | """ 21 | run_stats = defaultdict(list) 22 | run_stats['name'] = name 23 | 24 | for metric in metrics: 25 | for k, v in metric.items(): 26 | run_stats[k].append(v) 27 | num_datapoints = [] 28 | for k, vs in run_stats.items(): 29 | if isinstance(vs, list): 30 | run_stats[k] = sum(vs) / len(vs) 31 | num_datapoints.append(len(vs)) 32 | # assert not num_datapoints or len(set(num_datapoints)) == 1, num_datapoints 33 | run_stats['num_datapoints'] = num_datapoints[0] if num_datapoints else 0 34 | 35 | return dict(run_stats) 36 | 37 | 38 | def compute_prediction_metrics(prediction, hierarchy, data_mode): 39 | datapoint = prediction['datapoint'] 40 | context = datapoint['context'] 41 | topic = datapoint['topic'] 42 | constraint = datapoint['constraint'] 43 | generated_text = prediction['generated_text'] 44 | 45 | if data_mode == 'wordnet': 46 | generated_text = generated_text.lower() 47 | forbidden_words = hierarchy[constraint] + [constraint] 48 | forbidden_words = set( 49 | forbidden_words + [w + 's' for w in forbidden_words] 50 | ) 51 | 52 | topical_words = hierarchy[topic] 53 | topical_words = set(topical_words + [w + 's' for w in topical_words]) 54 | 55 | violated = any(forbidden_word in generated_text for forbidden_word in forbidden_words) 56 | on_topic = any(topical_word in generated_text for topical_word in topical_words) 57 | 58 | topical_word_regex = '|'.join(list(topical_words)) 59 | pattern = re.compile(rf'({topical_word_regex})') 60 | topical_word_matches = pattern.findall(generated_text) 61 | 62 | #if len(set(topical_word_matches) & set(prediction['datapoint']['current'])) > 0: 63 | # on_topic = False 64 | #print(prediction['datapoint']['current']) 65 | #print(topical_word_matches) 66 | #print(topic) 67 | #print(constraint) 68 | #print('---') 69 | 70 | extracted = None 71 | elif data_mode == 'wikidata': 72 | 73 | gen_text_parsed = spacy_parser(generated_text) 74 | parsed_names = set([ 75 | ent.text.lower() for ent in gen_text_parsed.ents 76 | if ent.label_ == 'PERSON' 77 | ]) 78 | forbidden_words = hierarchy[constraint] 79 | forbidden_words = [remove_parenthesis(q_title) for q_id, q_title in forbidden_words] 80 | violated = False 81 | violated_word = None 82 | for w in forbidden_words: 83 | if w in generated_text: 84 | violated = True 85 | violated_word = w 86 | break 87 | #violated = len(parsed_names & forbidden_words) > 0 88 | 89 | topical_words = hierarchy[topic] 90 | topical_words = [remove_parenthesis(q_title) for q_id, q_title in topical_words] 91 | on_topic = False 92 | on_topic_word = None 93 | for w in topical_words: 94 | if w in generated_text: 95 | on_topic = True 96 | on_topic_word = w 97 | break 98 | #on_topic = len(parsed_names & topical_words) > 0 99 | #extracted = parsed_names 100 | extracted = dict(violated_word=violated_word, on_topic_word=on_topic_word) 101 | else: 102 | raise ValueError(f'Data mode {data_mode} not recognized.') 103 | on_topic_not_violated = (not violated) and on_topic 104 | copying_bleu_score = copying_bleu(context, generated_text) 105 | repetition_scores = get_repetition_scores(generated_text.split()) 106 | prediction_metric = dict( 107 | violated=violated, 108 | on_topic=on_topic, 109 | on_topic_not_violated=on_topic_not_violated, 110 | copying_bleu_score=copying_bleu_score, 111 | ) 112 | prediction_metric = {**prediction_metric, **repetition_scores} 113 | 114 | return dict( 115 | prediction_metric=prediction_metric, 116 | extracted=extracted, 117 | ) 118 | 119 | 120 | def copying_bleu(context, generated_text): 121 | prediction = word_tokenize(generated_text.strip('==').strip()) 122 | max_score = float('-inf') 123 | for context_sent in context.strip().split('\n'): 124 | references = [word_tokenize(context_sent.strip('==').strip())] 125 | score = sentence_bleu(references, prediction) 126 | if score > max_score: 127 | max_score = score 128 | return score 129 | 130 | 131 | def get_repetition_scores(tokens): 132 | metric = defaultdict(float) 133 | for n in range(1, 5): 134 | ngs = [ng for ng in ngrams(tokens, n)] 135 | unique_ngs = set(ngs) 136 | if not ngs: 137 | metric[f'seq-rep-{n}'] = 0.0 138 | else: 139 | metric[f'seq-rep-{n}'] = 1.0 - (len(unique_ngs) / len(ngs)) 140 | return dict(metric) 141 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import yaml 3 | import random 4 | from pathlib import Path 5 | from ast import literal_eval 6 | from argparse import Namespace 7 | from collections import defaultdict 8 | from os.path import dirname, abspath, join 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from transformers import ( 13 | AutoTokenizer, 14 | AutoModelForCausalLM, 15 | StoppingCriteria, 16 | StoppingCriteriaList, 17 | ) 18 | 19 | random.seed(0) 20 | 21 | 22 | def top_p_sampling(logits, top_p): 23 | """ 24 | logits: (1, vocab_size) 25 | Code taken from: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 26 | """ 27 | logits = logits.squeeze(0) 28 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 29 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 30 | #print(cumulative_probs.tolist()[100]) 31 | #print(cumulative_probs.tolist()[200]) 32 | #print(cumulative_probs.tolist()[500]) 33 | #print(cumulative_probs.tolist()[1000]) 34 | #print(cumulative_probs.tolist()[2000]) 35 | #print(cumulative_probs.tolist()[5000]) 36 | #print(cumulative_probs.tolist()[10000]) 37 | #print(cumulative_probs.tolist()[20000]) 38 | #print(cumulative_probs.tolist()[-1]) 39 | sorted_indices_to_remove = cumulative_probs > top_p 40 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 41 | sorted_indices_to_remove[..., 0] = 0 42 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 43 | logits[indices_to_remove] = float('-inf') 44 | return logits 45 | 46 | 47 | class EndOfFunctionCriteria(StoppingCriteria): 48 | """ 49 | Code taken from: https://github.com/huggingface/transformers/blob/1d651868d64e8f54f7bf6b687fbcdac832039334/examples/research_projects/codeparrot/scripts/human_eval.py#L25 50 | Custom `StoppingCriteria` which checks if all generated functions in 51 | the batch are completed. 52 | """ 53 | 54 | def __init__(self, start_length, eof_strings, tokenizer): 55 | self.start_length = start_length 56 | self.eof_strings = eof_strings 57 | self.tokenizer = tokenizer 58 | 59 | def __call__(self, input_ids, scores, **kwargs): 60 | """ 61 | Returns true if all generated sequences contain any of the 62 | end-of-function strings. 63 | """ 64 | decoded_generations = \ 65 | self.tokenizer.batch_decode(input_ids[:, self.start_length :]) 66 | done = [] 67 | for decoded_generation in decoded_generations: 68 | done.append(any([ 69 | stop_string in decoded_generation 70 | for stop_string in self.eof_strings 71 | ])) 72 | return all(done) 73 | 74 | 75 | def to_namespace(config): 76 | return Namespace(**config) 77 | 78 | 79 | def prepare_args(args): 80 | if args.model_type in ('classification', 'prompt-tune', 'fine-tune'): 81 | args.metric_name = 'Accuracy' 82 | elif args.model_type == 'regression': 83 | args.metric_name = 'Class=1 Precision' 84 | else: 85 | raise ValueError(f'Unknown model_type: {args.model_type}') 86 | return args 87 | 88 | 89 | def update_args(args): 90 | """ 91 | Full in null args for later-trained model. 92 | """ 93 | if not hasattr(args, 'loss_type'): 94 | setattr(args, 'loss_type', 'ce') 95 | return args 96 | 97 | 98 | def get_wikidata_p_text(p_name, p_value): 99 | if p_name == 'place_of_birth': 100 | fill_text = f'people who were born in {p_value}' 101 | elif p_name == 'place_of_death': 102 | fill_text = f'people who died in {p_value}' 103 | elif p_name == 'occupation': 104 | fill_text = p_value 105 | elif p_name == 'country_of_citizenship': 106 | fill_text = f'people who are citizens of {p_value}' 107 | elif p_name == 'academic_degree': 108 | fill_text = f'people who hold a degree in {p_value}' 109 | elif p_name == 'educated_at': 110 | fill_text = f'people who had their education at {p_value}' 111 | return fill_text 112 | 113 | 114 | def cleanup_gen_text(text): 115 | """ 116 | Cleanup items 117 | - drop the last unfinished sentence (should finish with `==`) 118 | 119 | The normal case: `len(text.split('\n')) == 3`. 120 | """ 121 | sents = text.strip().split('\n') 122 | return sents[0] 123 | # print(sents) 124 | # if len(sents) > 2 and not sents[-1].endswith('=='): 125 | # sents = sents[:-1] 126 | # return '\n'.join(sents) 127 | 128 | 129 | def reformat_text(text): 130 | text_tokens = text.split(' ') 131 | text = ' '.join(text_tokens) 132 | return text 133 | 134 | 135 | def get_hierarchy_path_to_children(path): 136 | hierarchy_path_to_children = dict() 137 | with open(path) as f: 138 | for line in f: 139 | obj = json.loads(line.strip()) 140 | hierarchy_path = obj['hierarchy_path'] 141 | children = obj['children'] 142 | hierarchy_path_to_children[tuple(hierarchy_path)] = children 143 | return hierarchy_path_to_children 144 | 145 | 146 | def get_hierarchy(path=None): 147 | if path is None: 148 | path = '/path/to/your/topic_to_leafs.json' 149 | with open(path) as f: 150 | hierarchy_ = json.load(f) 151 | hierarchy = dict() 152 | for topic, leafs in hierarchy_.items(): 153 | new_topic = topic.replace('_', ' ') 154 | new_leafs = [l.replace('_', ' ') for l in leafs] 155 | hierarchy[new_topic] = new_leafs 156 | return hierarchy 157 | 158 | 159 | class WikidataHierarchy: 160 | def __init__(self, q_to_p, p_to_q): 161 | self.q_to_p = q_to_p 162 | self.p_to_q = p_to_q 163 | 164 | def __contains__(self, key): 165 | if self.__getitem__(key): 166 | return True 167 | else: 168 | return False 169 | 170 | def __getitem__(self, p): 171 | """ 172 | For a given `p` (i.e., topic or constraint), return a list of all Q's 173 | that have this `p`. 174 | Return list like [(Q7339, 'Margot Frank'), ...]. 175 | """ 176 | if isinstance(p, list): 177 | p = tuple(p) 178 | 179 | return self.p_to_q.get(p, []) 180 | 181 | 182 | def get_wikidata_hierarchy(): 183 | from evaluation.scripts.build_wikidata_dataset import ( 184 | load_ranked_properties, 185 | load_all_entities 186 | ) 187 | WIKIDATA_PATH = Path('path/to/your/qid2title.json') 188 | q_to_p, p_to_q = load_all_entities(WIKIDATA_PATH) 189 | hierarchy = WikidataHierarchy(q_to_p, p_to_q) 190 | return hierarchy 191 | 192 | 193 | def normalize_datapoint(datapoint, args, hierarchy): 194 | if args.data == 'wordnet': 195 | return datapoint 196 | elif args.data == 'wikidata': 197 | normalized = dict() 198 | if 'example_id' in datapoint: 199 | normalized['id'] = datapoint['example_id'] 200 | elif 'id' in datapoint: 201 | normalized['id'] = datapoint['id'] 202 | 203 | if 'text' in datapoint: 204 | normalized['context'] = datapoint['text'] 205 | elif 'context' in datapoint: 206 | normalized['context'] = datapoint['context'] 207 | 208 | normalized['topic'] = ( 209 | tuple(datapoint['p']) 210 | if 'topic' not in datapoint 211 | else datapoint['topic'] 212 | ) 213 | normalized['gen_qs'] = datapoint['gen_qs'] 214 | normalized['gen_text'] = datapoint['gen_text'] 215 | 216 | if 'constraint' in datapoint: 217 | normalized['constraint'] = datapoint['constraint'] 218 | else: 219 | constraint_candidates = defaultdict(int) 220 | for gen_q in datapoint['gen_qs']: 221 | gen_ps = hierarchy.q_to_p[tuple(gen_q)] 222 | for p_name, p_values in gen_ps.items(): 223 | for p_value in p_values: 224 | constraint_candidates[(p_name, p_value)] += 1 225 | 226 | SKIP = {} 227 | selected_constraint = None 228 | for constraint in constraint_candidates.keys(): 229 | if normalized['topic'][0] != constraint[0] and constraint[0] not in SKIP: 230 | selected_constraint = constraint 231 | break 232 | 233 | normalized['constraint'] = selected_constraint 234 | return normalized 235 | else: 236 | raise ValueError(f'Data arg {args.data} not recognized.') 237 | 238 | 239 | def get_data(args, randomize=False, hierarchy_only=False): 240 | if args.data == 'wordnet': 241 | hierarchy = get_hierarchy(args.hierarchy_path) 242 | elif args.data == 'wikidata': 243 | # Return the modifier to update the datapoint. 244 | hierarchy = get_wikidata_hierarchy() 245 | else: 246 | raise ValueError(f'Data arg {args.data} not recognized.') 247 | 248 | if hierarchy_only: 249 | return hierarchy 250 | 251 | datasets = defaultdict(list) 252 | data_paths = { 253 | 'train': args.train_path, 254 | 'dev': args.dev_path, 255 | 'test': args.test_path, 256 | } 257 | for dataset_split, dataset_path in data_paths.items(): 258 | if dataset_path is not None: 259 | with open(dataset_path) as f: 260 | for line in f: 261 | datapoint = json.loads(line.strip()) 262 | datapoint = normalize_datapoint(datapoint, args, hierarchy) 263 | 264 | topic = datapoint['topic'] 265 | constraint = datapoint['constraint'] 266 | if ( 267 | args.data == 'wordnet' and 268 | ( 269 | constraint not in hierarchy[topic] or 270 | constraint == topic or 271 | constraint not in hierarchy 272 | ) 273 | ): 274 | # skipping bad data 275 | continue 276 | 277 | if args.data == 'wikidata' and constraint is None: 278 | continue 279 | 280 | datasets[dataset_split].append(datapoint) 281 | if randomize: 282 | random.shuffle(datasets[dataset_split]) 283 | 284 | num_datapoints = getattr(args, dataset_split + '_num_datapoints', 100000000) 285 | datasets[dataset_split] = datasets[dataset_split][:num_datapoints] 286 | return datasets, hierarchy, None 287 | 288 | 289 | def get_model_class(model_name): 290 | from ci.train_guidance_model import ( 291 | MLPRegressor, 292 | MLPClassifier, 293 | PromptTuningModel, 294 | FineTuningModel, 295 | ) 296 | NAME_TO_CLASS = { 297 | 'MLPRegressor': MLPRegressor, 298 | 'MLPClassifier': MLPClassifier, 299 | 'PromptTuningModel': PromptTuningModel, 300 | 'FineTuningModel': FineTuningModel, 301 | } 302 | return NAME_TO_CLASS[model_name] 303 | 304 | 305 | def put_model_on_gpus(model, model_name, num_devices): 306 | if model_name == 'gpt2-xl': 307 | if num_devices == 8: 308 | device_map = { 309 | 0: list(range(0, 6)), 310 | 1: list(range(6, 12)), 311 | 2: list(range(12, 18)), 312 | 3: list(range(18, 24)), 313 | 4: list(range(24, 30)), 314 | 5: list(range(30, 36)), 315 | 6: list(range(36, 42)), 316 | 7: list(range(42, 48)), 317 | } 318 | model.parallelize(device_map) 319 | elif num_devices == 4: 320 | device_map = { 321 | 0: list(range(0, 12)), 322 | 1: list(range(12, 24)), 323 | 2: list(range(24, 36)), 324 | 3: list(range(36, 48)), 325 | } 326 | model.parallelize(device_map) 327 | elif num_devices == 2: 328 | device_map = { 329 | 0: list(range(0, 24)), 330 | 1: list(range(24, 48)), 331 | } 332 | model.parallelize(device_map) 333 | elif num_devices == 1: 334 | model.cuda() 335 | else: 336 | raise ValueError(f'Num devices ({num_devices}) not supported for gpt2-xl.') 337 | else: 338 | model.cuda() 339 | return model 340 | 341 | 342 | def get_lm(args, num_devices=None, load_mode=None): 343 | if load_mode is None: 344 | model_name = args.model_name 345 | elif load_mode == 'guidance': 346 | model_name = args.guidance_model_name 347 | else: 348 | raise ValueError(f'Unknown load_mode: {load_mode}') 349 | model = AutoModelForCausalLM.from_pretrained( 350 | model_name, 351 | local_files_only=True, 352 | ) 353 | if num_devices is not None: 354 | model = put_model_on_gpus(model, model_name, num_devices) 355 | for param in model.parameters(): 356 | param.requires_grad = False 357 | model.eval() 358 | print(f'Model "{model_name}" loaded.') 359 | return model 360 | 361 | 362 | def get_tokenizer(args): 363 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False) 364 | tokenizer.pad_token = tokenizer.eos_token 365 | return tokenizer 366 | 367 | 368 | def get_node_children_in_text(text, node, hierarchy): 369 | """ 370 | Check if `text` contains `node`'s children in the `hierarchy`. 371 | """ 372 | nodes = set(hierarchy[node] + [node]) 373 | return list(set([node for node in nodes if node in text])) 374 | 375 | 376 | def save_ckpt(model, args, run_dir, **kwargs): 377 | stats = dict() 378 | stats['current_epoch'] = kwargs.get('epoch', None) 379 | stats['global_step'] = kwargs.get('global_step', None) 380 | stats['best_metric'] = kwargs.get('best_metric', None) 381 | 382 | with open(run_dir / 'stats.yaml', 'w') as f: 383 | yaml.dump(stats, f) 384 | 385 | args.model_class_name = model.__class__.__name__ 386 | args.ckpt_path = run_dir / 'checkpoint.pt' 387 | checkpoint = dict() 388 | checkpoint['args'] = vars(args) 389 | checkpoint['states'] = model.state_dict() 390 | 391 | torch.save(checkpoint, args.ckpt_path) 392 | print(f"Model saved at: {args.ckpt_path}") 393 | 394 | 395 | def load_ckpt(load_path): 396 | checkpoint = torch.load(load_path, map_location='cpu') 397 | args = Namespace(**checkpoint['args']) 398 | args = update_args(args) 399 | states = checkpoint['states'] 400 | model_class = get_model_class(args.model_class_name) 401 | model = model_class(args) 402 | model.load_state_dict(states) 403 | model.eval() 404 | print('Model loaded from:', load_path) 405 | return model, args 406 | --------------------------------------------------------------------------------