├── README.md
├── assets
├── cognacgen.png
└── teaser.png
├── requirements.txt
└── src
├── __init__.py
├── data_utils.py
├── diverse_instructions.csv
├── diverse_instructions.py
├── experiment.py
├── guidance_models.py
├── guidance_utils.py
├── guide.py
├── guide_with_offset.py
├── lm_scorer.py
├── main.py
├── metric_tracker.py
├── metrics.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Cognac
2 | Repo for paper: [Cognac: Controllable Text Generation with Language Constraints](https://arxiv.org/abs/2212.10466)
3 |
4 |
5 | ## Overview
6 |
7 |
8 |
9 |
10 | We propose the Cognac task to stress test LMs ability to follow constraints.
11 | Green highlight specifies the topic to be covered. Red highlight specifies the constraint to conform to. GPT-3 generates continuation that mentioned a politician, thus violating the constraint. Our method, CognacGen, generates continuation that satisfies both the topic requirement and the constraint.
12 |
13 | ## Setup
14 | We use Python version 3.8. Install the dependencies with pip:
15 | ```bash
16 | pip install -r requirements.txt
17 | ```
18 |
19 | Download necessary resources:
20 | ```python
21 | python -m spacy download en_core_web_lg
22 | python -m nltk.downloader wordnet
23 | ```
24 |
25 | ## Cognac Benchmark
26 |
27 | ### WordNet
28 | Download the WordNet data [here](https://drive.google.com/file/d/17mpi7fufaKVEvGdNMAsY_ZBit1FwER7z/view?usp=drive_link). The folder contains files `train.jsonl`, `dev.jsonl`, and `test.jsonl` that include instances of instructions with topics and constraints. The file `topic_to_leafs.json` contains the WordNet hierarchy (used to verify if the generation is conformant). The data is loaded in the code [here](https://github.com/princeton-nlp/Cognac/blob/main/src/utils.py#L239).
29 |
30 | ### Wikidata
31 | Coming soon...
32 |
33 | ## CognacGen
34 |
35 |
36 |
37 | The image above shows the step-by-step procdure for CognacGen to handle natural language instructions.
38 |
39 | Stage 1: the LM generates a list of guidance examples from the queries that specify the topic and constraint. During self-guidance distillation, the topic and constraint prefixes are tuned using the guidance example as target and the instruction with demonstrations as input.
40 |
41 | Stage 2: The guidance model (blue LM & the tuned prefixes) generates guidance examples from the test instance. The guidance examples are used to construct trie trees for both the topic (green) and the constraint (red). The generation (blue) LM’s next token probability is modified by the tries.
42 |
43 | ### Overall Structure
44 | The main run script is `main.py`.
45 | Some important hyperparameters are described below:
46 | - `eval_version`: determining if control code or natural language instruction is used as context
47 | - `guidance`: the combination of guidances to use. `in` means the inclusion of topic is applied. `ex` means the exclusion of constraint is applied. `wd` (weighted decoding) is used in CognacGen. Other options are also available such as
48 | - `guidance_model_type`: guidance type to use for constraint exclusion; `discrete` is "Textual Guidance", `full` is "Top-K Token", and `binary` is "Binary Verifier" described in the paper
49 | - `guidance_model_type_2`: guidance type to use for topic inclusion
50 | - `alpha`: strength of inclusion to apply on the logits during inference
51 | - `beta`: strength of exclusion to applu on the logits during inference
52 |
53 | ### Run Control Code Setting on WordNet
54 |
55 | ```python
56 | python -m src.main \
57 | --name "your_run_name" \
58 | --dataset_split "dev" \
59 | --dev_path "./data/wordnet/dev.jsonl" \
60 | --hierarchy_path "./data/wordnet/topic_to_leafs.json" \
61 | --eval_version -2 \
62 | --guidance "wd+ex+in" \
63 | --guidance_model_name "gpt2-xl" \
64 | --guidance_model_type "discrete" \
65 | --guidance_model_type_2 "discrete" \
66 | --discrete_max_length 200 \
67 | --discrete_guidance_use_trie \
68 | --alpha 100.0 \
69 | --beta 5.0 \
70 | --top_p 0.92 \
71 | --temperature 0.7
72 | ```
73 | Note that the control code setting applies only stage 2 and does not fine-tune the guidance model.
74 |
75 | ### Run Natural Language Instruction Setting on WordNet
76 |
77 | ```python
78 | python -m src.main \
79 | --name "your_run_name" \
80 | --discrete_guidance_instruct2guide_model_dir "path/to/your/prefix/tuned/model/folder" \
81 | --dataset_split "dev" \
82 | --dev_path "./data/wordnet/dev.jsonl" \
83 | --hierarchy_path "./data/wordnet/topic_to_leafs.json" \
84 | --eval_version -1 \
85 | --guidance "wd+ex+in" \
86 | --guidance_model_name "gpt2-xl" \
87 | --guidance_model_type "discrete" \
88 | --guidance_model_type_2 "discrete" \
89 | --discrete_max_length 200 \
90 | --discrete_guidance_use_trie \
91 | --alpha 100.0 \
92 | --beta 5.0
93 | ```
94 | The prefix-tuned model in stage 1 can be downloaded [here](https://drive.google.com/file/d/1gTxwetdyK3X-IkUnw0FFdfPw1KyLI5DK/view?usp=drive_link).
95 |
96 |
97 | ### Run Control Code Setting on Wikidata
98 | Coming soon...
99 |
100 | ### Run Natural Language Instruction Setting on Wikidata
101 | Coming soon...
102 |
103 | ## Prefix-Tuning the Guidance Model
104 | Coming soon...
105 |
106 | ## Questions
107 |
108 | Please contact Howard Chen (`howardchen@cs.princeton.edu`) if you have any questions.
109 |
110 | ## Citation
111 |
112 | ```bibtex
113 | @inproceedings{chen2022cognac,
114 | title={{Cognac}: Controllable Text Generation with Language Constraints},
115 | author={Chen, Howard and Li, Huihan and Chen, Danqi and Narasimhan, Karthik},
116 | booktitle={arXiv},
117 | year={2022}
118 | }
119 | ```
120 |
--------------------------------------------------------------------------------
/assets/cognacgen.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/Cognac/c1b304b884f76b667d2f76b325ecbadfb2d1c90c/assets/cognacgen.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/Cognac/c1b304b884f76b667d2f76b325ecbadfb2d1c90c/assets/teaser.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyyaml
2 | torch
3 | alive-progress
4 | transformers
5 | spacy
6 | nltk
7 | numpy
8 | rich
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/Cognac/c1b304b884f76b667d2f76b325ecbadfb2d1c90c/src/__init__.py
--------------------------------------------------------------------------------
/src/data_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | """
3 | import json
4 | import random
5 | from collections import defaultdict
6 |
7 | from tqdm import tqdm
8 | from rich import print
9 |
10 | import torch
11 | from nltk.tokenize import sent_tokenize
12 |
13 |
14 | def get_hierarchy(path):
15 | with open(path) as f:
16 | hierarchy_ = json.load(f)
17 | hierarchy = dict()
18 | for topic, leafs in hierarchy_.items():
19 | new_topic = topic.replace('_', ' ')
20 | new_leafs = [l.replace('_', ' ') for l in leafs]
21 | hierarchy[new_topic] = new_leafs
22 | return hierarchy
23 |
24 |
25 | def load_datasets(data_paths, hierarchy, args):
26 | def process(datapoints, data_path):
27 | if 'pairs' in dataset_path:
28 | pass
29 | else:
30 | datapoints = [
31 | d for d in datapoints
32 | if hierarchy is not None and
33 | d['constraint'] in hierarchy[d['topic']] and
34 | d['constraint'] != d['topic']
35 | ]
36 | return datapoints
37 |
38 | datasets = dict()
39 | for dataset_name, dataset_path in data_paths.items():
40 | if dataset_path is not None:
41 | datapoints = []
42 | with open(dataset_path) as f:
43 | for line in f:
44 | datapoint = json.loads(line.strip())
45 | datapoints.append(datapoint)
46 | datapoints = process(datapoints, dataset_path)
47 |
48 | if args.randomize_dataset:
49 | random.seed(1)
50 | random.shuffle(datapoints)
51 | datasets[dataset_name] = datapoints
52 | return datasets
53 |
54 |
55 | def get_guidance_data(args):
56 | if args.hierarchy_path is not None:
57 | hierarchy = get_hierarchy(args.hierarchy_path)
58 | else:
59 | hierarchy = None
60 |
61 | data_paths = {
62 | 'train': args.train_path,
63 | 'dev': args.dev_path,
64 | 'test': args.test_path,
65 | }
66 | datasets = load_datasets(data_paths, hierarchy, args)
67 |
68 | if args.wiki_gold is not None:
69 | gold = dict()
70 | with open('data/wordnet_wiki_clean.jsonl') as f:
71 | for line in f:
72 | obj = json.loads(line.strip())
73 | gold[obj['node'].replace('_', ' ')] = obj
74 | else:
75 | gold = None
76 |
77 | print(f'Dataset loaded.')
78 | return datasets, hierarchy, gold
79 |
80 |
81 | def get_bow(words, tokenizer):
82 | inds = []
83 | for word in words:
84 | inds += tokenizer.encode(word, add_prefix_space=True)
85 | inds = list(set(inds))
86 |
87 | inds = torch.tensor(inds)
88 | bow = torch.zeros(1, len(tokenizer))
89 | bow[:, inds] = 1.0
90 |
91 | return bow
92 |
93 |
94 | def regressor_batcher(datapoints, batch_size, max_length, **kwargs):
95 | hierarchy = kwargs.get('hierarchy', None)
96 | tokenizer = kwargs.get('tokenizer', None)
97 | use_cuda = kwargs.get('use_cuda', False)
98 | num_batches = len(datapoints) // batch_size + 1
99 |
100 | for i in range(num_batches):
101 | batch = datapoints[i * batch_size:(i + 1) * batch_size]
102 |
103 | if not batch:
104 | continue
105 |
106 | topic_texts = [f"Talk about {b['topic']}" for b in batch]
107 | topic_inputs = tokenizer.batch_encode_plus(
108 | topic_texts,
109 | padding='longest'
110 | )
111 | constraint_texts = [f"Don't talk about {b['constraint']}" for b in batch]
112 | constraint_inputs = tokenizer.batch_encode_plus(
113 | constraint_texts,
114 | padding='longest'
115 | )
116 |
117 | topics = [b['topic'] for b in batch]
118 | constraints = [b['constraint'] for b in batch]
119 | topic_targets = torch.cat([
120 | get_bow(hierarchy.get(topic, []) + [topic], tokenizer)
121 | for topic in topics
122 | ], dim=0)
123 | constraint_targets = torch.cat([
124 | get_bow(hierarchy.get(constraint, []) + [constraint], tokenizer)
125 | for constraint in constraints
126 | ], dim=0)
127 |
128 | if use_cuda:
129 | topic_targets = topic_targets.cuda()
130 | constraint_targets = constraint_targets.cuda()
131 |
132 | yield dict(
133 | ids=[b['id'] for b in batch],
134 | topic_inputs=topic_inputs,
135 | constraint_inputs=constraint_inputs,
136 | topic_targets=topic_targets,
137 | constraint_targets=constraint_targets,
138 | topics=topics,
139 | constraints=constraints,
140 | )
141 |
142 |
143 | def classifier_batcher(datapoints, batch_size, max_length, **kwargs):
144 | hierarchy = kwargs.get('hierarchy', None)
145 | tokenizer = kwargs.get('tokenizer', None)
146 | args = kwargs.get('args', None)
147 | num_batches = len(datapoints) // batch_size + 1
148 |
149 | for i in range(num_batches):
150 | batch = datapoints[i * batch_size:(i + 1) * batch_size]
151 | if not batch:
152 | continue
153 |
154 | if args.model_type == 'classification':
155 | texts = [sent_tokenize(b['text'])[1] for b in batch]
156 | label_texts = [b['label'] for b in batch]
157 | labels = torch.tensor([1 if b['label'] == 'yes' else 0 for b in batch]).long()
158 | elif args.model_type in ('prompt-tune', 'fine-tune'):
159 | _YES = 3763
160 | _NO = 645
161 | texts = [b['text'] for b in batch]
162 | label_texts = [b['label'] for b in batch]
163 | labels = torch.tensor([_YES if b['label'] == 'yes' else _NO for b in batch]).long()
164 | else:
165 | raise ValueError(f'Unknown model type: {args.model_type}')
166 | categories = [b['category'] for b in batch]
167 |
168 | encoded = tokenizer.batch_encode_plus(
169 | texts,
170 | max_length=max_length,
171 | padding='max_length',
172 | )
173 | input_ids = torch.tensor(encoded['input_ids']).long()
174 | attention_mask = torch.tensor(encoded['attention_mask']).long()
175 |
176 | input_ids = input_ids.cuda()
177 | attention_mask = attention_mask.cuda()
178 | labels = labels.cuda()
179 |
180 | yield dict(
181 | input_ids=input_ids,
182 | attention_mask=attention_mask,
183 | texts=texts,
184 | categories=categories,
185 | labels=labels,
186 | label_texts=label_texts,
187 | )
188 |
189 |
190 | class Node:
191 | def __init__(self, name, parent, children):
192 | assert isinstance(parent, Node) or parent is None
193 | assert isinstance(children, list)
194 | self.name = name
195 | self.parent = parent
196 | self.children = children
197 |
198 | def __repr__(self):
199 | #parent = self.parent.name if self.parent is not None else None
200 | #children = [c.name for c in self.children]
201 | return f'Node(name={self.name})'
202 |
203 |
204 | class Hierarchy:
205 | def __init__(self):
206 | self.name_to_node = dict()
207 | self.load()
208 |
209 | def get_node(self, name):
210 | return self.name_to_node.get(name, None)
211 |
212 | def get_leafs(self, node):
213 | pass
214 |
215 | def attach_children_to_parent(self, parent_name, children_names):
216 | parent = self.get_node(parent_name)
217 |
218 | if parent is None:
219 | parent = Node(parent_name, None, [])
220 | self.name_to_node[parent_name] = parent
221 |
222 | for child_name in children_names:
223 | child = self.get_node(child_name)
224 | if child is None:
225 | child = Node(child_name, parent, [])
226 | self.name_to_node[child_name] = child
227 | elif child.parent is None:
228 | child.parent = parent
229 | else:
230 | continue
231 | parent.children.append(child)
232 |
233 | def load(self):
234 | with open('data/hierarchy_path_to_children.jsonl') as f:
235 | for line in f:
236 | obj = json.loads(line.strip())
237 | hierarchy_path = [''] + obj['hierarchy_path']
238 | for p, c in zip(hierarchy_path, hierarchy_path[1:]):
239 | self.attach_children_to_parent(p, [c])
240 | parent = hierarchy_path[-1]
241 | children = obj['children']
242 | self.attach_children_to_parent(parent, children)
243 |
244 |
245 | if __name__ == '__main__':
246 | """
247 | python -m ci.data_utils
248 | """
249 | from transformers import AutoTokenizer
250 | hierarchy = get_hierarchy('data/topic_to_leafs.json')
251 | tokenizer = AutoTokenizer.from_pretrained('gpt2-xl', use_fast=False)
252 |
253 | datapoints = []
254 | with open('data/gpt2-xl/wiki_id_ood/sample/train.jsonl') as f:
255 | for line in f:
256 | datapoint = json.loads(line.strip())
257 | datapoints.append(datapoint)
258 | print(len(datapoints))
259 | datapoints = add_bow_to_datapoint(datapoints, hierarchy, tokenizer)
260 | print(len(datapoints))
261 |
--------------------------------------------------------------------------------
/src/diverse_instructions.csv:
--------------------------------------------------------------------------------
1 | index,topic,constraint,type,mode
2 | 0,Write down examples of {}.,Continue listing them but do not include examples of {}.,begin+end,train
3 | 1,Below we show examples of {1}.,"Following these examples, keep listing but don’t mention {0}.",begin,train
4 | 2,The above showed examples of {1}.,Now list more examples but don’t mention {0}.,end,train
5 | 3,Generate examples that are under the category of {}.,Now keep generating but exclude anything that's in the category of {}.,begin+end,dev
6 | 4,List out examples of {1}:,"Right after these examples, continue listing but avoid mentioning {0}.",begin,dev
7 | 5,The above are sentences describing {1}.,Now write similar sentences as the above while omitting any mention of {0}.,end,dev
8 | 6,Generate examples about {}:,Now generate more examples but refrain from generating {}.,begin+end,test
9 | 7,"Task: write down three examples of {1}.,",and after the three examples keep writing more but don’t refer to any {0}.,begin,test
10 | 8,The above sentences are examples of {1}.,Your task is to list more sentences but exclude mentions of {0}.,end,test
11 | 9,Make a list of {}.,Do not list any {}.,begin+end,test
12 | 10,Give me a list of {}.,Extend the list which does not pertain to {}.,begin+end,test
13 | 11,Listed below are examples of {1}.,"Keep going down this list, but don't include any {0}.",begin,test
14 | 12,Examples of {1} are shown below.,"List more examples, but don't mention {0} this time.",begin,test
15 | 13,There are examples of {1} below.,Write similar sentences after these that omits mention of {0}.,begin,test
16 | 14,The examples provided were of {1}.,provide additional examples without referring to the constraint {0}.,end,test
17 | 15,The examples given were of different instances of {1}.,give more examples that don't go against the original constraint {0}.,end,test
18 | 16,The examples illustrated were of {1}.,list more examples that don't bring up {0}.,end,test
19 | 17,The aforementioned showed examples of {1}.,Write similar sentences omitting any mention of {0}.,end,test
20 | 18,Create examples for {}.,"Keep generating ideas, but don't include anything that falls under the category of {}.",begin+end,test
21 | 19,Make up examples for {}.,"Keep coming up with ideas, but don't include anything that would fit under {}.",begin+end,test
22 | 20,Come up with examples for {}.,"Keep brainstorming, but don't write down anything that could be classified as {}.",begin+end,test
23 | 21,Provide a list of examples for {1}.,"After these examples, keep going but don't mention {0}.",begin,test
24 | 22,Give a list of examples for {1}.,"After these examples, continue listing but avoid mentioning {0}.",begin,test
25 | 23,Share a list of examples for {1}.,"Keep listing after these examples, but don't mention {0}.",begin,test
26 | 24,A list of examples for {1} is as follows:,"List more examples after these, but don't mention {0}.",begin,test
27 | 25,The sentences above describe {1}.,Write similar sentences without mentioning {0}.,end,test
28 | 26,The foregoing sentences describe {1}.,Write similar sentences while omitting mention of {0}.,end,test
29 | 27,These sentences describe {1}.,Write similar sentences that does not mention {0}.,end,test
30 | 28,Above are sentences of {1}.,Your task is to list more sentences without mentioning {0}.,end,test
31 | 29,The sentences below show instances of {1}.,Your goal is to list more instances that don't mention {0}.,end,test
32 | 30,The sentences below are examples of {1}.,Your task is to enumerate more of them which don't mention {0}.,end,test
33 | 31,The sentences below are examples of {1}.,Your task is to list more sentences that do not mention {0}.,end,test
34 | 32,Write down examples of {}.,"Generate more examples, but don't create any that include {}.",begin+end,test
35 | 33,"Given the topic {}, write down three examples:","Write down more examples, but don't mention {} in any of them.",begin+end,test
36 | 34,Come up with instances of the topic {}.,"Come up with more examples, but keep {} out of them.",begin+end,test
37 |
--------------------------------------------------------------------------------
/src/diverse_instructions.py:
--------------------------------------------------------------------------------
1 | """
2 | Topic/constraint descriptions as the instruction.
3 | """
4 | import csv
5 | import copy
6 | from src.utils import get_wikidata_p_text
7 |
8 | MAPPING = {
9 | -2: [
10 | 'begin+end',
11 | 'dummy topic {}',
12 | 'dummy constraint {}',
13 | 'dummy',
14 | ]
15 | }
16 | with open('src/diverse_instructions.csv', 'r') as f:
17 | reader = csv.DictReader(f)
18 | for row in reader:
19 | index = int(row['index'])
20 | topic_inst = row['topic'].strip()
21 | constraint_inst = row['constraint'].strip()
22 | inst_type = row['type']
23 | mode = row['mode']
24 |
25 | if inst_type == 'begin+end':
26 | MAPPING[int(row['index'])] = [
27 | inst_type,
28 | topic_inst,
29 | constraint_inst,
30 | mode,
31 | ]
32 | else:
33 | MAPPING[int(row['index'])] = [
34 | inst_type,
35 | topic_inst + ' ' + constraint_inst,
36 | mode,
37 | ]
38 |
39 |
40 | def get_instruction(datapoint, args, version):
41 | if args.data == 'wordnet':
42 | topic = datapoint['topic']
43 | constraint = datapoint['constraint']
44 | insert_position, *instruction_templates, mode = MAPPING[version]
45 | if insert_position == 'begin+end':
46 | instructions = [
47 | instruction_templates[0].format(topic),
48 | instruction_templates[1].format(constraint),
49 | ]
50 | elif insert_position == 'begin':
51 | instructions = [instruction_templates[0].format(constraint, topic), ""]
52 | elif insert_position == 'end':
53 | instructions = ["", instruction_templates[0].format(constraint, topic)]
54 | else:
55 | raise ValueError(f'`{insert_position}` not recognized.')
56 |
57 | elif args.data == 'wikidata':
58 | topic = datapoint['topic']
59 | constraint = datapoint['constraint']
60 | insert_position, *instruction_templates, mode = MAPPING[version]
61 |
62 | topic = get_wikidata_p_text(topic[0], topic[1])
63 | constraint = get_wikidata_p_text(constraint[0], constraint[1])
64 |
65 | if insert_position == 'begin+end':
66 | instructions = [
67 | instruction_templates[0].format(topic),
68 | instruction_templates[1].format(constraint),
69 | ]
70 | elif insert_position == 'begin':
71 | instructions = [instruction_templates[0].format(constraint, topic), ""]
72 | elif insert_position == 'end':
73 | instructions = ["", instruction_templates[0].format(constraint, topic)]
74 | else:
75 | raise ValueError(f'`{insert_position}` not recognized.')
76 |
77 | return dict(
78 | begin=instructions[0],
79 | end=instructions[1],
80 | insert_position=insert_position,
81 | )
82 |
83 |
84 | def prepare_context(datapoint, args, version):
85 | new_datapoint = copy.deepcopy(datapoint)
86 | blocks = new_datapoint['context']
87 | blocks = ['== ' + sent.strip('==').strip() + ' ==' for sent in blocks.split('==\n')]
88 |
89 | instruction = get_instruction(new_datapoint, args, version)
90 | begin = instruction['begin']
91 | end = instruction['end']
92 | insert_position = instruction['insert_position']
93 | blocks = [begin] + blocks + [end]
94 | context_with_instructions = '\n'.join(blocks).strip()
95 |
96 | new_datapoint['context_with_instructions'] = context_with_instructions
97 | new_datapoint['version'] = version
98 | new_datapoint['begin'] = begin
99 | new_datapoint['end'] = end
100 | new_datapoint['insert_position'] = insert_position
101 | return new_datapoint
102 |
--------------------------------------------------------------------------------
/src/experiment.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | import yaml
4 | import uuid
5 | import logging
6 | import argparse
7 | import tempfile
8 | from pathlib import Path
9 |
10 |
11 | class ExperimentManager:
12 | """
13 | The only object needed to be instantiated throughout an experiment run.
14 |
15 | Functionalities:
16 | 1. Create run dir
17 | 2. Manage all logging
18 | a. Save command
19 | b. Save config/args
20 | c. Save predictions
21 | d. Save final eval stats
22 | 3. Manage all metric calculations
23 | """
24 | def __init__(self,
25 | run_dir=None,
26 | name=None,
27 | override=True,
28 | num_runs=1,
29 | **kwargs
30 | ):
31 | """
32 | Create run dir and setup file paths.
33 | By default, one run folder will be created that contains the predictions file.
34 | If `num_runs` > 1, then the amount of sub dirs will be created.
35 | The sub dirs will be named `seed_{run_id}`.
36 | """
37 | self.num_runs = num_runs
38 |
39 | # Set up run dir path.
40 | if run_dir is None:
41 | run_dir = '/tmp/unnamed-experiments'
42 | if name is None:
43 | name = str(uuid.uuid4()).split('-')[-1]
44 |
45 | self.run_dir = Path(run_dir) / name
46 |
47 | self.config_path = self.run_dir / 'config.yaml'
48 |
49 | self.prediction_paths = [
50 | self.run_dir / f'seed_{run_id}' / 'predictions.jsonl'
51 | for run_id in range(num_runs)
52 | ]
53 |
54 | self.agg_result_path = self.run_dir / 'results.yaml'
55 | self.result_paths = [
56 | self.run_dir / f'seed_{run_id}' / 'results.yaml'
57 | for run_id in range(num_runs)
58 | ]
59 |
60 | # Create run dir and the needed paths.
61 | if not self.run_dir.exists() or override:
62 | self.run_dir.mkdir(parents=True, exist_ok=True)
63 | for run_id in range(num_runs):
64 | seed_run_dir = self.run_dir / f'seed_{run_id}'
65 | seed_run_dir.mkdir(parents=True, exist_ok=True)
66 | print(f'Run dir `{self.run_dir}` created or existed.')
67 | else:
68 | print(f'Run dir `{self.run_dir}` not saved or overriden.')
69 |
70 | # Set up file loggers.
71 | self.file_loggers = []
72 | for run_id in range(self.num_runs):
73 | self.setup_logger(f'file_logger_{run_id}', run_id)
74 | self.file_loggers.append(logging.getLogger(f'file_logger_{run_id}'))
75 |
76 | # Set up other modules.
77 | self.metric_tracker = kwargs.get('metric_tracker', None)
78 |
79 | def setup_logger(self, logger_name, run_id=None):
80 | logger = logging.getLogger(logger_name)
81 | logger.setLevel(logging.INFO)
82 | formatter = logging.Formatter('%(message)s')
83 | handler = logging.FileHandler(self.prediction_paths[run_id], mode='w')
84 | handler.setFormatter(formatter)
85 | logger.addHandler(handler)
86 | logger.propagate = False
87 |
88 | def console(self, content):
89 | print(content)
90 |
91 | def log(self, content, run_id):
92 | assert run_id is not None, 'With `save=True`, `run_id` must be provided.'
93 | content_to_log = (
94 | json.dumps(content) if isinstance(content, dict) else content
95 | )
96 | self.file_loggers[run_id].info(content_to_log)
97 |
98 | def add_metric(self, content, run_id):
99 | assert self.metric_tracker is not None and run_id is not None, (
100 | 'With `add_metric=True`, the `metric_tracker` cannnot be `None` and '
101 | '`run_id` must be specified.')
102 | self.metric_tracker.add(content, run_id=run_id)
103 |
104 | def take(self, content, console=False, log=False, run_id=None, add_metric=False):
105 | """
106 | To log to files, a `run_id` must be provided.
107 | `content` can be of type `str` or in most cases `dict`.
108 |
109 | With default arguments, nothing will be performed.
110 | The function is meant to be used when multiple uses are needed at once.
111 | """
112 | if console:
113 | self.console(content)
114 | if log:
115 | self.log(content, run_id)
116 | if add_metric:
117 | self.add_metric(content, run_id)
118 |
119 | def save_results(self, results: dict, run_id=None):
120 | if run_id is None:
121 | with open(self.agg_result_path, 'w') as f:
122 | yaml.dump(results, f, default_flow_style=False)
123 | else:
124 | with open(self.result_paths[run_id], 'w') as f:
125 | yaml.dump(results, f, default_flow_style=False)
126 |
127 | def save_config(self, config):
128 | """
129 | Save config and command used to execute the main script.
130 | """
131 | if isinstance(config, argparse.Namespace):
132 | config = vars(config)
133 |
134 | with open(self.config_path, 'w') as f:
135 | yaml.dump(config, f)
136 |
137 | def compute_metric(self, metric_name, run_id):
138 | assert self.metric_tracker is not None
139 | return self.metric_tracker.compute(metric_name, run_id)
140 |
141 | def aggregate_metrics(self, avg_runs=True):
142 | assert self.metric_tracker is not None
143 | return self.metric_tracker.aggregate(avg_runs=avg_runs)
144 |
145 |
--------------------------------------------------------------------------------
/src/guidance_models.py:
--------------------------------------------------------------------------------
1 | import re
2 | import random
3 | from pathlib import Path
4 | from ast import literal_eval
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from nltk.tokenize import sent_tokenize
11 |
12 | from src.guidance_utils import (
13 | remove_parenthesis,
14 | get_wikidata_query,
15 | get_wordnet_query,
16 | )
17 | from src.instruct2guide.utils import postprocess
18 |
19 | _YES = 3763 # gpt2 tokenizer
20 | _NO = 645 # gpt2 tokenizer
21 | _PAD = 50256 # gpt2 tokenizer
22 |
23 | random.seed(0)
24 |
25 |
26 | class BaseGuidanceModel(nn.Module):
27 | def __init__(self, guidance_lm, tokenizer, args, **kwargs):
28 | super().__init__()
29 | self.args = args
30 | self.tokenizer = tokenizer
31 | self.guidance_lm = guidance_lm
32 | self.hierarchy = kwargs.get('hierarchy', None)
33 | self.entity_to_trie = dict()
34 |
35 | # Track generation state and use the next token for guidance.
36 | self.guidance_prefix = {
37 | 'in': [],
38 | 'ex': [],
39 | }
40 |
41 | # Track which `example_id` is done and stop guidance.
42 | self.in_finished = set()
43 |
44 | def forward(self, entity, query_entity=None, **kwargs):
45 | if self.args.data == 'wordnet':
46 | query = get_wordnet_query(
47 | guidance_model_type=self.guidance_model_type,
48 | entity=entity,
49 | query_entity=query_entity,
50 | args=self.args,
51 | **kwargs,
52 | )
53 | elif self.args.data == 'wikidata':
54 | query = get_wikidata_query(
55 | guidance_model_type=self.guidance_model_type,
56 | entity=entity,
57 | query_entity=query_entity,
58 | args=self.args,
59 | **kwargs,
60 | )
61 | else:
62 | raise ValueError(f'Unknown data: {self.args.data}')
63 |
64 | query_tokens = self.tokenizer.encode(query)
65 | query_tokens = torch.tensor(query_tokens).long().unsqueeze(0).cuda()
66 | out = self.guidance_lm(query_tokens)
67 | return out.logits[:, -1, :]
68 |
69 | def query_guidance(self, entity, **kwargs):
70 | """
71 | Returns indicies, loss, and info.
72 | """
73 | raise NotImplementedError
74 |
75 | def get_trie_filtered_indicies(
76 | self,
77 | mode,
78 | words,
79 | entity,
80 | last_token,
81 | curr_context,
82 | example_id,
83 | ):
84 | if self.args.data == 'wikidata':
85 | entity = tuple(entity)
86 | # Debug ########################################################
87 | #print('id:', example_id)
88 | #print('mode:', mode)
89 | #print('gen_examples:', words)
90 | #print('entity:', entity)
91 | #print('in_finished:', self.in_finished)
92 | ################################################################
93 | if mode == 'in' and example_id in self.in_finished:
94 | return []
95 |
96 | if entity not in self.entity_to_trie:
97 | trie = Trie()
98 | for word in words:
99 | name_tok_inds = self.tokenizer.encode(' ' + word)
100 | name_toks = [
101 | self.tokenizer.decode(ind, skip_special_tokens=True)
102 | for ind in name_tok_inds
103 | ]
104 |
105 | trie.start_node_inds.add(name_tok_inds[0])
106 | trie.insert(name_toks)
107 | trie.node_set.update(name_toks)
108 | self.entity_to_trie[entity] = trie
109 |
110 | # Get the trie corresponding to the current entity.
111 | trie = self.entity_to_trie[entity]
112 | last_token_text = self.tokenizer.decode(last_token.squeeze(0))
113 |
114 | # Debug #############################################################
115 | #curr_context_text = self.tokenizer.decode(curr_context.squeeze(0))
116 | #print('last:', last_token_text)
117 | #print('curr context:', curr_context_text)
118 | #print('node_set:', trie.node_set)
119 | #print(f'gp before [{mode}]: {self.guidance_prefix}')
120 | #####################################################################
121 |
122 | if last_token_text not in trie.node_set:
123 | # Debug ##########################################
124 | #print(f'[{mode}] here 1 -> [{last_token_text}]')
125 | ##################################################
126 | # Use the the start tokens for guidance.
127 | inds = list(trie.start_node_inds)
128 | else:
129 | # Debug ##########################################
130 | #print(f'[{mode}] here 2 -> [{last_token_text}]')
131 | ##################################################
132 | self.guidance_prefix[mode].append(last_token_text)
133 | query_prefix = self.guidance_prefix[mode]
134 | name_next_tok_index = len(query_prefix)
135 |
136 | name_candidates = trie.query(query_prefix)
137 | #print(f'Name candidates [{mode}]: {name_candidates}')
138 | if not name_candidates:
139 | # The prefix fails. Reset the tracker.
140 | inds = list(trie.start_node_inds)
141 | self.guidance_prefix[mode] = []
142 | else:
143 | name_candidate_tok_texts = [
144 | c[name_next_tok_index] for c in name_candidates
145 | if len(c) > name_next_tok_index
146 | ]
147 | #print(name_candidate_tok_texts)
148 | inds = [
149 | self.tokenizer.encode(tok)[0] for tok in name_candidate_tok_texts
150 | ]
151 | name_candidates = {''.join(toks) for toks in name_candidates}
152 | hit = ''.join(query_prefix) in {''.join(toks) for toks in name_candidates}
153 | if hit:
154 | self.in_finished.add(example_id)
155 | self.guidance_prefix[mode] = []
156 | # Debug ######################################################################
157 | #print(f'gp after [{mode}]: {self.guidance_prefix}')
158 | #print([self.tokenizer.decode(ind) for ind in inds])
159 | #print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
160 | #input()
161 | ##############################################################################
162 | return inds
163 |
164 |
165 | class BinaryGuidanceModel(BaseGuidanceModel):
166 | def __init__(self, guidance_lm, tokenizer, args, **kwargs):
167 | super().__init__(guidance_lm, tokenizer, args, **kwargs)
168 | self.guidance_model_type = 'binary'
169 |
170 | def query_guidance(self, entity, **kwargs):
171 | mode = kwargs.get('mode', None)
172 | query_entity = kwargs['query_entity']
173 | orig_probs = kwargs['orig_probs']
174 | if query_entity is None:
175 | return [], None, dict(
176 | yes_prob=-1.0,
177 | no_prob=-1.0,
178 | )
179 | if self.args.guidance_model_name == 'oracle':
180 | if self.args.data == 'wordnet':
181 | oracle_words = self.hierarchy[entity] + [entity]
182 | oracle_plural_words = [w + 's' for w in oracle_words]
183 | all_oracle_words = set(oracle_words) | set(oracle_plural_words)
184 | normalized_yes_prob = 1.0 * (query_entity in all_oracle_words)
185 | elif self.args.data == 'wikidata':
186 | words = self.hierarchy[entity]
187 | words = [remove_parenthesis(title) for q_id, title in words]
188 | words = set([w.lower() for w in words])
189 | normalized_yes_prob = 1.0 * (query_entity in words)
190 | else:
191 | raise ValueError(f'Data arg {self.args.data} not recognized.')
192 | yes_prob = normalized_yes_prob
193 | no_prob = 1.0 - normalized_yes_prob
194 | else:
195 | logits = self.forward(entity, query_entity)
196 | probs = F.softmax(logits, dim=-1)
197 | yes_prob = probs[:, _YES].squeeze(0).item()
198 | no_prob = probs[:, _NO].squeeze(0).item()
199 | normalized_yes_prob = yes_prob / (yes_prob + no_prob)
200 |
201 | if normalized_yes_prob > kwargs['threshold']:
202 | inds = torch.argmax(orig_probs, dim=1)
203 | loss = torch.log(orig_probs.max())
204 | else:
205 | loss = None
206 | inds = []
207 |
208 | return inds, loss, dict(
209 | yes_prob=yes_prob,
210 | no_prob=no_prob,
211 | query_entity=query_entity,
212 | entity=entity,
213 | )
214 |
215 |
216 | class FullGuidanceModel(BaseGuidanceModel):
217 | """
218 | Full guidance model operates with the next token's full probability,
219 | as opposed to binary guidance model only offers yes/no probability.
220 | """
221 | def __init__(self, guidance_lm, tokenizer, args, **kwargs):
222 | super().__init__(guidance_lm, tokenizer, args, **kwargs)
223 | self.guidance_model_type = 'full'
224 | self.in_bow_vecs = None
225 | self.ex_bow_vecs = None
226 | self.in_set = None
227 | self.ex_set = None
228 |
229 | def query_guidance(self, entity, **kwargs):
230 | mode = kwargs.get('mode', None)
231 | query_entity = kwargs['query_entity']
232 | orig_probs = kwargs['orig_probs']
233 |
234 | if self.args.guidance_model_name == 'oracle':
235 | if mode == 'in':
236 | self.in_bow_vecs, self.in_set = self._init_bow_vecs(entity, **kwargs)
237 | bow_vecs = self.in_bow_vecs
238 | ind_set = self.in_set
239 | elif mode == 'ex':
240 | self.ex_bow_vecs, self.ex_set = self._init_bow_vecs(entity, **kwargs)
241 | bow_vecs = self.ex_bow_vecs
242 | ind_set = self.ex_set
243 | else:
244 | raise ValueError(f'`mode` {mode} not supported.')
245 |
246 | loss = 0.0
247 | for vec in bow_vecs:
248 | loss += torch.log(torch.mm(orig_probs, torch.t(vec)).sum())
249 | inds = list(ind_set)
250 | else:
251 | logits = self.forward(entity, **kwargs)
252 | if mode == 'in':
253 | full_guide_topk = self.args.full_guide_topk_in
254 | elif mode == 'ex':
255 | full_guide_topk = self.args.full_guide_topk_ex
256 | else:
257 | raise ValueError(f'`mode` {mode} not supported.')
258 |
259 | _, inds = torch.topk(logits, k=full_guide_topk)
260 | inds = inds.squeeze(0) # 1xk -> k
261 | topk_probs = orig_probs[:, inds]
262 | loss = torch.log(topk_probs.sum())
263 |
264 | ind_tokens = [self.tokenizer.decode(ind) for ind in inds]
265 |
266 | return inds, loss, dict(
267 | yes_prob=-1.0,
268 | no_prob=-1.0,
269 | query_entity=query_entity,
270 | entity=entity,
271 | ind_tokens=ind_tokens,
272 | )
273 |
274 | def _init_bow_vecs(self, entity, **kwargs):
275 | mode = kwargs.get('mode', None)
276 | curr_context = kwargs.get('curr_context', None)
277 | if self.args.data == 'wordnet':
278 | if mode == 'in':
279 | words = self.hierarchy[entity]
280 | elif mode == 'ex':
281 | words = self.hierarchy[entity] + [entity]
282 | words = words + [w + 's' for w in words]
283 | elif self.args.data == 'wikidata':
284 | if mode == 'in':
285 | words = self.hierarchy[entity][:100]
286 | elif mode == 'ex':
287 | words = self.hierarchy[entity]
288 | else:
289 | raise ValueError(f'`{mode}` not recognized.')
290 |
291 | # If full length is used there's gonna be too many Q's.
292 | words = [remove_parenthesis(q_title) for q_id, q_title in words]
293 | else:
294 | raise ValueError(f'Data arg {self.args.data} not recognized.')
295 |
296 | inds = self.get_trie_filtered_indicies(
297 | mode,
298 | words,
299 | entity,
300 | kwargs['last_token'],
301 | kwargs['curr_context'],
302 | kwargs['example_id'],
303 | )
304 |
305 | vecs = []
306 | if inds:
307 | inds = torch.tensor(inds)
308 | onehot = torch.zeros(1, len(self.tokenizer))
309 | onehot[:, inds] = 1.0
310 | vecs.append(onehot.cuda())
311 | return vecs, set(inds)
312 |
313 |
314 | class DiscreteGuidanceModel(BaseGuidanceModel):
315 | def __init__(self, guidance_lm, tokenizer, args, **kwargs):
316 | super().__init__(guidance_lm, tokenizer, args, **kwargs)
317 | self.guidance_model_type = 'discrete'
318 | # NOTE: this might cause memory leak. This might be largely fine
319 | # given the expected number of new entity.
320 | self.enity_to_examples = dict()
321 |
322 | def forward_prefix(self, entity, **kwargs):
323 | datapoint = kwargs['datapoint']
324 | if self.args.data == 'wikidata':
325 | entity = tuple(entity)
326 | #gen_examples = self.enity_to_examples.get(entity, None)
327 | gen_examples = self.enity_to_examples.get((entity, datapoint['version']), None)
328 | if gen_examples is not None:
329 | return dict(gen_examples=gen_examples)
330 |
331 | datapoint = kwargs['datapoint']
332 | text = datapoint['context_with_instructions'] + ' [SEP]'
333 |
334 | if kwargs['mode'] == 'in':
335 | mode = 'topic'
336 | elif kwargs['mode'] == 'ex':
337 | mode = 'constraint'
338 | else:
339 | raise ValueError(f'mode={mode} not recognized.')
340 |
341 | gen_text = self.guidance_lm.generate(text, mode=mode)
342 | gen_examples = postprocess(gen_text, data_mode=self.args.data)
343 | #self.enity_to_examples[entity] = gen_examples
344 | self.enity_to_examples[(entity, datapoint['version'])] = gen_examples
345 | return dict(
346 | gen_examples=gen_examples,
347 | gen_texts=[gen_text],
348 | )
349 |
350 | def forward(self, entity, **kwargs):
351 | datapoint = kwargs['datapoint']
352 | if self.args.data == 'wikidata':
353 | entity = tuple(entity)
354 | #gen_examples = self.enity_to_examples.get(entity, None)
355 | gen_examples = self.enity_to_examples.get((entity, datapoint['version']), None)
356 | if gen_examples is not None:
357 | return dict(gen_examples=gen_examples)
358 |
359 | if self.args.data == 'wordnet':
360 | query = get_wordnet_query(
361 | guidance_model_type=self.guidance_model_type,
362 | entity=entity,
363 | args=self.args,
364 | **kwargs,
365 | )
366 | elif self.args.data == 'wikidata':
367 | query = get_wikidata_query(
368 | guidance_model_type='discrete',
369 | entity=entity,
370 | **kwargs,
371 | )
372 | else:
373 | raise ValueError(f'{self.args.data} not recognized.')
374 |
375 | query_tokens = self.tokenizer.encode(query)
376 | query_tokens = torch.tensor(query_tokens).long().unsqueeze(0).cuda()
377 |
378 | outputs = self.guidance_lm.generate(
379 | query_tokens,
380 | max_length=self.args.discrete_max_length,
381 | #skip_special_tokens=True,
382 | pad_token_id=self.tokenizer.eos_token_id,
383 | no_repeat_ngram_size=2,
384 | do_sample=self.args.discrete_guidance_do_sample,
385 | num_beams=self.args.discrete_guidance_num_beams,
386 | num_beam_groups=self.args.discrete_guidance_num_beam_groups,
387 | top_p=self.args.discrete_guidance_top_p,
388 | top_k=self.args.discrete_guidance_top_k,
389 | temperature=self.args.discrete_guidance_temperature,
390 | diversity_penalty=self.args.discrete_guidance_diversity_penalty,
391 | num_return_sequences=self.args.discrete_guidance_num_return_sequences or 1,
392 | )
393 | gen_examples = []
394 | gen_texts = []
395 | pattern = 'Some examples are:'
396 | for output in outputs:
397 | gen_text = self.tokenizer.decode(output, skip_special_tokens=True)
398 | gen_texts.append(gen_text)
399 | gen_sents = sent_tokenize(gen_text)
400 | gen_sents = [sent for sent in gen_sents if sent.startswith(pattern)]
401 | gen = gen_sents[1].replace(pattern, '').strip().strip('.').split(', ')
402 | gen_examples += gen
403 | gen_examples = list(set(gen_examples))
404 | #self.enity_to_examples[entity] = gen_examples
405 | self.enity_to_examples[(entity, datapoint['version'])] = gen_examples
406 | return dict(
407 | gen_examples=gen_examples,
408 | gen_texts=gen_texts,
409 | )
410 |
411 | def query_guidance_with_loss(self, entity, **kwargs):
412 | mode = kwargs.get('mode', None)
413 | gen_examples = self.forward(entity)
414 |
415 | if mode == 'in':
416 | self.in_bow_vecs, self.in_set = self._init_bow_vecs(gen_examples)
417 | bow_vecs = self.in_bow_vecs
418 | elif mode == 'ex':
419 | self.ex_bow_vecs, self.ex_set = self._init_bow_vecs(gen_examples)
420 | bow_vecs = self.ex_bow_vecs
421 | else:
422 | raise ValueError(f'`mode` {mode} not supported.')
423 |
424 | loss = 0.0
425 | for vec in bow_vecs:
426 | loss += torch.log(torch.mm(next_token_prob, torch.t(vec)).sum())
427 |
428 | return dict(
429 | loss=loss,
430 | yes_prob=-1.0,
431 | no_prob=-1.0,
432 | gen_examples=gen_examples,
433 | query_entity=query_entity,
434 | entity=entity,
435 | )
436 |
437 | def query_guidance(self, entity, **kwargs):
438 | mode = kwargs.get('mode', None)
439 | if self.args.discrete_guidance_instruct2guide_model_dir is not None:
440 | gen_out = self.forward_prefix(entity, **kwargs)
441 | else:
442 | gen_out = self.forward(entity, **kwargs)
443 |
444 | inds = self._init_bow_vecs(
445 | gen_out['gen_examples'],
446 | entity=entity,
447 | return_only_inds=True,
448 | **kwargs
449 | )
450 | inds = list(inds)
451 | gen_info = dict(
452 | gen_examples=gen_out['gen_examples'],
453 | entity=entity,
454 | )
455 | return inds, None, gen_info
456 |
457 | def _init_bow_vecs(self, words, entity, return_only_inds=False, **kwargs):
458 | if self.args.data == 'wordnet':
459 | words = words + [w + 's' for w in words]
460 |
461 | # Use trie.
462 | if self.args.discrete_guidance_use_trie:
463 | inds = self.get_trie_filtered_indicies(
464 | kwargs['mode'],
465 | words,
466 | entity,
467 | kwargs['last_token'],
468 | kwargs['curr_context'],
469 | kwargs['example_id'],
470 | )
471 | if return_only_inds:
472 | return set(inds)
473 | else:
474 | bow_indices = [
475 | self.tokenizer.encode(word.strip(), add_prefix_space=True)
476 | for word in words
477 | ]
478 | bow_set = set([ind for inds in bow_indices for ind in inds])
479 | if return_only_inds:
480 | return bow_set
481 |
482 |
483 | class OracleGuidanceModel(BaseGuidanceModel):
484 | def __init__(self, guidance_lm, tokenizer, args, **kwargs):
485 | super().__init__(guidance_lm, tokenizer, args, **kwargs)
486 | self.guidance_model_type = 'oracle'
487 |
488 | def query_guidance(self, entity, **kwargs):
489 | mode = kwargs.get('mode', None)
490 |
491 | if self.args.data == 'wordnet':
492 | if mode == 'in':
493 | words = self.hierarchy[entity]
494 | elif mode == 'ex':
495 | words = self.hierarchy[entity] + [entity]
496 | else:
497 | raise ValueError(f'`{mode}` not recognized.')
498 | elif self.args.data == 'wikidata':
499 | if mode == 'in':
500 | words = self.hierarchy[entity][:100]
501 | elif mode == 'ex':
502 | words = self.hierarchy[entity][:1000]
503 | else:
504 | raise ValueError(f'`{mode}` not recognized.')
505 | # If full length is used there's gonna be too many Q's.
506 | words = [remove_parenthesis(q_title) for q_id, q_title in words]
507 | else:
508 | raise ValueError(f'Data arg {self.args.data} not recognized.')
509 |
510 | inds = self._init_bow_vecs(
511 | words,
512 | entity=entity,
513 | return_only_inds=True,
514 | **kwargs
515 | )
516 | inds = list(inds)
517 | gen_info = dict(
518 | entity=entity,
519 | )
520 | #print(mode)
521 | #print(entity)
522 | #print(words)
523 | #print([self.tokenizer.decode(ind) for ind in inds])
524 | #print(self.guidance_prefix)
525 | #print('---')
526 | #input()
527 | return inds, None, gen_info
528 |
529 | def _init_bow_vecs(self, words, entity, return_only_inds=False, **kwargs):
530 | # Use trie.
531 | if self.args.discrete_guidance_use_trie:
532 | inds = self.get_trie_filtered_indicies(
533 | kwargs['mode'],
534 | words,
535 | entity,
536 | kwargs['last_token'],
537 | kwargs['curr_context'],
538 | kwargs['example_id'],
539 | )
540 | if return_only_inds:
541 | return set(inds)
542 | else:
543 | bow_indices = [
544 | self.tokenizer.encode(word.strip(), add_prefix_space=True)
545 | for word in words
546 | ]
547 | bow_set = set([ind for inds in bow_indices for ind in inds])
548 | if return_only_inds:
549 | return bow_set
550 |
551 |
552 | class TrieNode:
553 | """A node in the trie structure"""
554 |
555 | def __init__(self, tok):
556 | self.tok = tok
557 |
558 | self.is_end = False
559 |
560 | # A counter indicating how many times a word
561 | # is inserted (if this node's is_end is True).
562 | self.counter = 0
563 |
564 | # A dictionary of child nodes. Keys are tokens, values are nodes.
565 | self.children = {}
566 |
567 |
568 | class Trie:
569 | """The trie object"""
570 |
571 | def __init__(self):
572 | """
573 | The trie has at least the root node.
574 | The root node does not store any character
575 | """
576 | self.root = TrieNode('')
577 | self.node_set = set()
578 | self.start_node_inds = set()
579 |
580 | def insert(self, word):
581 | """Insert a word into the trie"""
582 | node = self.root
583 |
584 | # Loop through each token in the word.
585 | # Check if there is no child containing the character,
586 | # create a new child for the current node.
587 | for tok in word:
588 | if tok in node.children:
589 | node = node.children[tok]
590 | else:
591 | # If a token is not found, create a new node in the trie.
592 | new_node = TrieNode(tok)
593 | node.children[tok] = new_node
594 | node = new_node
595 |
596 | node.is_end = True
597 |
598 | # Increment the counter to indicate that we see this word once more.
599 | node.counter += 1
600 |
601 | def dfs(self, node, prefix):
602 | """
603 | Depth-first traversal of the trie.
604 |
605 | Args:
606 | - node: the node to start with.
607 | - prefix: the current prefix, for tracing a word while traversing.
608 | """
609 | if node.is_end:
610 | self.output.append(prefix + [node.tok])
611 |
612 | for child in node.children.values():
613 | self.dfs(child, prefix + [node.tok])
614 |
615 | def query(self, x):
616 | """Given an input (a prefix), retrieve all words stored in
617 | the trie with that prefix, sort the words by the number of
618 | times they have been inserted
619 | """
620 | # Use a variable within the class to keep all possible outputs
621 | # as there can be more than one word with such prefix.
622 | self.output = []
623 | node = self.root
624 |
625 | # Check if the prefix is in the trie.
626 | for tok in x:
627 | if tok in node.children:
628 | node = node.children[tok]
629 | else:
630 | return []
631 |
632 | # Traverse the trie to get all candidates.
633 | self.dfs(node, x[:-1])
634 |
635 | # Sort the results in reverse order and return.
636 | return self.output
637 |
--------------------------------------------------------------------------------
/src/guidance_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities relevant to the guidance models.
3 | """
4 | import re
5 | from pathlib import Path
6 |
7 | import torch
8 | import spacy
9 |
10 | from src.utils import (
11 | to_namespace,
12 | get_data,
13 | get_lm,
14 | get_tokenizer,
15 | cleanup_gen_text,
16 | reformat_text,
17 | load_ckpt,
18 | )
19 |
20 | SMALL_CONST = 1e-8
21 | PROMPT_TEXT = (
22 | # 'List out some examples of us presidents. Some examples are: lincoln, obama, washignton.\n'
23 | # 'List out some examples of color. Some examples are: red, blue, green.\n'
24 | # 'List out some examples of cities. Some examples are: new york, oslo, tokyo.\n'
25 | 'List out some examples of movie genres. Some examples are: drama, horror, action.\n'
26 | )
27 | PROMPT_TEXT_TEMPLATE = {
28 | 'animal': 'List out some examples of animals. Some examples are: cat, dog, lion.\n',
29 | 'food': 'List out some examples of food. Some examples are: pizza, burger, pasta.\n',
30 | 'vehicle': 'List out some examples of vehicles. Some examples are: car, bus, train.\n',
31 | 'art': 'List out some examples of art. Some examples are: painting, sculpture, drawing.\n',
32 | 'sport': 'List out some examples of sports. Some examples are: soccer, basketball, tennis.\n',
33 | }
34 |
35 | spacy_parser = spacy.load("en_core_web_lg")
36 |
37 |
38 | def remove_parenthesis(s):
39 | m = re.search(r'(.*) \(.*\)', s)
40 | if m is None:
41 | return s
42 | else:
43 | return m.group(1)
44 |
45 |
46 | def get_wikidata_query(
47 | guidance_model_type,
48 | entity,
49 | query_entity=None,
50 | **kwargs
51 | ):
52 | p_name, p_value = entity
53 |
54 | if guidance_model_type == 'binary':
55 | if p_name == 'place_of_birth':
56 | prompt = 'Was {0} born in {1}? The answer is'
57 | elif p_name == 'place_of_death':
58 | prompt = 'Did {0} die in {1}? The answer is'
59 | elif p_name == 'occupation':
60 | prompt = 'Was {0} a {1}? The answer is'
61 | elif p_name == 'country_of_citizenship':
62 | prompt = 'Was {0} a citizen of {1}? The answer is'
63 | elif p_name == 'academic_degree':
64 | prompt = 'Did {0} hold a degree in {1}? The answer is'
65 | elif p_name == 'educated_at':
66 | prompt = 'Did {0} get the education at {1}? The answer is'
67 | else:
68 | raise ValueError(f'Property name `{p_name}` not recognized.')
69 | query = prompt.format(query_entity, p_value)
70 | return query
71 |
72 | elif guidance_model_type in ('full', 'discrete'):
73 | #QUERY_FORMAT = 'List out some famous names of {0}. Some examples are:'
74 | #QUERY_FORMAT = 'List out some famous full names of passed {0}. Some examples are:'
75 | #QUERY_FORMAT = 'List out some historical names of {0}. Some examples are:'
76 | QUERY_FORMAT = 'List out some famous names of dead {0}. Some examples are:'
77 | if p_name == 'place_of_birth':
78 | fill_text = f'people who were born in {p_value}'
79 | prompt_text = f'List out some famous names of dead people who has traveled to France. Some examples are: Ernest Hemingway, Miles Davis, Oscar Wilde.\n'
80 | elif p_name == 'place_of_death':
81 | fill_text = f'people who died in {p_value}'
82 | prompt_text = f'List out some famous names of dead people who has traveled to France. Some examples are: Ernest Hemingway, Miles Davis, Oscar Wilde.\n'
83 | elif p_name == 'occupation':
84 | fill_text = p_value
85 | prompt_text = f'List out some famous names of dead people who were tech CEOs. Some examples are: Steve Jobs, Mark Hurd, Bill Campbell.\n'
86 | elif p_name == 'country_of_citizenship':
87 | fill_text = f'people who are citizens of {p_value}'
88 | prompt_text = f'List out some famous names of dead people who has traveled to France. Some examples are: Ernest Hemingway, Miles Davis, Oscar Wilde.\n'
89 | elif p_name == 'academic_degree':
90 | fill_text = f'people who hold a degree in {p_value}'
91 | prompt_text = f'List out some famous names of dead people who has traveled to France. Some examples are: Ernest Hemingway, Miles Davis, Oscar Wilde.\n'
92 | elif p_name == 'educated_at':
93 | #fill_text = f'people who had their education at {p_value}'
94 | fill_text = f'people who were educatied at {p_value}'
95 | prompt_text = f'List out some famous names of dead people who has traveled to France. Some examples are: Ernest Hemingway, Miles Davis, Oscar Wilde.\n'
96 | else:
97 | raise ValueError(f'Property name `{p_name}` not recognized.')
98 | query = QUERY_FORMAT.format(fill_text)
99 | #prompt_text = kwargs.get('prompt_text', PROMPT_TEXT)
100 | return (prompt_text + query).strip()
101 |
102 | else:
103 | raise ValueError(f'Guidance model type {guidance_model_type} not recognized.')
104 |
105 |
106 | def get_wordnet_query(
107 | guidance_model_type,
108 | entity,
109 | query_entity=None,
110 | **kwargs,
111 | ):
112 | args = kwargs.get('args', None)
113 | if guidance_model_type == 'binary':
114 | MAGIC_PROMPT = 'I am an expert in taxonomy.'
115 | QUERY_FORMAT = 'Is {0} a type of {1}? The answer is{2}'
116 | query = [MAGIC_PROMPT]
117 | if args.num_icl_pairs > 0:
118 | pos_icl_pairs, neg_icl_pairs, noisy_queries = load_icl_pairs(args)
119 | icl_text = get_icl_query(
120 | args.num_icl_pairs,
121 | pos_icl_pairs,
122 | neg_icl_pairs,
123 | noisy_queries,
124 | )
125 | else:
126 | icl_text = []
127 | query += icl_text
128 | query.append(QUERY_FORMAT.format(query_entity, entity, ''))
129 | query = ' '.join(query).strip()
130 | return query
131 |
132 | elif guidance_model_type in ('full', 'discrete'):
133 | root_node = kwargs['datapoint']['parents'][0][0]
134 | prompt_text = PROMPT_TEXT_TEMPLATE.get(root_node, PROMPT_TEXT)
135 |
136 | QUERY_FORMAT = 'What are some examples of {0}? Some examples are:'
137 | query = QUERY_FORMAT.format(entity)
138 | return (prompt_text + query).strip()
139 |
140 | else:
141 | raise ValueError(f'Guidance model type {guidance_model_type} not recognized.')
142 |
143 |
144 | def load_icl_pairs(args):
145 | """
146 | For in-context learning.
147 | """
148 | pos_icl_pairs = []
149 | neg_icl_pairs = []
150 | roots = ['animal', 'food', 'sport', 'art', 'vehicle']
151 | for root in roots:
152 | path = Path(args.icl_pair_dir) / f'train_pairs_{root}_new.txt'
153 | with open(path) as f:
154 | for line in f:
155 | cat, pos, neg = literal_eval(line.strip())
156 | cat = cat.replace('_', ' ')
157 | pos = pos.replace('_', ' ')
158 | neg = neg.replace('_', ' ')
159 | pos_icl_pairs.append((pos, cat, root))
160 | neg_icl_pairs.append((neg, cat, root))
161 | noisy_queries = []
162 | with open('/path/to/your/noisy_queries.txt') as f:
163 | for line in f:
164 | noisy_queries.append(line.strip())
165 | return pos_icl_pairs, neg_icl_pairs, noisy_queries
166 |
167 |
168 | def get_icl_query(num_icl_pairs, pos_icl_pairs, neg_icl_pairs, noisy_queries):
169 | pos_icl_pairs = random.choices(pos_icl_pairs, k=num_icl_pairs)
170 | neg_icl_pairs = random.choices(neg_icl_pairs, k=num_icl_pairs)
171 | noisy_icl = random.choices(noisy_queries, k=num_icl_pairs)
172 |
173 | icl = []
174 | for p, n, noisy_q in zip(pos_icl_pairs, neg_icl_pairs, noisy_icl):
175 | icl += [
176 | QUERY_FORMAT.format(p[0], p[1], ' yes.'),
177 | QUERY_FORMAT.format(n[0], n[1], ' no.'),
178 | # QUERY_FORMAT.format(noisy_q, p[1], ' no.'),
179 | # QUERY_FORMAT.format(noisy_q, n[1], ' no.'),
180 | ]
181 | random.shuffle(icl)
182 | return icl
183 |
184 |
185 | def get_guidance_model(
186 | args,
187 | tokenizer,
188 | hierarchy,
189 | num_devices=None,
190 | guidance_lm=None,
191 | ):
192 | from src.guidance_models import (
193 | BinaryGuidanceModel,
194 | FullGuidanceModel,
195 | DiscreteGuidanceModel,
196 | OracleGuidanceModel,
197 | )
198 | guidance_args = dict(
199 | tokenizer=tokenizer,
200 | args=args,
201 | hierarchy=hierarchy,
202 | )
203 | guidance_model_type = args.guidance_model_type
204 |
205 | # Load the core guidance LM (trained, un-trained, or oracle).
206 | if guidance_lm is not None:
207 | # Use the provided guidance_lm from the args.
208 | guidance_model_type = args.guidance_model_type_2
209 | elif args.guidance_model_path:
210 | # Load fine-tuned/prompt-tuned guidance models.
211 | guidance_lm, _ = load_ckpt(load_path=args.guidance_model_path)
212 | elif args.guidance_model_name == 'oracle':
213 | guidance_lm = None
214 | elif args.discrete_guidance_instruct2guide_model_dir is not None:
215 | assert args.guidance_model_type == 'discrete'
216 | from src.instruct2guide.utils import load_checkpoint
217 | guidance_lm = load_checkpoint(
218 | args.discrete_guidance_instruct2guide_model_dir
219 | )
220 | else:
221 | guidance_lm = get_lm(args, num_devices, load_mode='guidance')
222 |
223 | # Construct the complete guidance model
224 | if args.guidance_model_path:
225 | guidance_model_class = BinaryGuidanceModel
226 | elif guidance_model_type == 'binary':
227 | guidance_model_class = BinaryGuidanceModel
228 | elif guidance_model_type == 'full':
229 | guidance_model_class = FullGuidanceModel
230 | elif guidance_model_type == 'discrete':
231 | guidance_model_class = DiscreteGuidanceModel
232 | elif guidance_model_type == 'oracle':
233 | guidance_model_class = OracleGuidanceModel
234 | guidance_lm = None
235 | else:
236 | raise ValueError(f'Guidance model type: {guidance_model_type} not supported.')
237 |
238 | guidance_args.update(guidance_lm=guidance_lm)
239 | guidance_model = guidance_model_class(**guidance_args)
240 | return guidance_model
241 |
242 |
243 | def get_gradient(loss, curr_history_delta, step_size, args, retain_graph):
244 | if not torch.is_tensor(loss):
245 | # In the case where no loss is returned, we don't need to compute
246 | # the gradient. This happens when using the guidance model and it
247 | # predicts that the answer to the query is ``no''.
248 | grad = [
249 | (torch.zeros_like(key), torch.zeros_like(value))
250 | for key, value in curr_history_delta
251 | ]
252 | else:
253 | loss.backward(retain_graph=retain_graph)
254 | grad_norms = [
255 | (
256 | torch.norm(key.grad) + SMALL_CONST,
257 | torch.norm(value.grad) + SMALL_CONST
258 | )
259 | for key, value in curr_history_delta
260 | ]
261 | grad = [(
262 | - step_size * (key.grad / key_grad_norm ** args.gamma),
263 | - step_size * (value.grad / value_grad_norm ** args.gamma)
264 | )
265 | for (key, value), (key_grad_norm, value_grad_norm)
266 | in zip(curr_history_delta, grad_norms)
267 | ]
268 | return grad
269 |
270 |
271 | def get_query_entity(tokens, tokenizer):
272 | """
273 | Check and return word/entity with the following order:
274 | 1) tokens form a named entity + extra tokens
275 | 2) tokens form a single word + extra tokens
276 |
277 | """
278 | text = tokenizer.decode(tokens).strip(' ')
279 | parsed = spacy_parser(text)
280 | parsed = [(t.text, t.pos_) for t in parsed]
281 |
282 | consecutive_idx = 0
283 | query_entity = []
284 | query_tags = []
285 | for i, (t_text, t_pos) in enumerate(parsed):
286 | if i == consecutive_idx and t_pos == 'PROPN':
287 | query_entity.append(t_text)
288 | query_tags.append(t_pos)
289 | consecutive_idx += 1
290 |
291 | if not query_entity and parsed and parsed[0][1] == 'NOUN':
292 | query_entity = [parsed[0][0]]
293 | query_tags = [parsed[0][1]]
294 |
295 | query_entity = ' '.join(query_entity)
296 | query_entity = query_entity.lower() if query_entity else None
297 | return query_entity, query_tags, parsed
298 |
299 |
300 | def deepcopy_history(history):
301 | return [
302 | (key.clone(), value.clone())
303 | for key, value in history
304 | ]
305 |
306 |
307 | def add_key_values(xs, deltas):
308 | added_xs = []
309 | for x, delta in zip(xs, deltas):
310 | x_k, x_v = x
311 | delta_k, delta_v = delta
312 | added_xs.append((x_k + delta_k, x_v + delta_v))
313 | return added_xs
314 |
--------------------------------------------------------------------------------
/src/guide.py:
--------------------------------------------------------------------------------
1 | """
2 | The guidance procedure.
3 | - No Guidance
4 | - PPLM (pplm)
5 | - Weighted Decoding (wd)
6 | """
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from src.guidance_utils import (
11 | get_gradient,
12 | get_query_entity,
13 | deepcopy_history,
14 | add_key_values,
15 | )
16 |
17 |
18 | def run_pplm_step(
19 | example_id,
20 | args,
21 | model,
22 | guidance_model,
23 | in_guidance_model,
24 | tokenizer,
25 | history,
26 | last_token,
27 | curr_context,
28 | topic,
29 | constraint,
30 | datapoint,
31 | ):
32 | """
33 | Perform PPLM refinement steps to obtain the perturbed history.
34 | """
35 | history_delta = [
36 | (torch.zeros_like(key), torch.zeros_like(value))
37 | for key, value in history
38 | ]
39 | refinement = [] # Track refinement tokens and the guidance probabilties.
40 | for i in range(args.refinement_steps):
41 | curr_history_delta = [
42 | (
43 | key.clone().detach().requires_grad_(True),
44 | value.clone().detach().requires_grad_(True)
45 | )
46 | for key, value in history_delta
47 | ]
48 | perturbed_history = add_key_values(history, curr_history_delta)
49 |
50 | # Generate multiple tokens to query the guidance model.
51 | next_token_ = last_token.clone().detach()
52 | history_ = deepcopy_history(perturbed_history)
53 | multistep_tokens = []
54 | multistep_probs = []
55 | for _ in range(args.max_multistep_len):
56 | outputs_ = model(next_token_, past_key_values=history_)
57 | logits_ = outputs_.logits[:, -1, :]
58 | probs_ = F.softmax(logits_, dim=-1)
59 | history_ = outputs_.past_key_values
60 |
61 | next_token_ = torch.argmax(probs_, dim=1).unsqueeze(0)
62 | multistep_tokens.append(next_token_.item())
63 | multistep_probs.append(probs_)
64 |
65 | # Use the list of probs/histories of the entity tokens for further use.
66 | query_entity, query_tags, parsed = \
67 | get_query_entity(multistep_tokens, tokenizer)
68 |
69 | # Still use the latest token for generatation while useing the
70 | # full entity probs to get the guidance signal.
71 | next_token_after_refinement = multistep_tokens[0]
72 | next_token_after_refinement_text = tokenizer.decode(
73 | next_token_after_refinement,
74 | skip_special_tokens=True,
75 | ).strip(' ').replace('\n', '\\n')
76 | perturbed_probs = multistep_probs[0]
77 |
78 | ex_loss = 0.0
79 | ex_grad = None
80 | ex_guidance_outs = None
81 | if 'ex' in args.guidance:
82 | ex_guidance_outs = guidance_model.calc_loss(
83 | perturbed_probs,
84 | query_entity,
85 | constraint,
86 | mode='ex',
87 | threshold=args.g_threshold,
88 | )
89 | ex_loss = ex_guidance_outs['loss']
90 |
91 | # NOTE: the sign of the loss is assigned here
92 | ex_grad = get_gradient(
93 | ex_loss,
94 | curr_history_delta,
95 | args.alpha,
96 | args,
97 | retain_graph=True,
98 | )
99 |
100 | in_loss = 0.0
101 | in_grad = None
102 | in_guidance_outs = None
103 | if 'in' in args.guidance:
104 | in_guidance_outs = in_guidance_model.calc_loss(
105 | perturbed_probs,
106 | query_entity,
107 | topic,
108 | mode='in',
109 | threshold=args.g_threshold,
110 | )
111 | in_loss = in_guidance_outs['loss']
112 |
113 | # NOTE: the sign of the loss is assigned here
114 | in_grad = get_gradient(
115 | -in_loss,
116 | curr_history_delta,
117 | args.beta,
118 | args,
119 | retain_graph=False,
120 | )
121 |
122 | grad = None
123 | if ex_grad is not None:
124 | grad = ex_grad
125 | if in_grad is not None:
126 | grad = add_key_values(in_grad, ex_grad)
127 |
128 | history_delta = add_key_values(history_delta, grad)
129 | for key, value in curr_history_delta:
130 | if key.grad is not None:
131 | key.grad.zero_()
132 | if value.grad is not None:
133 | value.grad.zero_()
134 |
135 | refinement.append(dict(
136 | last_token=tokenizer.decode(
137 | last_token.item(),
138 | skip_special_tokens=True
139 | ).strip(' ').replace('\n', '\\n'),
140 | next_token_after_refinement_text=next_token_after_refinement_text,
141 | query_entity=query_entity,
142 | query_tags=query_tags,
143 | parsed=parsed,
144 | ex_yes_prob=(
145 | ex_guidance_outs['yes_prob']
146 | if ex_guidance_outs is not None else -1.0
147 | ),
148 | ex_no_prob=(
149 | ex_guidance_outs['no_prob']
150 | if ex_guidance_outs is not None else -1.0
151 | ),
152 | in_yes_prob=(
153 | in_guidance_outs['yes_prob']
154 | if in_guidance_outs is not None else -1.0
155 | ),
156 | in_no_prob=(
157 | in_guidance_outs['no_prob']
158 | if in_guidance_outs is not None else -1.0
159 | ),
160 | ))
161 | perturbed_history = add_key_values(history, history_delta)
162 | final_outputs = model(last_token, past_key_values=perturbed_history)
163 |
164 | final_logits = final_outputs.logits[:, -1, :]
165 | return final_logits, refinement
166 |
167 |
168 | def run_wd_step(
169 | example_id,
170 | args,
171 | model,
172 | guidance_model,
173 | in_guidance_model,
174 | tokenizer,
175 | history,
176 | last_token,
177 | curr_context,
178 | topic,
179 | constraint,
180 | datapoint,
181 | ):
182 | stepwise_info = []
183 |
184 | orig_outputs = model(last_token, past_key_values=history)
185 | orig_logits = orig_outputs.logits[:, -1, :]
186 | orig_probs = F.softmax(orig_logits, dim=-1)
187 |
188 | # Generate multiple tokens to query the guidance model.
189 | if not args.max_multistep_len:
190 | query_entity = None
191 | else:
192 | next_token_ = last_token.clone().detach()
193 | history_ = deepcopy_history(history)
194 | multistep_tokens = []
195 | for _ in range(args.max_multistep_len):
196 | outputs_ = model(next_token_, past_key_values=history_)
197 | logits_ = outputs_.logits[:, -1, :]
198 | probs_ = F.softmax(logits_, dim=-1)
199 | history_ = outputs_.past_key_values
200 | next_token_ = torch.argmax(probs_, dim=1).unsqueeze(0)
201 | multistep_tokens.append(next_token_.item())
202 |
203 | # Use the list of probs/histories of the entity tokens for further use.
204 | query_entity, query_tags, parsed = get_query_entity(multistep_tokens, tokenizer)
205 | if query_entity is None:
206 | return orig_logits, []
207 |
208 | wd_guidance_args = dict(
209 | example_id=example_id,
210 | orig_probs=orig_probs,
211 | query_entity=query_entity,
212 | threshold=args.g_threshold,
213 | curr_context=curr_context,
214 | last_token=last_token,
215 | datapoint=datapoint,
216 | )
217 |
218 | ex_inds, _, ex_info = guidance_model.query_guidance(
219 | constraint,
220 | mode='ex',
221 | **wd_guidance_args,
222 | )
223 | orig_logits[:, ex_inds] = orig_logits[:, ex_inds] - args.alpha
224 |
225 | if 'in' in args.guidance:
226 | in_inds, _, in_info = \
227 | in_guidance_model.query_guidance(topic, mode='in', **wd_guidance_args)
228 | orig_logits[:, in_inds] = orig_logits[:, in_inds] + args.beta
229 | else:
230 | in_info = dict()
231 |
232 | stepwise_info.append(dict(
233 | exclusion_info=ex_info,
234 | inclusion_info=in_info,
235 | ))
236 | final_logits = orig_logits
237 | return final_logits, stepwise_info
238 |
239 |
240 | def run_constrained_decoding(
241 | example_id,
242 | args,
243 | model,
244 | guidance_model,
245 | in_guidance_model,
246 | tokenizer,
247 | history,
248 | last_token,
249 | curr_context,
250 | topic,
251 | constraint,
252 | datapoint,
253 | ):
254 | orig_outputs = model(last_token, past_key_values=history)
255 | orig_logits = orig_outputs.logits[:, -1, :]
256 |
257 | if 'ex' in args.guidance:
258 | words = guidance_model.hierarchy[constraint]
259 | if args.data == 'wordnet':
260 | words = words + [constraint]
261 | elif args.data == 'wikidata':
262 | raise ValueError('Not finished yet. To be completed.')
263 | else:
264 | raise ValueError(f'Data arg {args.data} not recognized.')
265 |
266 | bow_indices = [
267 | tokenizer.encode(word.strip(), add_prefix_space=True)
268 | for word in words
269 | ]
270 | for inds in bow_indices:
271 | orig_logits[:, inds] = float('-inf')
272 |
273 | return final_logits, []
274 |
--------------------------------------------------------------------------------
/src/guide_with_offset.py:
--------------------------------------------------------------------------------
1 | """
2 | Offset Guidance.
3 |
4 | Guide generation using the combination of three probabilities
5 | 1. p_1 = p(x_t | x_{ 0:
86 | generated_text = generated_text[0]
87 | else:
88 | generated_text = ''
89 |
90 | if generated_text == '':
91 | print('Empty text.')
92 | continue
93 |
94 | input_ids = lm_scorer.tokenizer(generated_text, return_tensors="pt").input_ids.cuda()
95 | #print(input_ids.size())
96 | outputs = lm_scorer.model(input_ids, labels=input_ids)
97 | ppl = torch.exp(outputs.loss)
98 | if torch.isnan(ppl) or ppl > 500:
99 | print('NaN loss.')
100 | continue
101 | ppls.append((prediction_id, ppl.item()))
102 | print(f'avg: {sum(ppl for _, ppl in ppls) / len(ppls)} (n={len(ppls)})')
103 |
104 | save_path = path.parent / 'ppl.txt'
105 | with open(save_path, 'w') as f:
106 | for prediction_id, ppl in ppls:
107 | f.write(f'{prediction_id},{ppl}\n')
108 | f.write(f'avg,{sum(ppl for _, ppl in ppls) / len(ppls)}\n')
109 | print('Saved at:', save_path)
110 | print('----------------------------------')
111 |
112 |
113 | @torch.no_grad()
114 | def try_compute_ppl(lm_scorer):
115 | texts = [
116 | 'There is a dog hiding behind the tree.',
117 | 'There is a dog is a that hiding behind the tree.',
118 | 'The dog is hiding behind the tree.',
119 | 'The dog is hiding in cat of the movie tree.',
120 | 'this is a dog this is a dog this is a dog this is a dog',
121 | 'dog dog dog dog dog dog dog dog dog',
122 | ]
123 | for text in texts:
124 | print('text:')
125 | print(text)
126 | input_ids = lm_scorer.tokenizer(text, return_tensors="pt").input_ids.cuda()
127 | print([lm_scorer.tokenizer.decode(tok) for tok in input_ids[0]])
128 | outputs = lm_scorer.model(input_ids, labels=input_ids)
129 | print(outputs.loss)
130 |
131 | ppl = torch.exp(outputs.loss).item()
132 | print(f'PPL: {ppl:.4f}')
133 | ppl = lm_scorer.sentence_score(text)
134 | print(f'PPL 2: {ppl:.4f}')
135 | print('---')
136 |
137 |
138 | if __name__ == '__main__':
139 | """
140 | python -m ci.lm_scorer
141 | """
142 | model_name = 'gpt2-xl'
143 | #model_name = 'EleutherAI/gpt-j-6B'
144 |
145 | paths = [
146 | # wordnet
147 | #'./runs/wd/wordnet/wordnet_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=5.0_in=discrete_ex=discrete_trie_i2g_evalv=test/seed_0/predictions.jsonl',
148 | #'./ci/baselines/runs/gpt_engine="davinci"_temp=0.9_top_p=0.95_eval=-1_split=test/wordnet/predictions.jsonl',
149 | #'./ci/baselines/runs/gpt_engine="text-davinci-002"_temp=0.9_top_p=0.95_eval=-1_split=test/wordnet/predictions.jsonl',
150 | #'./runs/wd/wordnet/wordnet_nl+ex+in_gen=gpt2-xl_guide=none_a=100.0_b=5.0_in=none_ex=none_evalv=dev/seed_0/predictions.jsonl',
151 | #'./ci/baselines/runs/gpt_engine="davinci"_temp=0.9_top_p=0.95_eval=-1/wordnet/predictions.jsonl',
152 | #'./ci/baselines/runs/gpt3-legacy/wordnet/predictions.jsonl',
153 | #'./ci/baselines/runs/gpt3-extra-prompt/wordnet/predictions.jsonl',
154 | #'./runs/wd/wordnet/wordnet_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=5.0_in=discrete_ex=discrete_trie_evalv=dev/seed_0/predictions.jsonl',
155 | #'./runs/wd/wordnet/selfdebias_a=0.5_b=0.5_pmt=expof/seed_0/predictions.jsonl',
156 | #'./ci/baselines/runs/ctrl/wordnet/debug/ctrl_eval/seed_0/predictions.jsonl',
157 |
158 | # wikidata
159 | #'./runs/wd/wikidata/wikidata_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=10.0_in=discrete_ex=discrete_i2g_evalv=test/seed_0/predictions.jsonl',
160 | #'./ci/baselines/runs/gpt_engine="davinci"_temp=0.9_top_p=0.95_eval=-1_split=test/wikidata/predictions.jsonl',
161 | #'./ci/baselines/runs/gpt_engine="text-davinci-002"_temp=0.9_top_p=0.95_eval=-1_split=test/wikidata/predictions.jsonl',
162 | #'./runs/wd/wikidata/wikidata_nl+ex+in_gen=gpt2-xl_guide=none_a=100.0_b=10.0_in=none_ex=none_evalv=dev/seed_0/predictions.jsonl',
163 | #'./ci/baselines/runs/gpt_engine="davinci"_temp=0.9_top_p=0.95_eval=-1/wikidata/predictions.jsonl',
164 | #'./ci/baselines/runs/gpt3-legacy/wikidata/predictions.jsonl',
165 | #'./ci/baselines/runs/gpt3/wikidata/predictions.jsonl',
166 | #'./runs/wd/wikidata/wikidata_wd+ex+in_gen=gpt2-xl_guide=oracle_a=100.0_b=10.0_in=none_ex=oracle_evalv=-2/seed_0/predictions.jsonl',
167 | #'./runs/wd/wikidata/_right_results_bad_pred_files/wikidata_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=5.0_in=discrete_ex=binary/seed_0/predictions.jsonl',
168 | #'./runs/wd/wikidata/_right_results_bad_pred_files/wikidata_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=5.0_in=discrete_ex=full_extopk=20/seed_0/predictions.jsonl',
169 | #'./runs/wd/wikidata/wikidata_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=10.0_in=discrete_ex=discrete_beam=8_topp=0.92_return=8_impprompt/seed_0/predictions.jsonl',
170 | #'./runs/wd/wikidata/wikidata_wd+ex+in_gen=gpt2-xl_guide=gpt2-xl_a=100.0_b=10.0_in=discrete_ex=discrete_i2g_evalv=dev/seed_0/predictions.jsonl',
171 | #'./runs/wd/wikidata/selfdebias_a=0.5_b=0.5_pmt=expof/predictions.jsonl',
172 | './ci/baselines/runs/ctrl/wikidata/debug/ctrl_eval/seed_0/predictions.jsonl',
173 | ]
174 |
175 | lm_scorer = LMScorer(model_name=model_name)
176 | #try_compute_ppl(lm_scorer)
177 | for path in paths:
178 | print(path)
179 | compute_ppl(path, lm_scorer)
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 | import yaml
4 | import json
5 | import time
6 | import random
7 | import logging
8 | import argparse
9 | from argparse import Namespace
10 | from pathlib import Path
11 |
12 | import torch
13 | import torch.nn.functional as F
14 |
15 | from rich import print
16 | from alive_progress import alive_bar
17 |
18 | from src.utils import (
19 | get_data,
20 | get_lm,
21 | get_tokenizer,
22 | cleanup_gen_text,
23 | reformat_text,
24 | top_p_sampling,
25 | )
26 | from src.diverse_instructions import prepare_context
27 | from src.guide import (
28 | run_pplm_step,
29 | run_wd_step,
30 | run_constrained_decoding,
31 | )
32 | from src.guide_with_offset import run_offset_step
33 | from src.guidance_utils import add_key_values, get_guidance_model
34 | from src.metrics import compute_prediction_metrics, aggregate_metrics
35 | from src.lm_scorer import LMScorer
36 |
37 | from src.experiment import ExperimentManager
38 | from src.metric_tracker import MetricTracker
39 |
40 | random.seed(0)
41 |
42 |
43 | def run_generation(
44 | datapoint,
45 | hierarchy,
46 | model,
47 | tokenizer,
48 | args,
49 | guidance_model=None,
50 | in_guidance_model=None
51 | ):
52 | if args.eval_version == -2:
53 | context = datapoint['context']
54 | else:
55 | context = datapoint['context_with_instructions']
56 | #context = datapoint['context']
57 | context = tokenizer.encode(context, return_tensors='pt').cuda()
58 | start_context_length = context.size(1)
59 | topic = datapoint['topic']
60 | constraint = datapoint['constraint']
61 | example_id = datapoint['id']
62 | if in_guidance_model is None:
63 | in_guidance_model = guidance_model
64 |
65 | gen_step = 0
66 | context_length = 0
67 | generated_tokens = []
68 | end = []
69 | refinements = []
70 |
71 | while (
72 | context_length < args.max_gen_length + start_context_length and
73 | end != [' ==', '\n']
74 | ):
75 | last_token = context[:, -1:]
76 | curr_context = context[:, :-1]
77 |
78 | outputs = model(curr_context)
79 | history = outputs.past_key_values
80 |
81 | gen_state = dict(
82 | example_id=example_id,
83 | args=args,
84 | model=model,
85 | guidance_model=guidance_model,
86 | in_guidance_model=in_guidance_model,
87 | tokenizer=tokenizer,
88 | history=history,
89 | last_token=last_token,
90 | curr_context=curr_context,
91 | topic=topic,
92 | constraint=constraint,
93 | datapoint=datapoint,
94 | )
95 | if 'pplm' in args.guidance:
96 | final_logits, refinement = run_pplm_step(**gen_state)
97 | final_probs = F.softmax(final_logits, dim=-1)
98 | elif 'wd' in args.guidance:
99 | final_logits, refinement = run_wd_step(**gen_state)
100 | final_probs = F.softmax(final_logits, dim=-1)
101 | elif 'cd' in args.guidance:
102 | final_logits, refinement = run_constrained_decoding(**gen_state)
103 | final_probs = F.softmax(final_logits, dim=-1)
104 | elif args.guidance == 'none' or 'nl' in args.guidance:
105 | final_outputs = model(last_token, past_key_values=history)
106 | final_logits = final_outputs.logits[:, -1, :]
107 | final_probs = F.softmax(final_logits, dim=-1)
108 | refinement = []
109 | elif 'os' in args.guidance:
110 | final_probs, refinement = run_offset_step(**gen_state)
111 | else:
112 | raise ValueError(f'Guidance type `{args.guidance}` not supported.')
113 | refinements.append(refinement)
114 |
115 | if args.fusion_gamma is not None or args.fusion_gamma != 1.0:
116 | orig_outputs = model(last_token, past_key_values=history)
117 | orig_logits = orig_outputs.logits[:, -1, :]
118 | orig_probs = F.softmax(orig_logits, dim=-1)
119 | final_probs = final_probs.pow(args.fusion_gamma) * \
120 | orig_probs.pow(1.0 - args.fusion_gamma)
121 |
122 | if args.top_p is not None:
123 | final_logits = top_p_sampling(final_logits, args.top_p)
124 | final_probs = F.softmax(final_logits / args.temperature, dim=-1)
125 | next_token = torch.multinomial(final_probs, num_samples=1)
126 | else:
127 | next_token = torch.argmax(final_probs, dim=1)
128 | next_token_text = tokenizer.decode(
129 | next_token,
130 | skip_special_tokens=True
131 | )
132 |
133 | context = torch.cat([context, next_token[:, None]], dim=-1)
134 | context_length = context.size(1)
135 |
136 | generated_tokens.append(next_token)
137 | gen_step += 1
138 |
139 | end.append(next_token_text)
140 | end = end[-2:]
141 | return generated_tokens, refinements
142 |
143 |
144 | def setup_logger(args):
145 | handlers = []
146 | run_dir = Path(args.run_dir) / args.name
147 | run_dir.mkdir(parents=True, exist_ok=True)
148 | prediction_path = run_dir / 'predictions.jsonl'
149 | config_path = run_dir / 'config.yaml'
150 |
151 | if args.override == 'manual':
152 | ans = input(
153 | f'The following files will be override:\n'
154 | f'`{prediction_path}`\n'
155 | f'`{config_path}`\n'
156 | f'Proceed? [yes/no]: '
157 | )
158 | if ans.lower() == 'yes':
159 | handlers.append(logging.FileHandler(prediction_path, 'w'))
160 | with open(config_path, 'w') as f:
161 | yaml.dump(vars(args), f)
162 | else:
163 | sys.exit(0)
164 | elif args.override == 'auto':
165 | handlers.append(logging.FileHandler(prediction_path, 'w'))
166 | with open(config_path, 'w') as f:
167 | yaml.dump(vars(args), f)
168 | elif args.override == 'no':
169 | handlers.append(logging.FileHandler(prediction_path))
170 | with open(config_path, 'w') as f:
171 | yaml.dump(vars(args), f)
172 | else:
173 | raise ValueError(f'`{args.override}` not recognized.')
174 |
175 | if args.log_to_console:
176 | handlers.append(RichHandler())
177 |
178 | logging.basicConfig(
179 | format='%(message)s',
180 | level=logging.INFO,
181 | handlers=handlers
182 | )
183 |
184 | # Save command to file in `run_dir``.
185 | run_command = ' '.join(['python -m ci.main'] + sys.argv[1:])
186 | with open(run_dir / 'run.sh', 'w') as f:
187 | f.write(run_command)
188 | return run_dir
189 |
190 |
191 | def setup_args(args):
192 | if args.guidance_model_name == 'oracle':
193 | args.max_multistep_len = 0
194 | if args.guidance != 'none':
195 | args.guidance = args.guidance.split('+')
196 | return args
197 |
198 |
199 | def run(args, run_id, exp_manager=None):
200 | #run_dir = setup_logger(args)
201 | args = setup_args(args)
202 |
203 | vs = []
204 | os = []
205 | onvs = []
206 |
207 | num_devices = torch.cuda.device_count()
208 |
209 | datasets, hierarchy, gold = get_data(args)
210 | model = get_lm(args, num_devices)
211 | tokenizer = get_tokenizer(args)
212 |
213 | # Load the guidance model.
214 | if args.guidance_model_type != 'none':
215 | guidance_model = get_guidance_model(
216 | args,
217 | tokenizer,
218 | hierarchy,
219 | num_devices=num_devices
220 | )
221 | else:
222 | guidance_model = None
223 |
224 | if (args.guidance_model_type_2 != args.guidance_model_type and
225 | args.guidance_model_type_2 != 'none'):
226 | in_guidance_model = get_guidance_model(
227 | args,
228 | tokenizer,
229 | hierarchy,
230 | guidance_lm=guidance_model.guidance_lm
231 | )
232 | else:
233 | in_guidance_model = None
234 |
235 | lm_scoreer = LMScorer(model=model, tokenizer=tokenizer)
236 |
237 | # Main loop.
238 | predictions = []
239 | prediction_metrics = []
240 | total = len(datasets[args.dataset_split])
241 |
242 | with alive_bar(total, enrich_print=False) as bar:
243 | for idx, datapoint in enumerate(datasets[args.dataset_split]):
244 | if len(predictions) == args.num_datapoints:
245 | break
246 |
247 | topic = datapoint['topic']
248 | constraint = datapoint['constraint']
249 |
250 | if args.eval_version == -1 and args.dataset_split == 'dev':
251 | eval_version = random.choice(range(3, 6))
252 | elif args.eval_version == -1 and args.dataset_split == 'test':
253 | eval_version = random.choice(range(6, 35))
254 | else:
255 | eval_version = args.eval_version
256 |
257 | datapoint = prepare_context(datapoint, args, version=eval_version)
258 |
259 | # NOTE: The main logic starts here.
260 | gen_ids, refinements = run_generation(
261 | datapoint,
262 | hierarchy,
263 | model,
264 | tokenizer,
265 | args,
266 | guidance_model,
267 | in_guidance_model=in_guidance_model,
268 | )
269 |
270 | generated_text = tokenizer.decode(
271 | torch.cat(gen_ids),
272 | skip_special_tokens=True
273 | )
274 | generated_text = cleanup_gen_text(generated_text)
275 | prediction = dict(
276 | datapoint=datapoint,
277 | guidance=(
278 | '+'.join(args.guidance)
279 | if args.guidance != 'none' else args.guidance
280 | ),
281 | generated_text=generated_text,
282 | refinements=refinements,
283 | generated_tokens=[tokenizer.decode(token) for token in gen_ids],
284 | )
285 | predictions.append(prediction)
286 |
287 | prediction_metric_outputs = compute_prediction_metrics(
288 | prediction,
289 | hierarchy,
290 | data_mode=args.data,
291 | )
292 | prediction_metric = prediction_metric_outputs['prediction_metric']
293 | prediction_metric['id'] = datapoint['id']
294 | extracted = prediction_metric_outputs['extracted']
295 | if generated_text:
296 | ppl = lm_scoreer.sentence_score(generated_text)
297 | else:
298 | ppl = 0.0
299 |
300 | prediction_metric.update({'ppl': ppl})
301 | prediction_metrics.append(prediction_metric)
302 |
303 | exp_manager.take(
304 | {**prediction, **prediction_metric},
305 | log=True,
306 | console=False,
307 | run_id=run_id,
308 | add_metric=True,
309 | )
310 |
311 | v = prediction_metric['violated']
312 | o = prediction_metric['on_topic']
313 | onv = prediction_metric['on_topic_not_violated']
314 | vs.append(v)
315 | os.append(o)
316 | onvs.append(onv)
317 | print(f'name: {args.name} (num={len(predictions)})')
318 | print(args)
319 | print(f'eval_version: {eval_version}')
320 | print(f'topic: {topic}')
321 | print(datapoint['context'])
322 | print(f'constraint: {constraint}')
323 | print(f'generated_text:')
324 | print(generated_text)
325 | print(f'Extracted: {extracted}')
326 | print(f'v: {v}')
327 | print(f'o: {o}')
328 | print(f'onv: {onv}')
329 | print(f'ppl: {ppl:.4f}')
330 | print(f'Accumulated violated: {sum(vs) / len(vs):.4f}')
331 | print(f'Accumulated on_topic: {sum(os) / len(os):.4f}')
332 | print(f'Accumulated on_topic_not_violated: {sum(onvs) / len(onvs):.4f}')
333 | print('---')
334 |
335 | time.sleep(0.005)
336 | bar()
337 |
338 | stats = exp_manager.aggregate_metrics()
339 | exp_manager.save_results(stats)
340 | print(stats)
341 | print(args)
342 | print('Run finished.')
343 |
344 |
345 | def get_argparse(config=None):
346 | parser = argparse.ArgumentParser()
347 | parser.add_argument("--name", type=str, default=None)
348 | parser.add_argument("--run_dir", type=str, default="./runs")
349 | parser.add_argument("--debug", action="store_true")
350 | parser.add_argument("--override", type=str, default="manual",
351 | choices=['auto', 'manual', 'no'],
352 | help=(
353 | "`auto`: override all, "
354 | "`manual`: ask in terminal before override, "
355 | "`no`: no override."
356 | )
357 | )
358 | parser.add_argument("--data", type=str, default='wordnet',
359 | choices=['wordnet', 'wikidata'],
360 | )
361 |
362 | # Data arguments.
363 | parser.add_argument("--train_path", type=str, default=None)
364 | parser.add_argument("--dev_path", type=str, default=None)
365 | parser.add_argument("--test_path", type=str, default=None)
366 | parser.add_argument("--hierarchy_path", type=str, default=None)
367 | parser.add_argument("--wiki_gold", type=str, default=None)
368 | parser.add_argument("--num_datapoints", type=int, default=500)
369 | parser.add_argument("--dataset_split", type=str, default='dev')
370 | parser.add_argument("--eval_version", type=int, default=0,
371 | help=(
372 | "-1: use dataset_split to decide which version to use.\n"
373 | "-2: use only the context."
374 | )
375 | )
376 |
377 | # Generation model arguments.
378 | parser.add_argument("--model_name", type=str, default="gpt2-xl")
379 | parser.add_argument("--top_p", type=float, default=None)
380 | parser.add_argument("--temperature", type=float, default=1.0)
381 | parser.add_argument("--guidance", type=str, default='none',
382 | choices=[
383 | 'none', # No constraint
384 | 'pplm+ex', 'pplm+ex+in', # PPLM
385 | 'wd+ex', 'wd+ex+in', # Weighted decoding
386 | 'cd+ex', # Constrained decoding
387 | 'nl+ex', 'nl+ex+in', # Natural language constraint
388 | 'os+ex+in', # Offset guidance
389 | ],
390 | )
391 | parser.add_argument("--max_gen_length", type=int, default=60)
392 | parser.add_argument("--alpha", type=float, default=None) # 0.02 / 0.1
393 | parser.add_argument("--beta", type=float, default=None) # 0.05
394 | parser.add_argument("--gamma", type=float, default=None) # 1.5
395 | parser.add_argument("--fusion_gamma", type=float, default=1.0) # 1.0
396 | parser.add_argument("--refinement_steps", type=int, default=0) # 3
397 | parser.add_argument("--prev_run_dir", type=str, default=None)
398 | parser.add_argument("--log_to_console", action="store_true")
399 |
400 | # Guidance model.
401 | parser.add_argument("--guidance_model_name", type=str, default='none',
402 | choices=[
403 | 'none',
404 | 'oracle',
405 | 'gpt2', 'gpt2-xl', 'gpt2-ft',
406 | 'EleutherAI/gpt-j-6B',
407 | ],
408 | help="Guidance model. Default is the oracle."
409 | )
410 | parser.add_argument("--guidance_model_type", type=str, default='none',
411 | choices=['none', 'full', 'binary', 'discrete'],
412 | help="Guidance model type."
413 | )
414 | parser.add_argument("--guidance_model_type_2", type=str, default='none',
415 | choices=['none', 'full', 'binary', 'discrete'],
416 | help="Second guidance model (will only be used for INCLUSION if specified)."
417 | )
418 | parser.add_argument("--guidance_model_path", type=str, default="",
419 | help="Load checkpointed guidance model."
420 | )
421 | parser.add_argument("--num_icl_pairs", type=int, default=0, help="Default: 0")
422 | parser.add_argument("--g_threshold", type=float, default=0.5, help="Default: 0.5")
423 | parser.add_argument("--max_multistep_len", type=int, default=0, help="Default: 8")
424 | parser.add_argument("--full_guide_topk_in", type=int, default=0, help="Default: 40")
425 | parser.add_argument("--full_guide_topk_ex", type=int, default=0, help="Default: 20")
426 | parser.add_argument("--discrete_max_length", type=int, default=100,
427 | help="Default: 100. The number of tokens generated for discrete guidance."
428 | )
429 | parser.add_argument("--discrete_guidance_num_beams", type=int, default=1)
430 | parser.add_argument("--discrete_guidance_num_beam_groups", type=int, default=1)
431 | parser.add_argument("--discrete_guidance_do_sample", type=bool, default=False)
432 | parser.add_argument("--discrete_guidance_top_k", type=int, default=None)
433 | parser.add_argument("--discrete_guidance_top_p", type=float, default=None)
434 | parser.add_argument("--discrete_guidance_temperature", type=float, default=None)
435 | parser.add_argument(
436 | "--discrete_guidance_num_return_sequences", type=int, default=None
437 | )
438 | parser.add_argument("--discrete_guidance_diversity_penalty", type=float, default=None)
439 | parser.add_argument(
440 | "--discrete_guidance_instruct2guide_model_dir", type=str, default=None,
441 | help="Directory of tuned prefixes for topic and constraint."
442 | )
443 | parser.add_argument("--discrete_guidance_use_trie", action="store_true")
444 |
445 | if config is not None and isinstance(config, dict):
446 | parser.set_defaults(**config)
447 | args = parser.parse_args()
448 | return args
449 |
450 |
451 | if __name__ == '__main__':
452 | """
453 | python -m src.main --name RUN_NAME
454 | """
455 | args = get_argparse()
456 |
457 | exp_manager = ExperimentManager(
458 | name=args.name,
459 | run_dir=args.run_dir,
460 | override=True,
461 | num_runs=1,
462 | metric_tracker=MetricTracker(),
463 | )
464 | run(args, run_id=0, exp_manager=exp_manager)
465 |
466 |
--------------------------------------------------------------------------------
/src/metric_tracker.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import numpy as np
3 |
4 |
5 | class MetricTracker:
6 | def __init__(self):
7 | self._state = defaultdict(lambda: defaultdict(list))
8 |
9 | @property
10 | def state(self):
11 | return {k: dict(v) for k, v in self._state.items()}
12 |
13 | def add(self, metric: dict, run_id: int):
14 | invalid_value_types = {str, list, dict}
15 | for name, value in metric.items():
16 | if type(value) in invalid_value_types:
17 | continue
18 | self._state[run_id][name].append(value)
19 |
20 | def compute(self, metric_name=None, run_id=None, ignore=('id', 'global_step')):
21 | """
22 | `metric_name` and `run_id` are expected to be passed in simultaneously,
23 | or they should both be `None`.
24 | """
25 | stats = defaultdict(dict)
26 | for run_id_, metrics in self._state.items():
27 | for name, values in metrics.items():
28 | if metric_name is not None and name != metric_name:
29 | continue
30 | if name in ignore:
31 | continue
32 | avg = sum(values) / len(values)
33 | stats[run_id_][name] = avg
34 | stats[run_id_][name + '_num_datapoints'] = len(values)
35 | stats = dict(stats)
36 |
37 | if run_id is not None and metric_name is not None:
38 | stats = stats[run_id][metric_name]
39 | return stats
40 |
41 | def aggregate(self, avg_runs=True):
42 | per_run_avg_stats = self.compute()
43 |
44 | if not avg_runs:
45 | return per_run_avg_stats
46 | else:
47 | run_avg_stats = defaultdict(list)
48 | metrics = list(per_run_avg_stats.values())
49 | for metric in metrics:
50 | for name, value in metric.items():
51 | run_avg_stats[name].append(value)
52 |
53 | # Calculate mean/std.
54 | run_avg_stats = {
55 | k: {
56 | 'mean': float(np.mean(vs)),
57 | 'std': float(np.std(vs)),
58 | 'num_runs': len(vs),
59 | 'num_datapoints': run_avg_stats[k + '_num_datapoints'],
60 | }
61 | for k, vs in run_avg_stats.items() if not k.endswith('_num_datapoints')
62 | }
63 | return run_avg_stats
64 |
--------------------------------------------------------------------------------
/src/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Metrics.
3 | """
4 | import re
5 | from collections import defaultdict
6 |
7 | import spacy
8 | from nltk import ngrams
9 | from nltk.tokenize import word_tokenize
10 | from nltk.translate.bleu_score import sentence_bleu
11 |
12 | from src.guidance_utils import remove_parenthesis
13 |
14 | spacy_parser = spacy.load("en_core_web_lg")
15 |
16 |
17 | def aggregate_metrics(metrics, name):
18 | """
19 | Aggregate metrics for one run.
20 | """
21 | run_stats = defaultdict(list)
22 | run_stats['name'] = name
23 |
24 | for metric in metrics:
25 | for k, v in metric.items():
26 | run_stats[k].append(v)
27 | num_datapoints = []
28 | for k, vs in run_stats.items():
29 | if isinstance(vs, list):
30 | run_stats[k] = sum(vs) / len(vs)
31 | num_datapoints.append(len(vs))
32 | # assert not num_datapoints or len(set(num_datapoints)) == 1, num_datapoints
33 | run_stats['num_datapoints'] = num_datapoints[0] if num_datapoints else 0
34 |
35 | return dict(run_stats)
36 |
37 |
38 | def compute_prediction_metrics(prediction, hierarchy, data_mode):
39 | datapoint = prediction['datapoint']
40 | context = datapoint['context']
41 | topic = datapoint['topic']
42 | constraint = datapoint['constraint']
43 | generated_text = prediction['generated_text']
44 |
45 | if data_mode == 'wordnet':
46 | generated_text = generated_text.lower()
47 | forbidden_words = hierarchy[constraint] + [constraint]
48 | forbidden_words = set(
49 | forbidden_words + [w + 's' for w in forbidden_words]
50 | )
51 |
52 | topical_words = hierarchy[topic]
53 | topical_words = set(topical_words + [w + 's' for w in topical_words])
54 |
55 | violated = any(forbidden_word in generated_text for forbidden_word in forbidden_words)
56 | on_topic = any(topical_word in generated_text for topical_word in topical_words)
57 |
58 | topical_word_regex = '|'.join(list(topical_words))
59 | pattern = re.compile(rf'({topical_word_regex})')
60 | topical_word_matches = pattern.findall(generated_text)
61 |
62 | #if len(set(topical_word_matches) & set(prediction['datapoint']['current'])) > 0:
63 | # on_topic = False
64 | #print(prediction['datapoint']['current'])
65 | #print(topical_word_matches)
66 | #print(topic)
67 | #print(constraint)
68 | #print('---')
69 |
70 | extracted = None
71 | elif data_mode == 'wikidata':
72 |
73 | gen_text_parsed = spacy_parser(generated_text)
74 | parsed_names = set([
75 | ent.text.lower() for ent in gen_text_parsed.ents
76 | if ent.label_ == 'PERSON'
77 | ])
78 | forbidden_words = hierarchy[constraint]
79 | forbidden_words = [remove_parenthesis(q_title) for q_id, q_title in forbidden_words]
80 | violated = False
81 | violated_word = None
82 | for w in forbidden_words:
83 | if w in generated_text:
84 | violated = True
85 | violated_word = w
86 | break
87 | #violated = len(parsed_names & forbidden_words) > 0
88 |
89 | topical_words = hierarchy[topic]
90 | topical_words = [remove_parenthesis(q_title) for q_id, q_title in topical_words]
91 | on_topic = False
92 | on_topic_word = None
93 | for w in topical_words:
94 | if w in generated_text:
95 | on_topic = True
96 | on_topic_word = w
97 | break
98 | #on_topic = len(parsed_names & topical_words) > 0
99 | #extracted = parsed_names
100 | extracted = dict(violated_word=violated_word, on_topic_word=on_topic_word)
101 | else:
102 | raise ValueError(f'Data mode {data_mode} not recognized.')
103 | on_topic_not_violated = (not violated) and on_topic
104 | copying_bleu_score = copying_bleu(context, generated_text)
105 | repetition_scores = get_repetition_scores(generated_text.split())
106 | prediction_metric = dict(
107 | violated=violated,
108 | on_topic=on_topic,
109 | on_topic_not_violated=on_topic_not_violated,
110 | copying_bleu_score=copying_bleu_score,
111 | )
112 | prediction_metric = {**prediction_metric, **repetition_scores}
113 |
114 | return dict(
115 | prediction_metric=prediction_metric,
116 | extracted=extracted,
117 | )
118 |
119 |
120 | def copying_bleu(context, generated_text):
121 | prediction = word_tokenize(generated_text.strip('==').strip())
122 | max_score = float('-inf')
123 | for context_sent in context.strip().split('\n'):
124 | references = [word_tokenize(context_sent.strip('==').strip())]
125 | score = sentence_bleu(references, prediction)
126 | if score > max_score:
127 | max_score = score
128 | return score
129 |
130 |
131 | def get_repetition_scores(tokens):
132 | metric = defaultdict(float)
133 | for n in range(1, 5):
134 | ngs = [ng for ng in ngrams(tokens, n)]
135 | unique_ngs = set(ngs)
136 | if not ngs:
137 | metric[f'seq-rep-{n}'] = 0.0
138 | else:
139 | metric[f'seq-rep-{n}'] = 1.0 - (len(unique_ngs) / len(ngs))
140 | return dict(metric)
141 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import yaml
3 | import random
4 | from pathlib import Path
5 | from ast import literal_eval
6 | from argparse import Namespace
7 | from collections import defaultdict
8 | from os.path import dirname, abspath, join
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | from transformers import (
13 | AutoTokenizer,
14 | AutoModelForCausalLM,
15 | StoppingCriteria,
16 | StoppingCriteriaList,
17 | )
18 |
19 | random.seed(0)
20 |
21 |
22 | def top_p_sampling(logits, top_p):
23 | """
24 | logits: (1, vocab_size)
25 | Code taken from: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
26 | """
27 | logits = logits.squeeze(0)
28 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
29 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
30 | #print(cumulative_probs.tolist()[100])
31 | #print(cumulative_probs.tolist()[200])
32 | #print(cumulative_probs.tolist()[500])
33 | #print(cumulative_probs.tolist()[1000])
34 | #print(cumulative_probs.tolist()[2000])
35 | #print(cumulative_probs.tolist()[5000])
36 | #print(cumulative_probs.tolist()[10000])
37 | #print(cumulative_probs.tolist()[20000])
38 | #print(cumulative_probs.tolist()[-1])
39 | sorted_indices_to_remove = cumulative_probs > top_p
40 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
41 | sorted_indices_to_remove[..., 0] = 0
42 | indices_to_remove = sorted_indices[sorted_indices_to_remove]
43 | logits[indices_to_remove] = float('-inf')
44 | return logits
45 |
46 |
47 | class EndOfFunctionCriteria(StoppingCriteria):
48 | """
49 | Code taken from: https://github.com/huggingface/transformers/blob/1d651868d64e8f54f7bf6b687fbcdac832039334/examples/research_projects/codeparrot/scripts/human_eval.py#L25
50 | Custom `StoppingCriteria` which checks if all generated functions in
51 | the batch are completed.
52 | """
53 |
54 | def __init__(self, start_length, eof_strings, tokenizer):
55 | self.start_length = start_length
56 | self.eof_strings = eof_strings
57 | self.tokenizer = tokenizer
58 |
59 | def __call__(self, input_ids, scores, **kwargs):
60 | """
61 | Returns true if all generated sequences contain any of the
62 | end-of-function strings.
63 | """
64 | decoded_generations = \
65 | self.tokenizer.batch_decode(input_ids[:, self.start_length :])
66 | done = []
67 | for decoded_generation in decoded_generations:
68 | done.append(any([
69 | stop_string in decoded_generation
70 | for stop_string in self.eof_strings
71 | ]))
72 | return all(done)
73 |
74 |
75 | def to_namespace(config):
76 | return Namespace(**config)
77 |
78 |
79 | def prepare_args(args):
80 | if args.model_type in ('classification', 'prompt-tune', 'fine-tune'):
81 | args.metric_name = 'Accuracy'
82 | elif args.model_type == 'regression':
83 | args.metric_name = 'Class=1 Precision'
84 | else:
85 | raise ValueError(f'Unknown model_type: {args.model_type}')
86 | return args
87 |
88 |
89 | def update_args(args):
90 | """
91 | Full in null args for later-trained model.
92 | """
93 | if not hasattr(args, 'loss_type'):
94 | setattr(args, 'loss_type', 'ce')
95 | return args
96 |
97 |
98 | def get_wikidata_p_text(p_name, p_value):
99 | if p_name == 'place_of_birth':
100 | fill_text = f'people who were born in {p_value}'
101 | elif p_name == 'place_of_death':
102 | fill_text = f'people who died in {p_value}'
103 | elif p_name == 'occupation':
104 | fill_text = p_value
105 | elif p_name == 'country_of_citizenship':
106 | fill_text = f'people who are citizens of {p_value}'
107 | elif p_name == 'academic_degree':
108 | fill_text = f'people who hold a degree in {p_value}'
109 | elif p_name == 'educated_at':
110 | fill_text = f'people who had their education at {p_value}'
111 | return fill_text
112 |
113 |
114 | def cleanup_gen_text(text):
115 | """
116 | Cleanup items
117 | - drop the last unfinished sentence (should finish with `==`)
118 |
119 | The normal case: `len(text.split('\n')) == 3`.
120 | """
121 | sents = text.strip().split('\n')
122 | return sents[0]
123 | # print(sents)
124 | # if len(sents) > 2 and not sents[-1].endswith('=='):
125 | # sents = sents[:-1]
126 | # return '\n'.join(sents)
127 |
128 |
129 | def reformat_text(text):
130 | text_tokens = text.split(' ')
131 | text = ' '.join(text_tokens)
132 | return text
133 |
134 |
135 | def get_hierarchy_path_to_children(path):
136 | hierarchy_path_to_children = dict()
137 | with open(path) as f:
138 | for line in f:
139 | obj = json.loads(line.strip())
140 | hierarchy_path = obj['hierarchy_path']
141 | children = obj['children']
142 | hierarchy_path_to_children[tuple(hierarchy_path)] = children
143 | return hierarchy_path_to_children
144 |
145 |
146 | def get_hierarchy(path=None):
147 | if path is None:
148 | path = '/path/to/your/topic_to_leafs.json'
149 | with open(path) as f:
150 | hierarchy_ = json.load(f)
151 | hierarchy = dict()
152 | for topic, leafs in hierarchy_.items():
153 | new_topic = topic.replace('_', ' ')
154 | new_leafs = [l.replace('_', ' ') for l in leafs]
155 | hierarchy[new_topic] = new_leafs
156 | return hierarchy
157 |
158 |
159 | class WikidataHierarchy:
160 | def __init__(self, q_to_p, p_to_q):
161 | self.q_to_p = q_to_p
162 | self.p_to_q = p_to_q
163 |
164 | def __contains__(self, key):
165 | if self.__getitem__(key):
166 | return True
167 | else:
168 | return False
169 |
170 | def __getitem__(self, p):
171 | """
172 | For a given `p` (i.e., topic or constraint), return a list of all Q's
173 | that have this `p`.
174 | Return list like [(Q7339, 'Margot Frank'), ...].
175 | """
176 | if isinstance(p, list):
177 | p = tuple(p)
178 |
179 | return self.p_to_q.get(p, [])
180 |
181 |
182 | def get_wikidata_hierarchy():
183 | from evaluation.scripts.build_wikidata_dataset import (
184 | load_ranked_properties,
185 | load_all_entities
186 | )
187 | WIKIDATA_PATH = Path('path/to/your/qid2title.json')
188 | q_to_p, p_to_q = load_all_entities(WIKIDATA_PATH)
189 | hierarchy = WikidataHierarchy(q_to_p, p_to_q)
190 | return hierarchy
191 |
192 |
193 | def normalize_datapoint(datapoint, args, hierarchy):
194 | if args.data == 'wordnet':
195 | return datapoint
196 | elif args.data == 'wikidata':
197 | normalized = dict()
198 | if 'example_id' in datapoint:
199 | normalized['id'] = datapoint['example_id']
200 | elif 'id' in datapoint:
201 | normalized['id'] = datapoint['id']
202 |
203 | if 'text' in datapoint:
204 | normalized['context'] = datapoint['text']
205 | elif 'context' in datapoint:
206 | normalized['context'] = datapoint['context']
207 |
208 | normalized['topic'] = (
209 | tuple(datapoint['p'])
210 | if 'topic' not in datapoint
211 | else datapoint['topic']
212 | )
213 | normalized['gen_qs'] = datapoint['gen_qs']
214 | normalized['gen_text'] = datapoint['gen_text']
215 |
216 | if 'constraint' in datapoint:
217 | normalized['constraint'] = datapoint['constraint']
218 | else:
219 | constraint_candidates = defaultdict(int)
220 | for gen_q in datapoint['gen_qs']:
221 | gen_ps = hierarchy.q_to_p[tuple(gen_q)]
222 | for p_name, p_values in gen_ps.items():
223 | for p_value in p_values:
224 | constraint_candidates[(p_name, p_value)] += 1
225 |
226 | SKIP = {}
227 | selected_constraint = None
228 | for constraint in constraint_candidates.keys():
229 | if normalized['topic'][0] != constraint[0] and constraint[0] not in SKIP:
230 | selected_constraint = constraint
231 | break
232 |
233 | normalized['constraint'] = selected_constraint
234 | return normalized
235 | else:
236 | raise ValueError(f'Data arg {args.data} not recognized.')
237 |
238 |
239 | def get_data(args, randomize=False, hierarchy_only=False):
240 | if args.data == 'wordnet':
241 | hierarchy = get_hierarchy(args.hierarchy_path)
242 | elif args.data == 'wikidata':
243 | # Return the modifier to update the datapoint.
244 | hierarchy = get_wikidata_hierarchy()
245 | else:
246 | raise ValueError(f'Data arg {args.data} not recognized.')
247 |
248 | if hierarchy_only:
249 | return hierarchy
250 |
251 | datasets = defaultdict(list)
252 | data_paths = {
253 | 'train': args.train_path,
254 | 'dev': args.dev_path,
255 | 'test': args.test_path,
256 | }
257 | for dataset_split, dataset_path in data_paths.items():
258 | if dataset_path is not None:
259 | with open(dataset_path) as f:
260 | for line in f:
261 | datapoint = json.loads(line.strip())
262 | datapoint = normalize_datapoint(datapoint, args, hierarchy)
263 |
264 | topic = datapoint['topic']
265 | constraint = datapoint['constraint']
266 | if (
267 | args.data == 'wordnet' and
268 | (
269 | constraint not in hierarchy[topic] or
270 | constraint == topic or
271 | constraint not in hierarchy
272 | )
273 | ):
274 | # skipping bad data
275 | continue
276 |
277 | if args.data == 'wikidata' and constraint is None:
278 | continue
279 |
280 | datasets[dataset_split].append(datapoint)
281 | if randomize:
282 | random.shuffle(datasets[dataset_split])
283 |
284 | num_datapoints = getattr(args, dataset_split + '_num_datapoints', 100000000)
285 | datasets[dataset_split] = datasets[dataset_split][:num_datapoints]
286 | return datasets, hierarchy, None
287 |
288 |
289 | def get_model_class(model_name):
290 | from ci.train_guidance_model import (
291 | MLPRegressor,
292 | MLPClassifier,
293 | PromptTuningModel,
294 | FineTuningModel,
295 | )
296 | NAME_TO_CLASS = {
297 | 'MLPRegressor': MLPRegressor,
298 | 'MLPClassifier': MLPClassifier,
299 | 'PromptTuningModel': PromptTuningModel,
300 | 'FineTuningModel': FineTuningModel,
301 | }
302 | return NAME_TO_CLASS[model_name]
303 |
304 |
305 | def put_model_on_gpus(model, model_name, num_devices):
306 | if model_name == 'gpt2-xl':
307 | if num_devices == 8:
308 | device_map = {
309 | 0: list(range(0, 6)),
310 | 1: list(range(6, 12)),
311 | 2: list(range(12, 18)),
312 | 3: list(range(18, 24)),
313 | 4: list(range(24, 30)),
314 | 5: list(range(30, 36)),
315 | 6: list(range(36, 42)),
316 | 7: list(range(42, 48)),
317 | }
318 | model.parallelize(device_map)
319 | elif num_devices == 4:
320 | device_map = {
321 | 0: list(range(0, 12)),
322 | 1: list(range(12, 24)),
323 | 2: list(range(24, 36)),
324 | 3: list(range(36, 48)),
325 | }
326 | model.parallelize(device_map)
327 | elif num_devices == 2:
328 | device_map = {
329 | 0: list(range(0, 24)),
330 | 1: list(range(24, 48)),
331 | }
332 | model.parallelize(device_map)
333 | elif num_devices == 1:
334 | model.cuda()
335 | else:
336 | raise ValueError(f'Num devices ({num_devices}) not supported for gpt2-xl.')
337 | else:
338 | model.cuda()
339 | return model
340 |
341 |
342 | def get_lm(args, num_devices=None, load_mode=None):
343 | if load_mode is None:
344 | model_name = args.model_name
345 | elif load_mode == 'guidance':
346 | model_name = args.guidance_model_name
347 | else:
348 | raise ValueError(f'Unknown load_mode: {load_mode}')
349 | model = AutoModelForCausalLM.from_pretrained(
350 | model_name,
351 | local_files_only=True,
352 | )
353 | if num_devices is not None:
354 | model = put_model_on_gpus(model, model_name, num_devices)
355 | for param in model.parameters():
356 | param.requires_grad = False
357 | model.eval()
358 | print(f'Model "{model_name}" loaded.')
359 | return model
360 |
361 |
362 | def get_tokenizer(args):
363 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
364 | tokenizer.pad_token = tokenizer.eos_token
365 | return tokenizer
366 |
367 |
368 | def get_node_children_in_text(text, node, hierarchy):
369 | """
370 | Check if `text` contains `node`'s children in the `hierarchy`.
371 | """
372 | nodes = set(hierarchy[node] + [node])
373 | return list(set([node for node in nodes if node in text]))
374 |
375 |
376 | def save_ckpt(model, args, run_dir, **kwargs):
377 | stats = dict()
378 | stats['current_epoch'] = kwargs.get('epoch', None)
379 | stats['global_step'] = kwargs.get('global_step', None)
380 | stats['best_metric'] = kwargs.get('best_metric', None)
381 |
382 | with open(run_dir / 'stats.yaml', 'w') as f:
383 | yaml.dump(stats, f)
384 |
385 | args.model_class_name = model.__class__.__name__
386 | args.ckpt_path = run_dir / 'checkpoint.pt'
387 | checkpoint = dict()
388 | checkpoint['args'] = vars(args)
389 | checkpoint['states'] = model.state_dict()
390 |
391 | torch.save(checkpoint, args.ckpt_path)
392 | print(f"Model saved at: {args.ckpt_path}")
393 |
394 |
395 | def load_ckpt(load_path):
396 | checkpoint = torch.load(load_path, map_location='cpu')
397 | args = Namespace(**checkpoint['args'])
398 | args = update_args(args)
399 | states = checkpoint['states']
400 | model_class = get_model_class(args.model_class_name)
401 | model = model_class(args)
402 | model.load_state_dict(states)
403 | model.eval()
404 | print('Model loaded from:', load_path)
405 | return model, args
406 |
--------------------------------------------------------------------------------