├── rankgen_model ├── __init__.py └── modeling_rankgen.py ├── rankgen ├── parallel │ ├── parallel_logs │ │ └── .gitkeep │ ├── parallel_schedulers │ │ └── .gitkeep │ ├── parallel_template_gpu.sh │ ├── merge.py │ └── schedule.py ├── __init__.py ├── avg_lengths.py ├── choose_examples.py ├── preprocess_story_cloze.py ├── retriever_score.py ├── ab_tests.py ├── export_results.py ├── ab_tests_generations.py ├── score_ab.py ├── token_overlap_generate.py ├── build_new_ab_split.py ├── shorten_prefix.py ├── shorten_suffix.py ├── plot_divergence_curves.py ├── token_overlap.py ├── test_rankgen_encoder.py ├── human_eval.py ├── rankgen_beam_search.py ├── gpt3_score.py ├── score_multi_beam.py ├── gpt2_generate.py ├── gpt2_generate_contrastive_search.py ├── score_ab_text.py ├── rankgen_generator.py ├── gpt2_generate_contrastive_decoding.py ├── rankgen_encoder.py ├── score_multi_tsv.py ├── rankgen_beam_search_choose_eg.py ├── utils.py ├── gpt2_score.py ├── rankgen_beam_search_shifting.py ├── score_multi.py └── convert_checkpoint.py ├── paper.pdf ├── old └── requirements-bak.txt ├── pyproject.toml ├── .gitignore ├── README.md └── LICENSE /rankgen_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rankgen/parallel/parallel_logs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rankgen/parallel/parallel_schedulers/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martiansideofthemoon/rankgen/HEAD/paper.pdf -------------------------------------------------------------------------------- /rankgen/__init__.py: -------------------------------------------------------------------------------- 1 | from .rankgen_encoder import RankGenEncoder 2 | from .rankgen_generator import RankGenGenerator 3 | -------------------------------------------------------------------------------- /rankgen/parallel/parallel_template_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH --job-name=job__ 3 | #SBATCH -o /work/kalpeshkrish_umass_edu/rankgen/rankgen/parallel/parallel_logs/logs_exp_/log_.txt 4 | #SBATCH --partition= 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=2 7 | #SBATCH --mem=45000 8 | #SBATCH -d singleton 9 | 10 | cd /work/kalpeshkrish_umass_edu/rankgen 11 | 12 | --local_rank --num_shards 13 | -------------------------------------------------------------------------------- /old/requirements-bak.txt: -------------------------------------------------------------------------------- 1 | certifi==2022.6.15 2 | charset-normalizer==2.1.0 3 | filelock==3.7.1 4 | gdown==4.5.1 5 | huggingface-hub==0.8.1 6 | idna==3.3 7 | importlib-metadata==4.12.0 8 | numpy==1.22 9 | packaging==21.3 10 | Pillow==9.2.0 11 | pyparsing==3.0.9 12 | PyYAML==6.0 13 | regex==2022.6.2 14 | requests==2.28.1 15 | sentencepiece==0.1.96 16 | tokenizers==0.12.1 17 | torch==1.12.0 18 | torchvision==0.13.0 19 | tqdm==4.64.0 20 | transformers==4.20.1 21 | typing_extensions==4.3.0 22 | urllib3==1.26.9 23 | zipp==3.8.0 24 | -------------------------------------------------------------------------------- /rankgen/avg_lengths.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import numpy as np 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--dataset', default="data/story_cloze_spring_2016_val.tsv") 7 | args = parser.parse_args() 8 | 9 | data = [] 10 | 11 | with open(args.dataset, 'r') as f: 12 | data = [x.split("\t") for x in f.read().strip().split("\n")] 13 | 14 | prefix_lens = [] 15 | suffix_lens = [] 16 | 17 | for dd in data: 18 | prefix_lens.append(len(dd[0].split())) 19 | suffix_lens.append(len(dd[1].split())) 20 | 21 | print(f"Average prefix = {np.mean(prefix_lens)}") 22 | print(f"Average suffix = {np.mean(suffix_lens)}") 23 | -------------------------------------------------------------------------------- /rankgen_model/modeling_rankgen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from torch import nn 4 | from transformers import T5PreTrainedModel, T5EncoderModel 5 | 6 | class T5EncoderWithProjection(T5PreTrainedModel): 7 | def __init__(self, config): 8 | super().__init__(config) 9 | self.config = config 10 | self.t5_encoder = T5EncoderModel(config) 11 | self.projection = nn.Linear(config.d_model, config.d_model, bias=False) 12 | # Initialize weights and apply final processing 13 | self.post_init() 14 | 15 | def forward(self, **input_args): 16 | hidden_states = self.t5_encoder(**input_args).last_hidden_state 17 | hidden_states = hidden_states[:, 0, :] 18 | batch_embeddings = self.projection(hidden_states) 19 | return batch_embeddings 20 | -------------------------------------------------------------------------------- /rankgen/parallel/merge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import pickle 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--input_pattern', default="openwebtext_vectors/2016-06.pkl_0_small.pkl.matches_entity_*", type=str) 7 | parser.add_argument('--output_file', default=None, type=str) 8 | args = parser.parse_args() 9 | 10 | files = glob.glob(args.input_pattern) 11 | file_with_ids = [(int(f.split("_")[-1]), f) for f in files] 12 | file_with_ids.sort(key=lambda x: x[0]) 13 | 14 | data = "" 15 | for file in file_with_ids: 16 | with open(file[1], "r") as f: 17 | data += f.read() 18 | 19 | if args.output_file is not None: 20 | output_file = args.output_file 21 | else: 22 | output_file = ".".join(args.input_pattern.split(".")[:-1]) 23 | print(output_file) 24 | with open(output_file, "w") as f: 25 | f.write(data) 26 | -------------------------------------------------------------------------------- /rankgen/choose_examples.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import argparse 4 | from utils import export_server 5 | import tqdm 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--folder', default="rankgen_train_data_samples/pg19") 9 | parser.add_argument('--output_dir', default="rankgen_train_data_samples/pg19_html") 10 | args = parser.parse_args() 11 | 12 | files = os.listdir(args.folder) 13 | 14 | books = [] 15 | for file in tqdm.tqdm(files): 16 | with open(f"{args.folder}/{file}", 'r') as f: 17 | data = [x.split('\t') for x in f.read().strip().split("\n")] 18 | random.shuffle(data) 19 | data = data[:100] 20 | output = "" 21 | for dd in data: 22 | output += f"PREFIX = {dd[0]}\n\n" 23 | output += f"SUFFIX = {dd[1]}\n\n" 24 | output += f"NEGATIVE = {dd[-1]}\n\n--------------------------\n\n" 25 | 26 | export_server(output, os.path.join(args.output_dir, file)) 27 | 28 | random.shuffle(books) 29 | -------------------------------------------------------------------------------- /rankgen/preprocess_story_cloze.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--dataset', default="data/story_cloze/story_cloze_spring_2016_test.csv") 6 | args = parser.parse_args() 7 | 8 | data = [] 9 | 10 | with open(args.dataset, 'r') as f: 11 | reader = csv.reader(f) 12 | for row in reader: 13 | data.append(row) 14 | 15 | header = data[0] 16 | data = data[1:] 17 | 18 | output = "" 19 | for dd in data: 20 | prefix = " ".join(dd[1:5]) 21 | if dd[-1] == '1': 22 | output += f"{prefix}\t{dd[5]}\tplaceholder\tplaceholder\n" 23 | output += f"{prefix}\t{dd[6]}\tplaceholder\tplaceholder\n" 24 | elif dd[-1] == '2': 25 | output += f"{prefix}\t{dd[6]}\tplaceholder\tplaceholder\n" 26 | output += f"{prefix}\t{dd[5]}\tplaceholder\tplaceholder\n" 27 | else: 28 | raise ValueError("Wrong Answer Ending") 29 | 30 | with open('data/story_cloze/story_cloze_spring_2016_test.tsv', 'w') as f: 31 | f.write(output) 32 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "rankgen" 3 | version = "0.1.1" 4 | description = "RankGen is a suite of encoder models (100M-1.2B parameters) which map prefixes and generations from any pretrained English language model to a shared vector space. RankGen can be used to rerank multiple full-length samples from an LM, and it can also be incorporated as a scoring function into beam search to significantly improve generation quality (0.85 vs 0.77 MAUVE, 75% preference according to humans annotators who are English writers)." 5 | authors = ["Kalpesh Krishna, Yapei Chang, John Wieting, Mohit Iyyer"] 6 | license = "Apache License 2.0" 7 | repository = 'https://github.com/martiansideofthemoon/rankgen' 8 | readme = 'README.md' 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.7" 12 | torch = "^1.12.0" 13 | transformers = "^4.20.1" 14 | sentencepiece = "^0.1.96" 15 | gdown = "^4.5.1" 16 | 17 | [tool.poetry.dev-dependencies] 18 | 19 | [build-system] 20 | requires = ["poetry-core>=1.0.0"] 21 | build-backend = "poetry.core.masonry.api" 22 | -------------------------------------------------------------------------------- /rankgen/retriever_score.py: -------------------------------------------------------------------------------- 1 | """Use pre-loaded retriever scores to rank suffixes.""" 2 | 3 | import argparse 4 | import json 5 | import tqdm 6 | import numpy as np 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--dataset', default="data/t5_xl_all_domains_pg19_hard.jsonl") 11 | args = parser.parse_args() 12 | 13 | with open(args.dataset, "r") as f: 14 | data = [json.loads(x) for x in f.read().strip().split("\n")] 15 | 16 | avg_score = [] 17 | all_score = [] 18 | 19 | for idx, dd in tqdm.tqdm(enumerate(data), total=len(data)): 20 | prefix = dd['prefix'] 21 | candidates = [dd['suffix']] + dd['negatives'] 22 | overlap_scores = [dd['suffix_score']] + dd['negative_scores'] 23 | assert len(candidates) == 11 24 | avg_score.append(np.mean([overlap_scores[0] > y for y in overlap_scores[1:]])) 25 | all_score.append(all([overlap_scores[0] > y for y in overlap_scores[1:]])) 26 | 27 | if (idx + 1) % 10000 == 0: 28 | print(f"{np.mean(avg_score):.4f} average ({len(avg_score)} instances), {np.mean(all_score):.4f} all ({len(all_score)} instances)") 29 | 30 | print(f"{np.mean(avg_score):.4f} average ({len(avg_score)} instances), {np.mean(all_score):.4f} all ({len(all_score)} instances)") 31 | -------------------------------------------------------------------------------- /rankgen/ab_tests.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import csv 4 | import os 5 | import numpy as np 6 | import random 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--dataset', default="data/t5_xl_all_domains_pg19_hard.jsonl") 11 | parser.add_argument('--folder', default="ab_tests/gold_neg_pg19_hard") 12 | parser.add_argument('--num_instances', default=200) 13 | args = parser.parse_args() 14 | 15 | os.makedirs(args.folder, exist_ok=True) 16 | 17 | with open(args.dataset, "r") as f: 18 | data = [json.loads(x) for x in f.read().strip().split("\n")] 19 | 20 | 21 | random.seed(46) 22 | 23 | output = [["Prefix", "First", "Second", "Order", "InstanceNum", "Folder"]] 24 | 25 | for i, dd in enumerate(data[:args.num_instances]): 26 | negative = random.choice(dd["negatives"]) 27 | order = random.random() 28 | if order < 0.5: 29 | output.append([ 30 | dd["prefix"], dd["suffix"], negative, "suffix,negative", i, args.folder 31 | ]) 32 | else: 33 | output.append([ 34 | dd["prefix"], negative, dd["suffix"], "negative,suffix", i, args.folder 35 | ]) 36 | 37 | with open(args.folder + "/input.csv", 'w', newline='') as csvfile: 38 | writer = csv.writer(csvfile) 39 | writer.writerows(output) 40 | -------------------------------------------------------------------------------- /rankgen/export_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import csv 4 | import os 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--best', default="data/wiki_beam_2_num_tokens_20_num_samples_10.jsonl") 9 | parser.add_argument('--nucleus', default="data/wiki_beam_1_num_tokens_128_num_samples_1.jsonl") 10 | parser.add_argument('--folder', default="analyze_results/t5_xxl_descartes") 11 | parser.add_argument('--num_instances', default=50) 12 | args = parser.parse_args() 13 | 14 | os.makedirs(args.folder, exist_ok=True) 15 | 16 | with open(args.best, "r") as f: 17 | best = [json.loads(x) for x in f.read().strip().split("\n")] 18 | 19 | with open(args.nucleus, "r") as f: 20 | nuc = [json.loads(x) for x in f.read().strip().split("\n")] 21 | 22 | output = [["Prefix", "Gold Suffix", "Best Suffix", "Nucleus Suffix", "Folder"]] 23 | 24 | for b in best[:args.num_instances]: 25 | for n in nuc: 26 | if n['prefix'] == b['prefix']: 27 | output.append([ 28 | b["prefix"], b["targets"][0], b['t5_xxl_descartes_outputs'][0], n['t5_xxl_descartes_outputs'][0], args.folder 29 | ]) 30 | 31 | with open(args.folder + "/outputs.csv", 'w', newline='', encoding="utf-8") as csvfile: 32 | writer = csv.writer(csvfile) 33 | writer.writerows(output) 34 | -------------------------------------------------------------------------------- /rankgen/ab_tests_generations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import csv 4 | import os 5 | import numpy as np 6 | import random 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--dataset', default="data/t5_xl_all_domains_pg19_hard.jsonl") 11 | parser.add_argument('--folder', default="ab_tests/gold_neg_pg19_hard") 12 | parser.add_argument('--num_instances', default=200) 13 | args = parser.parse_args() 14 | 15 | os.makedirs(args.folder, exist_ok=True) 16 | 17 | file_pairs = [] 18 | 19 | with open(args.dataset, "r") as f: 20 | data = [json.loads(x) for x in f.read().strip().split("\n")] 21 | 22 | random.seed(46) 23 | 24 | output = [["Prefix", "First", "Second", "Order", "InstanceNum", "File"]] 25 | 26 | for i, dd in enumerate(data[:args.num_instances]): 27 | negative = random.choice(dd["negatives"]) 28 | order = random.random() 29 | if order < 0.5: 30 | output.append([ 31 | dd["prefix"], dd["suffix"], negative, "suffix,negative", i, args.folder 32 | ]) 33 | else: 34 | output.append([ 35 | dd["prefix"], negative, dd["suffix"], "negative,suffix", i, args.folder 36 | ]) 37 | 38 | with open(args.folder + "/input.csv", 'w', newline='') as csvfile: 39 | writer = csv.writer(csvfile) 40 | writer.writerows(output) 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | bin/ 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Installer logs 24 | pip-log.txt 25 | pip-delete-this-directory.txt 26 | 27 | # Unit test / coverage reports 28 | .tox/ 29 | .coverage 30 | .cache 31 | nosetests.xml 32 | coverage.xml 33 | 34 | # Translations 35 | *.mo 36 | 37 | # Mr Developer 38 | .mr.developer.cfg 39 | .project 40 | .pydevproject 41 | 42 | # Rope 43 | .ropeproject 44 | 45 | # Django stuff: 46 | *.log 47 | *.pot 48 | 49 | # Sphinx documentation 50 | docs/_build/ 51 | 52 | style-venv 53 | 54 | .vscode 55 | datasets 56 | 57 | relic-venv 58 | relic_preprocessed 59 | data 60 | scripts/clean.sh 61 | scripts/score_submission_fix.py 62 | crealm-retriever 63 | RELiC 64 | mauve 65 | ab_tests 66 | 67 | scripts/parallel/parallel_logs 68 | scripts/parallel/parallel_schedulers 69 | outputs 70 | .idea 71 | t5x_conversion 72 | outputs_beam 73 | outputs_beam_xl 74 | data_new 75 | 76 | transformers 77 | rankgen/parallel/parallel_logs 78 | rankgen/parallel/parallel_schedulers 79 | gold-beats-neg-outputs 80 | rankgen_data 81 | ContrastiveDecoding 82 | -------------------------------------------------------------------------------- /rankgen/score_ab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import csv 4 | import os 5 | import numpy as np 6 | import random 7 | from datetime import datetime 8 | from collections import Counter 9 | from scipy.stats import kendalltau 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--input', default="ab_tests/gold_neg_wiki_random/input.csv") 14 | parser.add_argument('--dataset', default="ab_tests/gold_neg_wiki_random/Batch_351346_batch_results.csv") 15 | parser.add_argument('--worker_id', default="ADWO0GL6862NA") 16 | args = parser.parse_args() 17 | 18 | data = [] 19 | with open(args.dataset, "r") as f: 20 | spamreader = csv.reader(f) 21 | for row in spamreader: 22 | data.append(row) 23 | 24 | data_input = [] 25 | with open(args.input, "r") as f: 26 | spamreader = csv.reader(f) 27 | for row in spamreader: 28 | data_input.append(row) 29 | 30 | header = data[0] 31 | data = [x for x in data[1:] if x[15] == args.worker_id] 32 | data_dict = {f'{dd[27]} {dd[28]} {dd[29]}': 1 for dd in data} 33 | 34 | verdict = [] 35 | correct = 0 36 | 37 | for dd in data: 38 | if dd[-1] == "Text 1" and dd[-4] == "suffix,negative": 39 | correct += 1 40 | elif dd[-1] == "Text 2" and dd[-4] == "negative,suffix": 41 | correct += 1 42 | 43 | time_taken = [int(dd[23]) for dd in data] 44 | timestamps = [datetime.strptime(dd[18].replace('PDT ', ''), '%a %b %d %H:%M:%S %Y').timestamp() for dd in data] 45 | 46 | print(correct / len(data)) 47 | print(len(data)) 48 | 49 | # for dd in data_input: 50 | # key = f'{dd[0]} {dd[1]} {dd[2]}' 51 | # if key not in data_dict: 52 | # print(dd) 53 | -------------------------------------------------------------------------------- /rankgen/token_overlap_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | import numpy as np 5 | import mauve 6 | import pickle 7 | import matplotlib.pyplot as plt 8 | from nltk import tokenize 9 | from nltk.corpus import stopwords 10 | from utils import f1_score, rep_statistic 11 | import tqdm 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--dataset', default="data/multi_outs/t5_xxl_wiki_t5_xl_gen_inbook_all.jsonl") 16 | parser.add_argument('--output_file', default="data/multi_outs/t5_xxl_wiki_t5_xl_gen_inbook_all.jsonl") 17 | parser.add_argument('--eval_type', default="both") 18 | parser.add_argument('--rep_window', default=20, type=int) 19 | parser.add_argument('--plot_divergence', action='store_true') 20 | parser.add_argument('--eval_mauve', action='store_true') 21 | args = parser.parse_args() 22 | 23 | with open(args.dataset, 'r') as f: 24 | data = [json.loads(x) for x in f.read().strip().split("\n")] 25 | 26 | token_overlaps = { 27 | "human": [], 28 | "random": [], 29 | "best": [] 30 | } 31 | rep_scores = { 32 | "human": [], 33 | "random": [], 34 | "best": [] 35 | } 36 | 37 | for i in range(1): 38 | all_human = [] 39 | all_gen = [] 40 | num_tokens = [] 41 | all_max_score = [] 42 | for dd in tqdm.tqdm(data): 43 | all_human.append(dd['prefix'] + ' ' + dd['targets'][0]) 44 | random_gen = random.choice(dd['targets'][1:]) 45 | 46 | token_overlap_scores = [f1_score(x, dd['prefix'], stopwords=stopwords.words('english'))[0] for x in dd['targets']] 47 | 48 | dd['scores'] = token_overlap_scores 49 | 50 | output_txt = "\n".join([json.dumps(x) for x in data]) + "\n" 51 | with open(args.output_file, 'w') as f: 52 | f.write(output_txt) 53 | -------------------------------------------------------------------------------- /rankgen/build_new_ab_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import json 4 | import csv 5 | import os 6 | import glob 7 | import numpy as np 8 | import pickle 9 | import random 10 | import tqdm 11 | from datetime import datetime 12 | from collections import Counter, defaultdict 13 | from scipy.stats import kendalltau 14 | from statsmodels.stats.inter_rater import fleiss_kappa 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset', default="ab_tests/gpt2_medium_nucleus_vs_beam/Batch_355182_scarecrow.csv") 19 | parser.add_argument('--leftover', default="A1HM7F2N458OAM") 20 | parser.add_argument('--replacement', default="AOZAWM1GLMJPC") 21 | args = parser.parse_args() 22 | 23 | data = [] 24 | header = None 25 | files = args.dataset.split(",") 26 | for fl in files: 27 | curr_data = [] 28 | with open(fl, "r") as f: 29 | spamreader = csv.reader(f) 30 | for row in spamreader: 31 | curr_data.append(row) 32 | header = curr_data[0] 33 | data.extend(curr_data[1:]) 34 | 35 | worker_ids = list(set([x[15] for x in data])) 36 | worker_ids.sort() 37 | all_hit_ids = list(set([x[0] for x in data])) 38 | hit_dict = {x[0]: x for x in data} 39 | 40 | hits_worker_leftover = {x[0]: 1 for x in data if x[15] == args.leftover} 41 | hits_replacement = {x[0]: 1 for x in data if x[15] == args.replacement} 42 | 43 | other_hits = [x for x in all_hit_ids if x not in hits_worker_leftover and x not in hits_replacement] 44 | 45 | output = [["Prefix", "First", "Second", "Order", "Dataset", "Model", "Old HIT ID"]] 46 | 47 | for x in other_hits: 48 | orig_hit = hit_dict[x] 49 | output.append(orig_hit[27:33] + [x]) 50 | 51 | with open(os.path.dirname(args.dataset) + "/leftover.csv", 'w', newline='') as csvfile: 52 | writer = csv.writer(csvfile) 53 | writer.writerows(output) 54 | -------------------------------------------------------------------------------- /rankgen/shorten_prefix.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import nltk 4 | import numpy as np 5 | from nltk.tokenize import sent_tokenize 6 | from utils import extend_sequence 7 | from transformers import AutoTokenizer 8 | 9 | nltk.download('punkt') 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--source', default="rankgen_vary_lens_splits/wikipedia_eval_512_128_sent_boundary.jsonl") 13 | parser.add_argument('--reference', default="rankgen_vary_lens_splits/wikipedia_eval_64_128_sent_boundary.jsonl") 14 | parser.add_argument('--output_path', default="rankgen_vary_lens_splits/wikipedia_eval_{prefix_length}_128_sent_boundary.jsonl") 15 | parser.add_argument('--prefix_length', default=128, type=int) 16 | args = parser.parse_args() 17 | 18 | args.output_path = args.output_path.replace("{prefix_length}", str(args.prefix_length)) 19 | 20 | with open(args.source, "r") as f: 21 | data_source = [json.loads(x) for x in f.read().strip().split("\n")] 22 | 23 | with open(args.reference, "r") as f: 24 | data_ref = [json.loads(x) for x in f.read().strip().split("\n")] 25 | data_ref_dict = {dd['targets'][0]: 1 for dd in data_ref} 26 | 27 | data_src_filt = [dd for dd in data_source if dd['targets'][0] in data_ref_dict] 28 | t5_tokenizer = AutoTokenizer.from_pretrained("t5-large") 29 | 30 | avg_lens = [] 31 | 32 | for dd in data_src_filt: 33 | new_prefix_sents = [] 34 | sents = sent_tokenize(dd['prefix']) 35 | sent_lens = [len(t5_tokenizer.tokenize(x)) for x in sents] 36 | output, _ = extend_sequence(sents, sent_lens, len(sents) - 1, args.prefix_length, False) 37 | dd['prefix'] = output 38 | avg_lens.append(len(t5_tokenizer.tokenize(output))) 39 | print(np.mean(avg_lens)) 40 | 41 | with open(args.output_path, "w") as f: 42 | f.write("\n".join([json.dumps(x) for x in data_src_filt]) + "\n") 43 | -------------------------------------------------------------------------------- /rankgen/shorten_suffix.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import nltk 4 | import numpy as np 5 | from nltk.tokenize import sent_tokenize 6 | from utils import extend_sequence 7 | from transformers import AutoTokenizer 8 | 9 | nltk.download('punkt') 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--source', default="rankgen_vary_lens_splits/wikipedia_eval_256_128_sent_boundary.jsonl") 13 | parser.add_argument('--reference', default="rankgen_vary_lens_splits/wikipedia_eval_64_128_sent_boundary.jsonl") 14 | parser.add_argument('--output_path', default="rankgen_vary_lens_splits/wikipedia_eval_256_{suffix_len}_sent_boundary.jsonl") 15 | parser.add_argument('--suffix_length', default=64, type=int) 16 | args = parser.parse_args() 17 | 18 | args.output_path = args.output_path.replace("{suffix_len}", str(args.suffix_length)) 19 | 20 | with open(args.source, "r") as f: 21 | data_source = [json.loads(x) for x in f.read().strip().split("\n")] 22 | 23 | with open(args.reference, "r") as f: 24 | data_ref = [json.loads(x) for x in f.read().strip().split("\n")] 25 | data_ref_dict = {dd['targets'][0]: 1 for dd in data_ref} 26 | 27 | data_src_filt = [dd for dd in data_source if dd['targets'][0] in data_ref_dict] 28 | t5_tokenizer = AutoTokenizer.from_pretrained("t5-large") 29 | 30 | avg_lens = [] 31 | 32 | for dd in data_src_filt: 33 | all_new_targets = [] 34 | for tgt in dd['targets']: 35 | sents = sent_tokenize(tgt) 36 | sent_lens = [len(t5_tokenizer.tokenize(x)) for x in sents] 37 | output, _ = extend_sequence(sents, sent_lens, 0, args.suffix_length, False, 'suffix') 38 | all_new_targets.append(output) 39 | avg_lens.append(len(t5_tokenizer.tokenize(output))) 40 | dd['targets'] = all_new_targets 41 | 42 | print(np.mean(avg_lens)) 43 | 44 | with open(args.output_path, "w") as f: 45 | f.write("\n".join([json.dumps(x) for x in data_src_filt]) + "\n") 46 | -------------------------------------------------------------------------------- /rankgen/plot_divergence_curves.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pickle 3 | import numpy as np 4 | 5 | setting = ['pg19_gpt2_medium', 'wiki_gpt2_medium', 'pg19_gpt2_xl', 'wiki_gpt2_xl', 'pg19_t5_xxl', 'wiki_t5_xxl', 'pg19_t5_xxl_descartes', 'wiki_t5_xxl_descartes'] 6 | 7 | for s1 in setting: 8 | rankers = [ 9 | (f'data_new/greedy/{s1}.tsv.mauve.pkl', 'max_gen_mauve', 'Greedy'), 10 | (f'data_new/ppl/{s1}.jsonl.mauve.pkl', 'max_gen_mauve', 'PPL-rerank'), 11 | (f'data_new/t5_xl_inbook_gen_all/{s1}.jsonl.mauve.pkl', 'random_gen_mauve', 'Nucleus'), 12 | (f'data_new/t5_xl_inbook_gen_all/{s1}.jsonl.mauve.pkl', 'max_gen_mauve', 'RankGen-rerank') 13 | ] 14 | hatch_styles = ['x', 'O', 'o', '.'] 15 | 16 | all_mauve = [] 17 | plt.rcParams.update({'font.size': 16}) 18 | plt.axis([0.0, 1.0, 0.0, 1.0]) 19 | 20 | for i, rr in enumerate(rankers): 21 | with open(rr[0], 'rb') as f: 22 | mauve1 = pickle.load(f)[rr[1]] 23 | all_mauve.append(mauve1) 24 | 25 | plt.plot(mauve1.divergence_curve[:, 0], mauve1.divergence_curve[:, 1]) 26 | 27 | if i == 0: 28 | plt.fill_between(mauve1.divergence_curve[:, 0], mauve1.divergence_curve[:, 1], hatch=hatch_styles[i], label=rr[2], facecolor='white', edgecolor=plt.rcParams['axes.prop_cycle'].by_key()['color'][i]) 29 | else: 30 | prev_mauve = all_mauve[i - 1] 31 | plt.fill(np.append(prev_mauve.divergence_curve[:, 0], mauve1.divergence_curve[:, 0][::-1]), 32 | np.append(prev_mauve.divergence_curve[:, 1], mauve1.divergence_curve[:, 1][::-1]), 33 | hatch=hatch_styles[i], label=rr[2], facecolor='white', edgecolor=plt.rcParams['axes.prop_cycle'].by_key()['color'][i]) 34 | plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), 35 | fancybox=True, shadow=False, ncol=2) 36 | plt.title(" ".join(s1.split("_")).replace("descartes", "C4")) 37 | # plt.legend(loc='upper right') 38 | plt.xlabel("similarity to Q") 39 | plt.ylabel("similarity to P") 40 | plt.savefig(f'{s1}.plot.pdf', bbox_inches="tight") 41 | plt.clf() 42 | -------------------------------------------------------------------------------- /rankgen/token_overlap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from utils import f1_score 4 | import nltk 5 | import tqdm 6 | import numpy as np 7 | from nltk import tokenize 8 | from nltk.corpus import stopwords 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--dataset', default="data/t5_xl_all_domains_pg19_random.jsonl") 13 | parser.add_argument('--num_negatives', default=10, type=int) 14 | args = parser.parse_args() 15 | 16 | avg_score = [] 17 | all_score = [] 18 | 19 | if args.dataset.endswith(".jsonl"): 20 | with open(args.dataset, "r") as f: 21 | data = [json.loads(x) for x in f.read().strip().split("\n")] 22 | for idx, dd in tqdm.tqdm(enumerate(data), total=len(data)): 23 | prefix = dd['prefix'] 24 | candidates = [dd['suffix']] + dd['negatives'] 25 | assert len(candidates) == args.num_negatives + 1 26 | overlap_scores = [f1_score(x, prefix, stopwords=stopwords.words('english'))[0] for x in candidates] 27 | avg_score.append(np.mean([overlap_scores[0] > y for y in overlap_scores[1:]])) 28 | all_score.append(all([overlap_scores[0] > y for y in overlap_scores[1:]])) 29 | 30 | if (idx + 1) % 10000 == 0: 31 | print(f"{np.mean(avg_score):.4f} average ({len(avg_score)} instances), {np.mean(all_score):.4f} all ({len(all_score)} instances)") 32 | 33 | elif args.dataset.endswith(".tsv"): 34 | with open(args.dataset, "r") as f: 35 | data = [x.split("\t") for x in f.read().strip().split("\n")] 36 | for idx in tqdm.tqdm(range(0, len(data), args.num_negatives + 1)): 37 | prefix = data[idx][0] 38 | candidates = [] 39 | for jdx in range(args.num_negatives + 1): 40 | assert data[idx + jdx][0] == prefix 41 | candidates.append(data[idx + jdx][1]) 42 | assert len(candidates) == args.num_negatives + 1 43 | overlap_scores = [f1_score(x, prefix, stopwords=stopwords.words('english'))[0] for x in candidates] 44 | avg_score.append(np.mean([overlap_scores[0] > y for y in overlap_scores[1:]])) 45 | all_score.append(all([overlap_scores[0] > y for y in overlap_scores[1:]])) 46 | 47 | if (idx + 1) % 10000 == 0: 48 | print(f"{np.mean(avg_score):.4f} average ({len(avg_score)} instances), {np.mean(all_score):.4f} all ({len(all_score)} instances)") 49 | 50 | print(f"{np.mean(avg_score):.4f} average ({len(avg_score)} instances), {np.mean(all_score):.4f} all ({len(all_score)} instances)") -------------------------------------------------------------------------------- /rankgen/test_rankgen_encoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import torch 4 | import tqdm 5 | import os 6 | import numpy as np 7 | import time 8 | from rankgen import RankGenEncoder, RankGenGenerator 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model_path', default="kalpeshk2011/rankgen-t5-base-all", type=str) 12 | 13 | parser.add_argument('--cache_dir', default=None, type=str) 14 | args = parser.parse_args() 15 | 16 | test_example_file_map = { 17 | "kalpeshk2011/rankgen-t5-base-all": "rankgen_data/test_examples/t5_base_all.jsonl", 18 | "kalpeshk2011/rankgen-t5-large-all": "rankgen_data/test_examples/t5_large_all.jsonl", 19 | "kalpeshk2011/rankgen-t5-xl-all": "rankgen_data/test_examples/t5_xl_all.jsonl", 20 | "kalpeshk2011/rankgen-t5-xl-pg19": "rankgen_data/test_examples/t5_xl_pg19.jsonl" 21 | } 22 | 23 | rankgen_encoder = RankGenEncoder(args.model_path) 24 | 25 | parameters = sum(p.numel() for p in rankgen_encoder.model.parameters()) 26 | 27 | f = open(test_example_file_map[args.model_path], "r") 28 | examples = [json.loads(x) for x in f.read().strip().split("\n")] 29 | 30 | mean_prefix_diff = [] 31 | mean_suffix_diff = [] 32 | 33 | start = time.time() 34 | all_prefix_outs = rankgen_encoder.encode([x["inputs"]["inputs_pretokenized"] for x in examples], vectors_type="prefix", verbose=True, return_input_ids=True) 35 | all_suffix_outs = rankgen_encoder.encode([x["inputs"]["targets_pretokenized"] for x in examples], vectors_type="suffix", verbose=True, return_input_ids=True) 36 | time_taken = time.time() - start 37 | 38 | print(f"Time taken = {time_taken / len(examples)}") 39 | 40 | for eg_num, eg in tqdm.tqdm(enumerate(examples)): 41 | ref_prefix_vec = torch.Tensor(eg['score']['input_embedding']).cuda() 42 | ref_suffix_vec = torch.Tensor(eg['score']['target_embedding']).cuda() 43 | ref_prefix_ids = eg['inputs_processed']['prefix_ids'] 44 | ref_suffix_ids = eg['inputs_processed']['suffix_ids'] 45 | 46 | mean_prefix_diff.append(torch.mean(torch.abs(all_prefix_outs['embeddings'][eg_num] - ref_prefix_vec)).item()) 47 | mean_suffix_diff.append(torch.mean(torch.abs(all_suffix_outs['embeddings'][eg_num] - ref_suffix_vec)).item()) 48 | 49 | for x, y in zip(ref_prefix_ids, all_prefix_outs['input_ids'][eg_num]): 50 | assert x == y 51 | 52 | for x, y in zip(ref_suffix_ids, all_suffix_outs['input_ids'][eg_num]): 53 | assert x == y 54 | 55 | # Expected to be close to 10e-3 56 | print(np.mean(mean_prefix_diff)) 57 | print(np.mean(mean_suffix_diff)) 58 | -------------------------------------------------------------------------------- /rankgen/parallel/schedule.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import datetime 4 | import time 5 | import subprocess 6 | import socket 7 | 8 | # example to cancel jobs 9 | # squeue -u $USER | grep "job_6" | awk '{print $1}' | tail -n +2 | xargs scancel 10 | 11 | 12 | def get_run_id(): 13 | filename = "rankgen/parallel/parallel_logs/expts.txt" 14 | if os.path.isfile(filename) is False: 15 | with open(filename, 'w') as f: 16 | f.write("") 17 | return 0 18 | else: 19 | with open(filename, 'r') as f: 20 | expts = f.readlines() 21 | run_id = len(expts) / 5 22 | print(len(expts)) 23 | return run_id 24 | 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--command', default="python rankgen/gpt2_generate.py --model_size medium --output_file outputs/wiki_gpt2_medium_typical_p90.tsv --num_samples 20 --typical_p 0.9") 28 | parser.add_argument('--num_shards', default=20, type=int) 29 | parser.add_argument('--start_shard', default=None, type=int) 30 | parser.add_argument('--end_shard', default=None, type=int) 31 | parser.add_argument('--partition_type', default="gypsum-2080ti", type=str) 32 | args = parser.parse_args() 33 | 34 | script_command = args.command 35 | exp_id = int(get_run_id()) 36 | print(script_command) 37 | 38 | TOTAL = args.num_shards 39 | start_to_schedule = args.start_shard or 0 40 | end_to_schedule = args.end_shard or args.num_shards 41 | 42 | print(exp_id) 43 | gpu_list = [args.partition_type for i in range(40)] 44 | 45 | template = "rankgen/parallel/parallel_template_gpu.sh" 46 | 47 | print(template) 48 | 49 | with open(template, "r") as f: 50 | schedule_template = f.read() 51 | 52 | for i in range(start_to_schedule, end_to_schedule): 53 | 54 | curr_gpu = gpu_list[i % len(gpu_list)] 55 | 56 | os.makedirs("rankgen/parallel/parallel_schedulers/schedulers_exp_%d" % exp_id, exist_ok=True) 57 | os.makedirs("rankgen/parallel/parallel_logs/logs_exp_%d" % exp_id, exist_ok=True) 58 | 59 | curr_template = schedule_template.replace("", str(TOTAL)).replace("", str(i)) 60 | curr_template = curr_template.replace("", str(exp_id)).replace("", script_command) 61 | curr_template = curr_template.replace("", curr_gpu) 62 | 63 | with open("rankgen/parallel/parallel_schedulers/schedulers_exp_%d/schedule_%d.sh" % (exp_id, i), "w") as f: 64 | f.write(curr_template + "\n") 65 | 66 | command = "sbatch rankgen/parallel/parallel_schedulers/schedulers_exp_%d/schedule_%d.sh" % (exp_id, i) 67 | print(subprocess.check_output(command, shell=True)) 68 | time.sleep(0.2) 69 | 70 | output = f"Experiment ID {exp_id}\n" + \ 71 | "Script Command = " + script_command + "\n" + \ 72 | datetime.datetime.now().strftime("%Y-%m-%d %H:%M") + "\n" + \ 73 | "{:d} shards, {:d} - {:d} scheduled".format(TOTAL, start_to_schedule, end_to_schedule) + "\n" + \ 74 | "" + "\n\n" 75 | 76 | with open("rankgen/parallel/parallel_logs/expts.txt", "a") as f: 77 | f.write(output) 78 | -------------------------------------------------------------------------------- /rankgen/human_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import csv 4 | import os 5 | import numpy as np 6 | import random 7 | from utils import truncate 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--folder', default="ab_tests/t5_xxl_descartes_nucleus_vs_beam") 12 | parser.add_argument('--num_instances', default=50) 13 | args = parser.parse_args() 14 | 15 | BASE_DIR = "files_human_eval" 16 | 17 | files = [ 18 | # GPT2-medium 19 | # {"nucleus": ["pg19", "gpt2_medium", "pg19_gpt2_medium.jsonl"], 20 | # "beam": ["pg19", "gpt2_medium", "pg19_t5_xl_beam_2_tokens_20_samples_10.jsonl"]}, 21 | # {"nucleus": ["wiki", "gpt2_medium", "wiki_gpt2_medium.jsonl"], 22 | # "beam": ["wiki", "gpt2_medium", "wiki_t5_xl_beam_2_tokens_20_samples_10.jsonl"]}, 23 | # T5-XXL descartes 24 | {"nucleus": ["pg19", "t5_xxl_descartes", "pg19_t5_xxl_descartes.jsonl"], 25 | "beam": ["pg19", "t5_xxl_descartes", "pg19_beam_2_num_tokens_20_num_samples_10.jsonl"]}, 26 | {"nucleus": ["wiki", "t5_xxl_descartes", "wiki_t5_xxl_descartes.jsonl"], 27 | "beam": ["wiki", "t5_xxl_descartes", "wiki_beam_2_num_tokens_20_num_samples_10.jsonl"]}, 28 | ] 29 | 30 | random.seed(43) 31 | 32 | os.makedirs(args.folder, exist_ok=True) 33 | 34 | output = [["Prefix", "First", "Second", "Order", "Dataset", "Model"]] 35 | nucleus_tokens = [] 36 | beam_tokens = [] 37 | 38 | for file_pair in files: 39 | model = file_pair["nucleus"][1] 40 | dataset = file_pair["nucleus"][0] 41 | 42 | nucleus_file = file_pair["nucleus"][2] 43 | with open(f"{BASE_DIR}/{model}/{nucleus_file}", "r") as f: 44 | nucleus_data = [json.loads(x) for x in f.read().strip().split("\n")] 45 | 46 | beam_file = file_pair["beam"][2] 47 | with open(f"{BASE_DIR}/{model}/{beam_file}", "r") as f: 48 | beam_data = [json.loads(x) for x in f.read().strip().split("\n")] 49 | 50 | random.shuffle(nucleus_data) 51 | 52 | for i, nucleus_instance in enumerate(nucleus_data[:args.num_instances]): 53 | beam_instance = None 54 | for j, dd2 in enumerate(beam_data): 55 | if dd2['prefix'] == nucleus_instance['prefix']: 56 | beam_instance = dd2 57 | break 58 | 59 | assert beam_instance["prefix"] == nucleus_instance["prefix"] 60 | 61 | nucleus_gen = random.choice(nucleus_instance['targets'][1:]) 62 | order = random.random() 63 | 64 | if model == "gpt2_medium": 65 | beam_gen = beam_instance["targets"][1] 66 | elif model == "t5_xxl_descartes": 67 | beam_gen = beam_instance["t5_xxl_descartes_outputs"][0] 68 | beam_gen = truncate(beam_gen) 69 | 70 | nucleus_tokens.append(len(nucleus_gen.split())) 71 | beam_tokens.append(len(beam_gen.split())) 72 | 73 | if order < 0.5: 74 | output.append([ 75 | nucleus_instance["prefix"], beam_gen, nucleus_gen, "beam,nucleus", dataset, model 76 | ]) 77 | else: 78 | output.append([ 79 | nucleus_instance["prefix"], nucleus_gen, beam_gen, "nucleus,beam", dataset, model 80 | ]) 81 | 82 | with open(args.folder + "/input.csv", 'w', newline='') as csvfile: 83 | writer = csv.writer(csvfile) 84 | writer.writerows(output) 85 | 86 | print(np.mean(beam_tokens)) 87 | print(np.mean(nucleus_tokens)) 88 | -------------------------------------------------------------------------------- /rankgen/rankgen_beam_search.py: -------------------------------------------------------------------------------- 1 | from sys import prefix 2 | from transformers import T5Tokenizer, T5EncoderModel 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | import tqdm 7 | import os 8 | import torch 9 | import random 10 | import json 11 | from functools import partial 12 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 13 | from utils import form_partitions 14 | from rankgen import RankGenEncoder, RankGenGenerator 15 | from utils import truncate 16 | from transformers.utils import logging 17 | 18 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataset', default="rankgen_data/wiki.jsonl", type=str) 22 | parser.add_argument('--num_samples', default=10, type=int) 23 | parser.add_argument('--beam_size', default=2, type=int) 24 | parser.add_argument('--num_tokens', default=20, type=int) 25 | parser.add_argument('--max_length', default=115, type=int) 26 | parser.add_argument('--top_p', default=0.9, type=float) 27 | parser.add_argument('--model_size', default='medium', type=str) 28 | parser.add_argument('--cache_dir', default=None, type=str) 29 | parser.add_argument('--rankgen_encoder', default='kalpeshk2011/rankgen-t5-xl-all', type=str) 30 | parser.add_argument('--num_shards', default=1, type=int) 31 | parser.add_argument('--local_rank', default=0, type=int) 32 | parser.add_argument('--output_file', default=None, type=str) 33 | args = parser.parse_args() 34 | 35 | with open(args.dataset, "r") as f: 36 | data = [json.loads(x) for x in f.read().strip().split("\n")] 37 | 38 | if args.num_shards > 1: 39 | partitions = form_partitions(data, args.num_shards) 40 | data = partitions[args.local_rank] 41 | args.output_file = f'{args.output_file}.shard_{args.local_rank}' 42 | 43 | rankgen_encoder = RankGenEncoder(model_path=args.retriever_model_path, cache_dir=args.cache_dir) 44 | 45 | random.seed(49) 46 | random.shuffle(data) 47 | 48 | random.seed(442) 49 | random.shuffle(data) 50 | 51 | folder_name = f"token_bs_t5x" 52 | 53 | rankgen_generator = RankGenGenerator(rankgen_encoder=rankgen_encoder, language_model="gpt2-{args.model_size}", cache_dir=args.cache_dir) 54 | 55 | outputs = [] 56 | 57 | target_seq_len = [] 58 | gen_seq_len = [] 59 | 60 | logging.set_verbosity_error() 61 | 62 | if os.path.exists(args.output_file): 63 | with open(args.output_file, "r") as f: 64 | outputs = f.read().strip().split("\n") 65 | 66 | for kk, instance in tqdm.tqdm(enumerate(data), total=len(data)): 67 | if kk < len(outputs): 68 | continue 69 | token_beam_text, token_beam_scores = rankgen_generator.beam_search(contexts=[instance["prefix"]], 70 | beam_size=args.beam_size, 71 | top_p=args.top_p, 72 | num_tokens=args.num_tokens, 73 | num_samples=args.num_samples, 74 | max_length=args.max_length) 75 | 76 | token_beam_text = token_beam_text[0] 77 | token_beam_text = [truncate(" ".join(x.split())) for x in token_beam_text] 78 | if "scores" not in instance: 79 | instance["scores"] = [1.0] 80 | outputs.append(json.dumps({ 81 | "prefix": instance["prefix"], 82 | "targets": instance["targets"][0:1] + token_beam_text, 83 | "scores": instance["scores"][0:1] + token_beam_scores[0].cpu().tolist() 84 | })) 85 | target_seq_len.append(len(instance["targets"][0].split())) 86 | gen_seq_len.append(len(token_beam_text[0].split())) 87 | 88 | if (kk + 1) % 100 == 0: 89 | print(f"Avg lens ({kk + 1} instances) = {np.mean(gen_seq_len)} generation, {np.mean(target_seq_len)} target") 90 | print("Saving file...") 91 | with open(args.output_file, "w") as f: 92 | f.write("\n".join(outputs) + "\n") 93 | 94 | with open(args.output_file, "w") as f: 95 | f.write("\n".join(outputs) + "\n") 96 | -------------------------------------------------------------------------------- /rankgen/gpt3_score.py: -------------------------------------------------------------------------------- 1 | # Functions for receiving and evaluating GPT-3 response 2 | import openai 3 | import math 4 | import numpy as np 5 | import os 6 | import json 7 | import argparse 8 | import random 9 | 10 | from utils import pickle_load, pickle_dump 11 | 12 | openai.api_key = os.environ['OPENAI_API_KEY'] 13 | 14 | def get_response(prompt: str, max_tokens = 150, temperature = 0.7, top_p = 1, n = 1, logprobs = 1, stop = None, echo = True): 15 | response = openai.Completion.create(engine="davinci", 16 | prompt=prompt, 17 | max_tokens=max_tokens, 18 | temperature = temperature, 19 | top_p=top_p, 20 | n=n, 21 | logprobs=logprobs, 22 | stop=stop, 23 | echo=echo) 24 | return response 25 | 26 | def perplexity(log_probs): 27 | N = len(log_probs) 28 | return math.exp((-1/N) * np.sum(log_probs)) 29 | 30 | # Use max_tokens value passed to response to extract response PPL 31 | def evaluate_response(response, max_tokens): 32 | response_dict = dict(response['choices'][0]) 33 | text = response_dict['text'] 34 | 35 | log_probs = response_dict['logprobs']['token_logprobs'][1:] 36 | log_probs_prompt = log_probs[:-max_tokens] 37 | log_probs_response = log_probs[-max_tokens:] 38 | 39 | ppl_prompt = perplexity(log_probs_prompt) 40 | ppl_response = perplexity(log_probs_response) 41 | ppl_total = perplexity(log_probs) 42 | 43 | return { 44 | 'prompt_ppl': ppl_prompt, 45 | 'response_ppl': ppl_response, 46 | 'overall_ppl': ppl_total, 47 | 'text': text, 48 | } 49 | 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('--dataset', default="data/t5_xl_all_domains_wiki_random.jsonl") 52 | args = parser.parse_args() 53 | 54 | with open(args.dataset, "r") as f: 55 | data = [json.loads(x) for x in f.read().strip().split("\n")] 56 | 57 | gold_beats_neg_avg = [] 58 | gold_beats_neg_all = [] 59 | gold_beats_neg_any = [] 60 | 61 | for dd in data[:100]: 62 | if "gold_gpt3" in dd: 63 | print("skipping API call") 64 | write = False 65 | else: 66 | write = True 67 | gold = get_response(dd['prefix'].strip() + " " + dd['suffix'].strip(), 0) 68 | negs = [] 69 | for nn in dd['negatives']: 70 | negs.append( 71 | get_response(dd['prefix'].strip() + " " + nn.strip(), 0) 72 | ) 73 | dd['gold_gpt3'] = gold['choices'][0]['logprobs'] 74 | dd['negs_gpt3'] = [nn['choices'][0]['logprobs'] for nn in negs] 75 | 76 | prefix_len = 0 77 | prefix_so_far = "" 78 | for token in dd['gold_gpt3']['tokens']: 79 | prefix_so_far += token 80 | prefix_len += 1 81 | if prefix_so_far == dd['prefix']: 82 | break 83 | if prefix_len > len(dd['gold_gpt3']['tokens']) - 10: 84 | continue 85 | suffix_len = len(dd['gold_gpt3']['tokens']) - prefix_len 86 | 87 | gold_ppl = perplexity(dd['gold_gpt3']['token_logprobs'][prefix_len:]) 88 | neg_ppls = [perplexity(nn['token_logprobs'][prefix_len:]) for nn in dd['negs_gpt3']] 89 | 90 | # gold_ppl = perplexity(dd['gold_gpt3']['token_logprobs'][-1 * suffix_len:]) 91 | # neg_ppls = [perplexity(nn['token_logprobs'][-1 * suffix_len:]) for nn in dd['negs_gpt3']] 92 | print(gold_ppl) 93 | print(neg_ppls) 94 | 95 | gold_beats_neg_avg.extend( 96 | [gold_ppl < nppl for nppl in neg_ppls] 97 | ) 98 | gold_beats_neg_all.append( 99 | all([gold_ppl < nppl for nppl in neg_ppls]) 100 | ) 101 | gold_beats_neg_any.append( 102 | any([gold_ppl < nppl for nppl in neg_ppls]) 103 | ) 104 | 105 | print(f"Avg = {np.mean(gold_beats_neg_avg)} ({len(gold_beats_neg_avg)} instances)") 106 | print(f"All = {np.mean(gold_beats_neg_all)} ({len(gold_beats_neg_all)} instances)") 107 | print(f"Any = {np.mean(gold_beats_neg_any)} ({len(gold_beats_neg_any)} instances)") 108 | 109 | if write: 110 | output = "\n".join([json.dumps(x) for x in data]) + "\n" 111 | with open(args.dataset, "w") as f: 112 | f.write(output) 113 | -------------------------------------------------------------------------------- /rankgen/score_multi_beam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | import numpy as np 5 | from torch import trunc 6 | import mauve 7 | import pickle 8 | import os 9 | from utils import truncate 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--dataset', default="outputs_beam/wiki_t5_large_beam_1_tokens_115_samples_1.jsonl") 13 | parser.add_argument('--domain', default="wiki") 14 | parser.add_argument('--gen_key_type', default="second_idx") 15 | parser.add_argument('--data_length', default=7713, type=int) 16 | parser.add_argument('--max_mauve_length', default=768, type=int) 17 | parser.add_argument('--truncate', default=None, type=int) 18 | parser.add_argument('--refresh', action='store_true') 19 | args = parser.parse_args() 20 | 21 | with open(args.dataset, 'r') as f: 22 | data = [json.loads(x) for x in f.read().strip().split("\n")] 23 | 24 | data_dict = {x["prefix"]: x for x in data} 25 | 26 | mauve_output_key = "random_gen_mauve" if "random" in args.gen_key_type else "max_gen_mauve" 27 | 28 | 29 | if args.domain == "wiki": 30 | with open("data/multi_outs/t5_xxl_descartes_wiki_ppl.jsonl", "r") as f: 31 | raw_inp_data = [json.loads(x) for x in f.read().strip().split("\n")] 32 | for rid in raw_inp_data: 33 | assert rid["prefix"] in data_dict 34 | assert rid["targets"][0] == data_dict[rid["prefix"]]["targets"][0] 35 | elif args.domain == "pg19": 36 | with open("data_new/ppl/pg19_t5_xxl.jsonl", "r") as f: 37 | raw_inp_data = [json.loads(x) for x in f.read().strip().split("\n")] 38 | for rid in raw_inp_data: 39 | assert rid["prefix"] in data_dict 40 | assert rid["targets"][0] == data_dict[rid["prefix"]]["targets"][0] 41 | elif args.domain != "None": 42 | with open(args.domain, "r") as f: 43 | raw_inp_data = [json.loads(x) for x in f.read().strip().split("\n")] 44 | for rid in raw_inp_data: 45 | assert rid["prefix"] in data_dict 46 | assert rid["targets"][0] == data_dict[rid["prefix"]]["targets"][0] 47 | else: 48 | raw_inp_data = [None for _ in range(7711)] 49 | # print(len(data_dict)) 50 | # assert len(data) == len(raw_inp_data) 51 | # assert len(data_dict) == len(raw_inp_data) 52 | 53 | all_human = [] 54 | all_gen = [] 55 | num_tokens = [] 56 | 57 | output_file = args.dataset + ".mauve.pkl" 58 | if args.truncate: 59 | data = data[:args.truncate] 60 | output_file += f"{output_file}.truncate" 61 | 62 | for dd in data: 63 | all_human.append(dd['prefix'] + ' ' + dd['targets'][0]) 64 | if args.gen_key_type == "second_idx": 65 | # assert len(dd['targets']) != 21 66 | all_gen.append(dd['prefix'] + ' ' + dd['targets'][1]) 67 | num_tokens.append(len(dd['targets'][1].split())) 68 | elif args.gen_key_type == "random": 69 | random_gen = random.choice(dd['targets'][1:]) 70 | all_gen.append(dd['prefix'] + ' ' + random_gen) 71 | num_tokens.append(len(random_gen.split())) 72 | elif args.gen_key_type.startswith("random_"): 73 | gen_key_type = args.gen_key_type.replace("random_", "") 74 | random_gen = random.choice(dd[gen_key_type]) 75 | random_gen = " ".join(random_gen.split()) 76 | if "descartes" in args.gen_key_type: 77 | random_gen = truncate(random_gen) 78 | all_gen.append(dd['prefix'] + ' ' + random_gen) 79 | num_tokens.append(len(random_gen.split())) 80 | else: 81 | generation = dd[args.gen_key_type][0] 82 | generation = " ".join(generation.split()) 83 | if "descartes" in args.gen_key_type: 84 | generation = truncate(generation) 85 | num_tokens.append(len(generation.split())) 86 | all_gen.append(dd['prefix'] + ' ' + generation) 87 | 88 | print(np.mean(num_tokens)) 89 | 90 | if os.path.exists(output_file): 91 | with open(output_file, "rb") as f: 92 | mauve_data = pickle.load(f) 93 | else: 94 | mauve_data = {} 95 | 96 | if mauve_output_key in mauve_data and not args.refresh: 97 | print(f"Generation score mauve = {mauve_data[mauve_output_key].mauve}") 98 | else: 99 | mauve1 = mauve.compute_mauve(p_text=all_gen, q_text=all_human, device_id=0, max_text_length=args.max_mauve_length, verbose=False) 100 | print(f"Generation score mauve = {mauve1.mauve}") 101 | mauve_data[mauve_output_key] = mauve1 102 | 103 | if args.truncate is None: 104 | with open(args.dataset + ".mauve.pkl", "wb") as f: 105 | pickle.dump(mauve_data, f) 106 | -------------------------------------------------------------------------------- /rankgen/gpt2_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import numpy as np 4 | import tqdm 5 | import json 6 | import torch 7 | import os 8 | import random 9 | 10 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 11 | from utils import execute_gpt2, cudafy_tokens, form_partitions, truncate 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--dataset', default="data/t5_xl_all_domains_wiki_random.jsonl") 15 | parser.add_argument('--output_file', default="data/wiki_gpt2_medium_p90_multi.tsv") 16 | parser.add_argument('--model_size', default="medium") 17 | parser.add_argument('--num_instances', default=7713, type=int) 18 | parser.add_argument('--num_samples', default=1, type=int) 19 | parser.add_argument('--max_new_tokens', default=115, type=int) 20 | parser.add_argument('--top_k', default=None, type=int) 21 | parser.add_argument('--top_p', default=None, type=float) 22 | parser.add_argument('--typical_p', default=None, type=float) 23 | parser.add_argument('--truncate_fraction', default=0.0, type=float) 24 | parser.add_argument('--num_shards', default=1, type=int) 25 | parser.add_argument('--local_rank', default=0, type=int) 26 | args = parser.parse_args() 27 | 28 | with open(args.dataset, "r") as f: 29 | data = [json.loads(x) for x in f.read().strip().split("\n")] 30 | 31 | tokenizer = GPT2Tokenizer.from_pretrained(f"gpt2-{args.model_size}") 32 | tokenizer.pad_token = tokenizer.eos_token 33 | model = GPT2LMHeadModel.from_pretrained(f"gpt2-{args.model_size}") 34 | model.cuda() 35 | model.eval() 36 | 37 | avg_score = [] 38 | all_score = [] 39 | random.seed(43) 40 | device = "cuda" if torch.cuda.is_available() else "cpu" 41 | 42 | 43 | output = "" 44 | suffix_lens = [] 45 | gen_lens = [] 46 | 47 | def postprocess(outputs): 48 | return tokenizer.batch_decode(outputs, skip_special_tokens=True) 49 | 50 | def truncate(text): 51 | """Truncate text to the last full sentence.""" 52 | last_punc = 0 53 | if "." in text: 54 | last_punc = max(last_punc, text.rindex(".")) 55 | if "?" in text: 56 | last_punc = max(last_punc, text.rindex("?")) 57 | if "!" in text: 58 | last_punc = max(last_punc, text.rindex("!")) 59 | if last_punc != 0: 60 | text = text[:last_punc + 1] 61 | return text 62 | 63 | if args.num_shards > 1: 64 | partitions = form_partitions(data, args.num_shards) 65 | data = partitions[args.local_rank] 66 | args.output_file = f'{args.output_file}.shard_{args.local_rank}' 67 | 68 | for idx, dd in tqdm.tqdm(enumerate(data), total=min(len(data), args.num_instances)): 69 | if len(suffix_lens) >= args.num_instances: 70 | break 71 | prefix = dd['prefix'] 72 | batch = tokenizer(prefix, truncation=True, padding="longest", return_tensors="pt", max_length=1024 - args.max_new_tokens).to(device) 73 | num_tokens = len(batch['input_ids'][0]) 74 | if num_tokens >= 1024 - args.max_new_tokens - 3: 75 | print("long sequence detected") 76 | with torch.no_grad(): 77 | generation = model.generate(**batch, 78 | do_sample=True, 79 | output_scores=True, 80 | return_dict_in_generate=True, 81 | max_new_tokens=args.max_new_tokens, 82 | top_k=args.top_k, 83 | typical_p=args.typical_p, 84 | top_p=args.top_p, 85 | num_return_sequences=args.num_samples) 86 | gen_text = postprocess(generation['sequences'][:, num_tokens:]) 87 | gen_text = [" ".join(x.split()) for x in gen_text] 88 | gen_text = [truncate(x) for x in gen_text] 89 | 90 | for i in range(len(gen_text)): 91 | if random.random() < args.truncate_fraction: 92 | gen_text[i] = truncate(gen_text[i][:-1]) 93 | 94 | if "suffix" in dd: 95 | suffix_str = dd['suffix'] 96 | else: 97 | suffix_str = dd['targets'][0] 98 | 99 | suffix_lens.append(len(suffix_str.split())) 100 | for x in gen_text: 101 | gen_lens.append(len(x.split())) 102 | output += f"{prefix}\t{suffix_str}\tplaceholder\tplaceholder\n" 103 | for x in gen_text: 104 | output += f"{prefix}\t{x}\tplaceholder\tplaceholder\n" 105 | 106 | if (idx + 1) % 100 == 0: 107 | print(f"Avg suffix length = {np.mean(suffix_lens):.4f} ({len(suffix_lens)} samples), avg gen length = {np.mean(gen_lens):.4f} ({len(gen_lens)} samples)") 108 | with open(args.output_file, "w") as f: 109 | f.write(output) 110 | 111 | with open(args.output_file, "w") as f: 112 | f.write(output) 113 | -------------------------------------------------------------------------------- /rankgen/gpt2_generate_contrastive_search.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | from lib2to3.pgen2 import token 4 | import numpy as np 5 | import tqdm 6 | import json 7 | import torch 8 | import os 9 | import random 10 | import nltk 11 | 12 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, StoppingCriteriaList, MaxLengthCriteria 13 | from utils import execute_gpt2, cudafy_tokens, form_partitions, truncate 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--dataset', default="data/t5_xl_all_domains_wiki_random.jsonl") 17 | parser.add_argument('--output_file', default="data_new/contrastive_search/wiki_gpt2_medium_k_5_alpha_0.6.tsv") 18 | parser.add_argument('--model_size', default="medium") 19 | parser.add_argument('--num_instances', default=7713, type=int) 20 | parser.add_argument('--num_samples', default=1, type=int) 21 | parser.add_argument('--max_new_tokens', default=115, type=int) 22 | parser.add_argument('--top_k', default=5, type=int) 23 | parser.add_argument('--penalty_alpha', default=0.6, type=float) 24 | parser.add_argument('--truncate_fraction', default=0.0, type=float) 25 | parser.add_argument('--num_shards', default=1, type=int) 26 | parser.add_argument('--local_rank', default=0, type=int) 27 | args = parser.parse_args() 28 | 29 | with open(args.dataset, "r") as f: 30 | data = [json.loads(x) for x in f.read().strip().split("\n")] 31 | 32 | tokenizer = GPT2Tokenizer.from_pretrained(f"gpt2-{args.model_size}") 33 | tokenizer.pad_token = tokenizer.eos_token 34 | model = GPT2LMHeadModel.from_pretrained(f"gpt2-{args.model_size}") 35 | model.cuda() 36 | model.eval() 37 | 38 | avg_score = [] 39 | all_score = [] 40 | random.seed(43) 41 | device = "cuda" if torch.cuda.is_available() else "cpu" 42 | 43 | output = "" 44 | suffix_lens = [] 45 | gen_lens = [] 46 | 47 | def postprocess(outputs): 48 | return tokenizer.batch_decode(outputs, skip_special_tokens=True) 49 | 50 | def truncate(text): 51 | """Truncate text to the last full sentence.""" 52 | last_punc = 0 53 | if "." in text: 54 | last_punc = max(last_punc, text.rindex(".")) 55 | if "?" in text: 56 | last_punc = max(last_punc, text.rindex("?")) 57 | if "!" in text: 58 | last_punc = max(last_punc, text.rindex("!")) 59 | if last_punc != 0: 60 | text = text[:last_punc + 1] 61 | return text 62 | 63 | if args.num_shards > 1: 64 | partitions = form_partitions(data, args.num_shards) 65 | data = partitions[args.local_rank] 66 | args.output_file = f'{args.output_file}.shard_{args.local_rank}' 67 | 68 | for idx, dd in tqdm.tqdm(enumerate(data), total=min(len(data), args.num_instances)): 69 | if len(suffix_lens) >= args.num_instances: 70 | break 71 | prefix = dd['prefix'] 72 | batch = tokenizer(prefix, truncation=True, padding="longest", return_tensors="pt", max_length=1024 - args.max_new_tokens).to(device) 73 | num_tokens = len(batch['input_ids'][0]) 74 | if num_tokens >= 1024 - args.max_new_tokens - 3: 75 | print("long sequence detected") 76 | stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=num_tokens + args.max_new_tokens)]) 77 | with torch.inference_mode(): 78 | generation = model.contrastive_search(**batch, 79 | penalty_alpha=args.penalty_alpha, 80 | top_k=args.top_k, 81 | output_scores=True, 82 | return_dict_in_generate=True, 83 | stopping_criteria=stopping_criteria, 84 | eos_token_id=tokenizer.eos_token_id, 85 | pad_token_id=tokenizer.pad_token_id) 86 | gen_text = postprocess(generation['sequences'][:, num_tokens:]) 87 | gen_text = [" ".join(x.split()) for x in gen_text] 88 | gen_text = [truncate(x) for x in gen_text] 89 | 90 | for i in range(len(gen_text)): 91 | if random.random() < args.truncate_fraction: 92 | gen_text[i] = truncate(gen_text[i][:-1]) 93 | 94 | if "suffix" in dd: 95 | suffix_str = dd['suffix'] 96 | else: 97 | suffix_str = dd['targets'][0] 98 | 99 | suffix_lens.append(len(suffix_str.split())) 100 | for x in gen_text: 101 | gen_lens.append(len(x.split())) 102 | output += f"{prefix}\t{suffix_str}\tplaceholder\tplaceholder\n" 103 | for x in gen_text: 104 | output += f"{prefix}\t{x}\tplaceholder\tplaceholder\n" 105 | 106 | if (idx + 1) % 100 == 0: 107 | print(f"Avg suffix length = {np.mean(suffix_lens):.4f} ({len(suffix_lens)} samples), avg gen length = {np.mean(gen_lens):.4f} ({len(gen_lens)} samples)") 108 | with open(args.output_file, "w") as f: 109 | f.write(output) 110 | 111 | with open(args.output_file, "w") as f: 112 | f.write(output) 113 | -------------------------------------------------------------------------------- /rankgen/score_ab_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import json 4 | import csv 5 | import os 6 | import glob 7 | import numpy as np 8 | import pickle 9 | import random 10 | from regex import E 11 | import tqdm 12 | from datetime import datetime 13 | from collections import Counter, defaultdict 14 | from scipy.stats import kendalltau 15 | from statsmodels.stats.inter_rater import fleiss_kappa 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--dataset', default="human-eval-data/*") 20 | parser.add_argument('--split', default=None) 21 | parser.add_argument('--model', default=None) 22 | args = parser.parse_args() 23 | 24 | def get_annotation(hit): 25 | text = hit[33] 26 | if text.strip().lower().startswith("text 1"): 27 | return "text 1" 28 | if text.strip().lower().startswith("test 1"): 29 | return "text 1" 30 | elif text.strip().lower().startswith("text 2"): 31 | return "text 2" 32 | elif text.startswith("Neither of them sound like a good continuation to me, but I choose Text 1"): 33 | return "text 1" 34 | elif text == "Xcdsds" or text == "adfwegsdgsdgsdg": 35 | return None 36 | else: 37 | import pdb; pdb.set_trace() 38 | pass 39 | 40 | def get_multi_annotations(hits): 41 | anns = [get_annotation(x) for x in hits] 42 | return [x for x in anns if x] 43 | 44 | def most_common(lst): 45 | return max(set(lst), key=lst.count) 46 | 47 | def print_counter(x): 48 | x = Counter(x) 49 | total = sum([v for v in x.values()]) 50 | for k, v in x.items(): 51 | print(f"{k} = {v / total:.4f} ({v} / {total})") 52 | 53 | data = [] 54 | header = None 55 | files = glob.glob(args.dataset) 56 | for fl in files: 57 | curr_data = [] 58 | with open(fl, "r") as f: 59 | spamreader = csv.reader(f) 60 | for row in spamreader: 61 | curr_data.append(row) 62 | header = curr_data[0] 63 | data.extend(curr_data[1:]) 64 | 65 | if args.split is not None: 66 | data = [x for x in data if x[31] == args.split] 67 | 68 | if args.model is not None: 69 | data = [x for x in data if x[32] == args.model] 70 | 71 | worker_ids = list(set([x[15] for x in data])) 72 | worker_ids.sort() 73 | hit_ids = list(set([x[0] for x in data])) 74 | scarecrow_beam = defaultdict(list) 75 | 76 | all_workers = [] 77 | for worker in worker_ids: 78 | verdicts = [] 79 | data_small = [x for x in data if x[15] == worker] 80 | for dd in data_small: 81 | annotation = get_annotation(dd) 82 | text1, text2 = dd[30].split(",") 83 | if annotation == "text 1": 84 | verdicts.append(text1) 85 | elif annotation == "text 2": 86 | verdicts.append(text2) 87 | 88 | if annotation is not None and verdicts[-1] == "beam": 89 | # compute scarecrow stats for cases beam search triumphs 90 | scarecrow_beam[worker].append(dd[34]) 91 | 92 | all_workers.extend(verdicts) 93 | if verdicts: 94 | print(f"{worker} results:") 95 | print(Counter(verdicts)) 96 | 97 | 98 | # Agreement between annotators 99 | annotations = [] 100 | unique = [] 101 | table = [] 102 | majority = [] 103 | for hit_id in hit_ids: 104 | curr_entry = [0, 0] 105 | data_small = [x for x in data if x[0] == hit_id] 106 | workers = [x[15] for x in data_small] 107 | text1, text2 = data_small[0][30].split(",") 108 | 109 | anns = get_multi_annotations(data_small) 110 | 111 | unique.append(len(set(anns))) 112 | 113 | for ann in anns[:3]: 114 | if ann == "text 1": 115 | curr_entry[0] += 1 116 | elif ann == "text 2": 117 | curr_entry[1] += 1 118 | 119 | vote = most_common(anns) 120 | if vote == "text 1": 121 | majority.append(text1) 122 | elif vote == "text 2": 123 | majority.append(text2) 124 | table.append(curr_entry) 125 | 126 | 127 | table = np.array(table) 128 | 129 | print("") 130 | 131 | print(f"Fleiss ({len(table)} pairs) = {fleiss_kappa(table)}") 132 | print_counter(unique) 133 | print("") 134 | 135 | print("Majority vote accuracy ---") 136 | print_counter(majority) 137 | print("") 138 | 139 | print("Absolute accuracy ---") 140 | print_counter(all_workers) 141 | print("") 142 | 143 | # Scarecrow statistics 144 | 145 | def process_scarecrow(sc_anns): 146 | totals = defaultdict(int) 147 | for sca in sc_anns: 148 | types = [x.strip() for x in sca.split(", ")] 149 | types = [x for x in types if x != "equal" and x.strip()] 150 | for tp in types: 151 | totals[tp] += 1 / len(types) 152 | return totals 153 | 154 | scarecrow_list = [] 155 | for k, v in scarecrow_beam.items(): 156 | scarecrow_list.extend(v) 157 | 158 | print("All annotators --- ") 159 | scarecrow_list = [x for x in scarecrow_list if x] 160 | scarecrow_all = process_scarecrow(scarecrow_list) 161 | for k, v in scarecrow_all.items(): 162 | print(f"{k} = {v * 100 / len(scarecrow_list):.1f}") 163 | -------------------------------------------------------------------------------- /rankgen/rankgen_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | 5 | class RankGenGenerator(): 6 | def __init__(self, rankgen_encoder, language_model="gpt2-medium", cache_dir=None): 7 | self.rankgen_encoder = rankgen_encoder 8 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 9 | 10 | self.tokenizer = AutoTokenizer.from_pretrained(language_model, cache_dir=cache_dir) 11 | self.tokenizer.pad_token = self.tokenizer.eos_token 12 | self.language_model = AutoModelForCausalLM.from_pretrained(language_model, cache_dir=cache_dir) 13 | self.language_model.to(self.device) 14 | self.language_model.eval() 15 | 16 | def rankgen_scorer(self, prefix, suffixes, prefix_vector=None): 17 | rankgen_model = self.rankgen_encoder 18 | if prefix_vector is None: 19 | prefix_vector = rankgen_model.encode(prefix, vectors_type="prefix")["embeddings"] 20 | suffix_vectors = rankgen_model.encode(suffixes, vectors_type="suffix")["embeddings"] 21 | similarities = torch.matmul(prefix_vector, suffix_vectors.t()).squeeze(dim=0) 22 | return similarities, prefix_vector, suffix_vectors 23 | 24 | def postprocess(self, outputs): 25 | return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 26 | 27 | def generate_single(self, contexts, temperature=1.0, top_p=0.9, num_samples=10, max_length=115): 28 | return self.beam_search(contexts=contexts, 29 | beam_size=1, 30 | temperature=temperature, 31 | top_p=top_p, 32 | num_tokens=max_length, 33 | num_samples=1, 34 | max_length=max_length) 35 | 36 | def overgenerate_rerank(self, contexts, temperature=1.0, top_p=0.9, num_samples=10, max_length=115): 37 | return self.beam_search(contexts=contexts, 38 | beam_size=1, 39 | temperature=temperature, 40 | top_p=top_p, 41 | num_tokens=max_length, 42 | num_samples=num_samples, 43 | max_length=max_length) 44 | 45 | def beam_search(self, contexts, beam_size=2, temperature=1.0, top_p=0.9, num_tokens=20, num_samples=10, max_length=115): 46 | final_outputs = [] 47 | final_scores = [] 48 | total_generated_tokens = 0 49 | for ctx in contexts: 50 | if beam_size == 1 and num_samples == 1: 51 | prefix_vector = None 52 | else: 53 | _, prefix_vector, _ = self.rankgen_scorer(prefix=ctx, suffixes=[ctx]) 54 | beams = [{ 55 | "text": "", 56 | "eos": False 57 | } for _ in range(beam_size)] 58 | while True: 59 | all_outs = [] 60 | max_new_tokens = min(num_tokens, max_length - total_generated_tokens) 61 | for beam in beams: 62 | # if a beam has ended, add it to all_outs 63 | if beam["eos"]: 64 | all_outs.append(beam) 65 | continue 66 | # otherwise generate the next n tokens 67 | inputs = self.tokenizer(ctx + beam['text'], truncation=True, padding="longest", 68 | return_tensors="pt", max_length=1024 - max_new_tokens).to(self.device) 69 | num_input_tokens = len(inputs['input_ids'][0]) 70 | with torch.inference_mode(): 71 | curr_outs = self.language_model.generate(**inputs, do_sample=True, output_scores=True, 72 | return_dict_in_generate=True, 73 | max_new_tokens=max_new_tokens, top_k=None, top_p=top_p, 74 | num_return_sequences=num_samples, temperature=temperature) 75 | is_eos = [] 76 | for curr_out in curr_outs['sequences']: 77 | if self.tokenizer.eos_token_id in curr_out: 78 | is_eos.append(True) 79 | else: 80 | is_eos.append(False) 81 | curr_outs_text = self.postprocess(curr_outs['sequences'][:, num_input_tokens:]) 82 | for text, eos in zip(curr_outs_text, is_eos): 83 | # update all_outs 84 | all_outs.append({ 85 | "text": beam["text"] + text, 86 | "eos": eos 87 | }) 88 | # Each beam has total_generated_tokens length 89 | total_generated_tokens += max_new_tokens 90 | if len(all_outs) > 1: 91 | # skip beam scoring if only one output to choose from 92 | scores, _, _ = self.rankgen_scorer(prefix=ctx, suffixes=[x["text"] for x in all_outs], prefix_vector=prefix_vector) 93 | top_scores, top_indices = torch.topk(scores, k=beam_size) 94 | beams = [all_outs[x] for x in top_indices] # only track the top k beams 95 | else: 96 | top_scores = torch.Tensor([1.0]) 97 | top_scores.cuda() 98 | beams = all_outs 99 | 100 | for beam in beams: 101 | if len(self.tokenizer.tokenize(beam["text"])) >= max_length: 102 | beam["eos"] = True 103 | 104 | if all([x["eos"] for x in beams]): 105 | final_outputs.append([x["text"] for x in beams]) 106 | final_scores.append(top_scores) 107 | break 108 | return final_outputs, final_scores -------------------------------------------------------------------------------- /rankgen/gpt2_generate_contrastive_decoding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tqdm 4 | import json 5 | import torch 6 | import os 7 | import random 8 | 9 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 10 | from utils import form_partitions, truncate 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--dataset', default="data/t5_xl_all_domains_wiki_random.jsonl") 14 | parser.add_argument('--output_file', default="data_new/contrastive_decoding/wiki_gpt2_medium_ignore_prefix.tsv") 15 | parser.add_argument('--model_size', default="medium") 16 | parser.add_argument('--num_instances', default=7713, type=int) 17 | parser.add_argument('--num_samples', default=1, type=int) 18 | parser.add_argument('--max_new_tokens', default=115, type=int) 19 | parser.add_argument('--top_k', default=5, type=int) 20 | parser.add_argument('--penalty_alpha', default=0.6, type=float) 21 | parser.add_argument('--truncate_fraction', default=0.0, type=float) 22 | parser.add_argument('--num_shards', default=1, type=int) 23 | parser.add_argument('--local_rank', default=0, type=int) 24 | args = parser.parse_args() 25 | 26 | with open(args.dataset, "r") as f: 27 | data = [json.loads(x) for x in f.read().strip().split("\n")] 28 | 29 | def ignore_prefix_prepare_inputs_for_generation(input_ids, past=None, **kwargs): 30 | 31 | token_type_ids = kwargs.get("token_type_ids", None) 32 | # only last token for inputs_ids if past is defined in kwargs 33 | input_ids = input_ids[:, -1].unsqueeze(-1) 34 | if token_type_ids is not None: 35 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 36 | 37 | attention_mask = kwargs.get("attention_mask", None) 38 | position_ids = kwargs.get("position_ids", None) 39 | 40 | if attention_mask is not None and position_ids is None: 41 | # create position_ids on the fly for batch generation 42 | position_ids = attention_mask.long().cumsum(-1) - 1 43 | position_ids.masked_fill_(attention_mask == 0, 1) 44 | position_ids = position_ids[:, -1].unsqueeze(-1) 45 | else: 46 | position_ids = None 47 | 48 | return { 49 | "input_ids": input_ids, 50 | "past_key_values": past, 51 | "use_cache": kwargs.get("use_cache"), 52 | "position_ids": position_ids, 53 | "attention_mask": attention_mask, 54 | "token_type_ids": token_type_ids, 55 | } 56 | 57 | 58 | student_lm = GPT2LMHeadModel.from_pretrained(f"gpt2") 59 | tokenizer = GPT2Tokenizer.from_pretrained(f"gpt2-{args.model_size}") 60 | tokenizer.pad_token = tokenizer.eos_token 61 | model = GPT2LMHeadModel.from_pretrained(f"gpt2-{args.model_size}") 62 | model.cuda() 63 | model.eval() 64 | student_lm.cuda() 65 | 66 | student_lm.prepare_inputs_for_generation = ignore_prefix_prepare_inputs_for_generation 67 | 68 | avg_score = [] 69 | all_score = [] 70 | random.seed(43) 71 | device = "cuda" if torch.cuda.is_available() else "cpu" 72 | 73 | output = "" 74 | suffix_lens = [] 75 | gen_lens = [] 76 | 77 | def postprocess(outputs): 78 | return tokenizer.batch_decode(outputs, skip_special_tokens=True) 79 | 80 | def truncate(text): 81 | """Truncate text to the last full sentence.""" 82 | last_punc = 0 83 | if "." in text: 84 | last_punc = max(last_punc, text.rindex(".")) 85 | if "?" in text: 86 | last_punc = max(last_punc, text.rindex("?")) 87 | if "!" in text: 88 | last_punc = max(last_punc, text.rindex("!")) 89 | if last_punc != 0: 90 | text = text[:last_punc + 1] 91 | return text 92 | 93 | if args.num_shards > 1: 94 | partitions = form_partitions(data, args.num_shards) 95 | data = partitions[args.local_rank] 96 | args.output_file = f'{args.output_file}.shard_{args.local_rank}' 97 | 98 | for idx, dd in tqdm.tqdm(enumerate(data), total=min(len(data), args.num_instances)): 99 | if len(suffix_lens) >= args.num_instances: 100 | break 101 | prefix = dd['prefix'] 102 | batch = tokenizer(prefix, truncation=True, padding="longest", return_tensors="pt", max_length=1024 - args.max_new_tokens).to(device) 103 | num_tokens = len(batch['input_ids'][0]) 104 | if num_tokens >= 1024 - args.max_new_tokens - 3: 105 | print("long sequence detected") 106 | 107 | with torch.inference_mode(): 108 | generation = model.generate( 109 | **batch, 110 | temperature=1.0, 111 | top_k=0, 112 | top_p=1.0, 113 | min_prob=0.0, 114 | do_sample=False, 115 | num_beams=5, 116 | max_length=num_tokens + args.max_new_tokens, 117 | num_return_sequences=1, 118 | student_lm=student_lm, 119 | teacher_student=True, 120 | model_kwargs_student={}, 121 | st_coef=1.0, 122 | tokenizer=tokenizer, # analysis 123 | student_min_prob=0.0, 124 | student_temperature=0.5, 125 | use_cap_student=False, #cap student debug 126 | use_switch=False 127 | ) 128 | gen_text = postprocess(generation[:, num_tokens:]) 129 | gen_text = [" ".join(x.split()) for x in gen_text] 130 | gen_text = [truncate(x) for x in gen_text] 131 | 132 | for i in range(len(gen_text)): 133 | if random.random() < args.truncate_fraction: 134 | gen_text[i] = truncate(gen_text[i][:-1]) 135 | 136 | if "suffix" in dd: 137 | suffix_str = dd['suffix'] 138 | else: 139 | suffix_str = dd['targets'][0] 140 | 141 | suffix_lens.append(len(suffix_str.split())) 142 | for x in gen_text: 143 | gen_lens.append(len(x.split())) 144 | output += f"{prefix}\t{suffix_str}\tplaceholder\tplaceholder\n" 145 | for x in gen_text: 146 | output += f"{prefix}\t{x}\tplaceholder\tplaceholder\n" 147 | 148 | if (idx + 1) % 100 == 0: 149 | print(f"Avg suffix length = {np.mean(suffix_lens):.4f} ({len(suffix_lens)} samples), avg gen length = {np.mean(gen_lens):.4f} ({len(gen_lens)} samples)") 150 | with open(args.output_file, "w") as f: 151 | f.write(output) 152 | 153 | with open(args.output_file, "w") as f: 154 | f.write(output) 155 | -------------------------------------------------------------------------------- /rankgen/rankgen_encoder.py: -------------------------------------------------------------------------------- 1 | from transformers import T5Tokenizer, T5EncoderModel, AutoModel 2 | import pickle 3 | import argparse 4 | import numpy as np 5 | import tqdm 6 | import os 7 | import torch 8 | 9 | class RankGenEncoder(): 10 | def __init__(self, model_path, max_batch_size=32, model_size=None, cache_dir=None): 11 | assert model_path in ["kalpeshk2011/rankgen-t5-xl-all", "kalpeshk2011/rankgen-t5-xl-pg19", "kalpeshk2011/rankgen-t5-base-all", "kalpeshk2011/rankgen-t5-large-all"] 12 | self.max_batch_size = max_batch_size 13 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | if model_size is None: 15 | if "t5-large" in model_path or "t5_large" in model_path: 16 | self.model_size = "large" 17 | elif "t5-xl" in model_path or "t5_xl" in model_path: 18 | self.model_size = "xl" 19 | else: 20 | self.model_size = "base" 21 | else: 22 | self.model_size = model_size 23 | 24 | self.tokenizer = T5Tokenizer.from_pretrained(f"google/t5-v1_1-{self.model_size}", cache_dir=cache_dir) 25 | self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True) 26 | self.model.to(self.device) 27 | self.model.eval() 28 | 29 | def encode(self, inputs, vectors_type="prefix", verbose=False, return_input_ids=False): 30 | tokenizer = self.tokenizer 31 | max_batch_size = self.max_batch_size 32 | if isinstance(inputs, str): 33 | inputs = [inputs] 34 | if vectors_type == 'prefix': 35 | inputs = ['pre ' + input for input in inputs] 36 | max_length = 512 37 | else: 38 | inputs = ['suffi ' + input for input in inputs] 39 | max_length = 128 40 | 41 | all_embeddings = [] 42 | all_input_ids = [] 43 | for i in tqdm.tqdm(range(0, len(inputs), max_batch_size), total=(len(inputs) // max_batch_size) + 1, disable=not verbose, desc=f"Encoding {vectors_type} inputs:"): 44 | tokenized_inputs = tokenizer(inputs[i:i + max_batch_size], return_tensors="pt", padding=True) 45 | for k, v in tokenized_inputs.items(): 46 | tokenized_inputs[k] = v[:, :max_length] 47 | tokenized_inputs = tokenized_inputs.to(self.device) 48 | with torch.inference_mode(): 49 | batch_embeddings = self.model(**tokenized_inputs) 50 | all_embeddings.append(batch_embeddings) 51 | if return_input_ids: 52 | all_input_ids.extend(tokenized_inputs.input_ids.cpu().tolist()) 53 | return { 54 | "embeddings": torch.cat(all_embeddings, dim=0), 55 | "input_ids": all_input_ids 56 | } 57 | 58 | 59 | class T5XEmbeddingGeneratorLegacy(): 60 | '''This class is deprecated, use RankGenEncoder.''' 61 | 62 | def __init__(self, max_batch_size=32, model_path='.', model_size=None, cache_dir=None): 63 | self.max_batch_size = max_batch_size 64 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 65 | if model_size is None: 66 | if "t5_large" in model_path: 67 | self.model_size = "large" 68 | elif "t5_xl" in model_path: 69 | self.model_size = "xl" 70 | else: 71 | self.model_size = "base" 72 | else: 73 | self.model_size = model_size 74 | 75 | with open(os.path.join(model_path, 'state_dict.pickle'), 'rb') as handle: 76 | state_dict = pickle.load(handle) 77 | 78 | state_dict_new = {} 79 | 80 | for k, v in state_dict.items(): 81 | if k != "encoder.embed_tokens.weight": 82 | v = np.transpose(v) 83 | state_dict_new[k] = torch.Tensor(v) 84 | else: 85 | state_dict_new[k] = torch.Tensor(v) 86 | state_dict_new["shared.weight"] = torch.Tensor(v) 87 | 88 | with open(os.path.join(model_path, 'projection.pickle'), 'rb') as handle: 89 | self.projection = torch.Tensor(pickle.load(handle)) # (1024, 1024), numpy array 90 | 91 | self.projection = self.projection.to(self.device) 92 | self.tokenizer = T5Tokenizer.from_pretrained(f"google/t5-v1_1-{self.model_size}", cache_dir=cache_dir) 93 | model = T5EncoderModel.from_pretrained(f"google/t5-v1_1-{self.model_size}", cache_dir=cache_dir) 94 | state_dict_keys = [k for k in state_dict_new.keys()] 95 | self.model, _, _, _, _ = T5EncoderModel._load_pretrained_model(model, state_dict_new, state_dict_keys, None, f"google/t5-v1_1-{self.model_size}") 96 | self.model.to(self.device) 97 | self.model.eval() 98 | 99 | def encode(self, inputs, vectors_type="prefix", verbose=False, return_input_ids=False): 100 | tokenizer = self.tokenizer 101 | max_batch_size = self.max_batch_size 102 | if isinstance(inputs, str): 103 | inputs = [inputs] 104 | if vectors_type == 'prefix': 105 | inputs = ['pre ' + input for input in inputs] 106 | max_length = 512 107 | else: 108 | inputs = ['suffi ' + input for input in inputs] 109 | max_length = 128 110 | 111 | all_embeddings = [] 112 | all_input_ids = [] 113 | for i in tqdm.tqdm(range(0, len(inputs), max_batch_size), total=(len(inputs) // max_batch_size) + 1, disable=not verbose, desc=f"Encoding {vectors_type} inputs:"): 114 | tokenized_inputs = tokenizer(inputs[i:i + max_batch_size], return_tensors="pt", padding=True) 115 | for k, v in tokenized_inputs.items(): 116 | tokenized_inputs[k] = v[:, :max_length] 117 | tokenized_inputs = tokenized_inputs.to(self.device) 118 | with torch.no_grad(): 119 | hidden_states = self.model(**tokenized_inputs).last_hidden_state 120 | hidden_states = hidden_states[:, 0, :] 121 | batch_embeddings = torch.matmul(hidden_states, self.projection) 122 | all_embeddings.append(batch_embeddings) 123 | if return_input_ids: 124 | all_input_ids.extend(tokenized_inputs.input_ids.cpu().tolist()) 125 | return { 126 | "embeddings": torch.cat(all_embeddings, dim=0), 127 | "input_ids": all_input_ids 128 | } 129 | -------------------------------------------------------------------------------- /rankgen/score_multi_tsv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | import random 5 | import numpy as np 6 | import mauve 7 | import glob 8 | import os 9 | import tqdm 10 | import pickle 11 | from nltk import tokenize 12 | import spacy 13 | from nltk.corpus import stopwords 14 | from utils import f1_score, rep_statistic 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--dataset', default="data/multi_outs/t5_xxl_descartes_wiki_greedy.tsv") 18 | parser.add_argument('--num_samples', default=None, type=int) 19 | parser.add_argument('--num_runs', default=1, type=int) 20 | parser.add_argument('--rep_window', default=20, type=int) 21 | parser.add_argument('--num_instances', default=7713, type=int) 22 | parser.add_argument('--eval_mauve', action='store_true') 23 | parser.add_argument('--eval_pos_overlap', action='store_true') 24 | args = parser.parse_args() 25 | 26 | 27 | files = glob.glob(args.dataset) 28 | 29 | base_dir = os.path.dirname(files[0]) 30 | assert all([os.path.dirname(x) == base_dir for x in files]) 31 | files = ['pg19_gpt2_medium.tsv', 'wiki_gpt2_medium.tsv', 'pg19_gpt2_xl.tsv', 'wiki_gpt2_xl.tsv', 32 | 'pg19_t5_xxl.tsv', 'wiki_t5_xxl.tsv', 'pg19_t5_xxl_descartes.tsv', 'wiki_t5_xxl_descartes.tsv'] 33 | files = [os.path.join(base_dir, f) for f in files] 34 | 35 | latex_token_overlap = [] 36 | latex_rep_score = [] 37 | latex_token_overlap_ents = [] 38 | latex_mauve = [] 39 | 40 | 41 | if args.eval_pos_overlap: 42 | nlp = spacy.load("en_core_web_sm") 43 | 44 | for file in files: 45 | if not os.path.exists(file): 46 | continue 47 | print(file) 48 | with open(file, 'r') as f: 49 | data = [x.split('\t') for x in f.read().strip().split("\n")] 50 | data_dict = {dd[0]: dd[1] for dd in data} 51 | 52 | if "wiki_" in file: 53 | with open("data_new/ppl/wiki_t5_xxl.jsonl", "r") as f: 54 | raw_inp_data = [json.loads(x) for x in f.read().strip().split("\n")] 55 | for rid in tqdm.tqdm(raw_inp_data): 56 | assert rid["prefix"] in data_dict 57 | elif "pg19_" in file: 58 | with open("data_new/ppl/pg19_t5_xxl.jsonl", "r") as f: 59 | raw_inp_data = [json.loads(x) for x in f.read().strip().split("\n")] 60 | for rid in tqdm.tqdm(raw_inp_data): 61 | assert rid["prefix"] in data_dict 62 | 63 | 64 | if args.num_samples is None: 65 | args.num_samples = (len(data) // args.num_instances) - 1 66 | 67 | assert len(data) % (args.num_samples + 1) == 0 68 | 69 | token_overlaps = { 70 | "human": [], 71 | "random": [] 72 | } 73 | token_overlaps_ents = { 74 | "human": [], 75 | "random": [] 76 | } 77 | rep_scores = { 78 | "human": [], 79 | "random": [] 80 | } 81 | 82 | all_mauve = [] 83 | for idx in range(args.num_runs): 84 | all_human = [] 85 | all_gen = [] 86 | for i in tqdm.tqdm(range(0, len(data), args.num_samples + 1)): 87 | gen_suffices = [] 88 | for j in range(1, args.num_samples + 1): 89 | assert data[i][0] == data[i + j][0] 90 | gen_suffices.append(data[i + j][1]) 91 | 92 | random_gen = random.choice(gen_suffices) 93 | all_human.append(data[i][0] + ' ' + data[i][1]) 94 | all_gen.append(data[i][0] + ' ' + random_gen) 95 | 96 | token_overlaps["human"].append( 97 | f1_score(data[i][1], data[i][0], stopwords=stopwords.words('english'))[0] 98 | ) 99 | token_overlaps["random"].append( 100 | f1_score(random_gen, data[i][0], stopwords=stopwords.words('english'))[0] 101 | ) 102 | rep_scores["human"].append(rep_statistic(data[i][0], data[i][1], window=args.rep_window)) 103 | rep_scores["random"].append(rep_statistic(data[i][0], random_gen, window=args.rep_window)) 104 | 105 | if args.eval_pos_overlap and not os.path.exists(file + ".ent_overlap.pkl"): 106 | prefix_nlp = nlp(data[i][0]) 107 | best_nlp = nlp(random_gen) 108 | prefix_ents = " ".join([x.lemma_.lower() for x in prefix_nlp if x.pos_ in ["PROPN", "NUM", "NOUN"]]) 109 | best_ents = " ".join([x.lemma_.lower() for x in best_nlp if x.pos_ in ["PROPN", "NUM", "NOUN"]]) 110 | 111 | token_overlaps_ents["random"].append( 112 | f1_score(best_ents, prefix_ents, stopwords=stopwords.words('english'))[0] 113 | ) 114 | 115 | print(f"Results for {file}...") 116 | print(f"token overlap = {np.mean(token_overlaps['human']):.3f} human, {np.mean(token_overlaps['random']):.3f} random") 117 | print(f"rep = {np.mean(rep_scores['human']):.3f} human, {np.mean(rep_scores['random']):.3f} random") 118 | 119 | latex_token_overlap.append(np.mean(token_overlaps['random'])) 120 | latex_rep_score.append(np.mean(rep_scores['random'])) 121 | 122 | if args.eval_pos_overlap and not os.path.exists(file + ".ent_overlap.pkl"): 123 | latex_token_overlap_ents.append(np.mean(token_overlaps_ents['random'])) 124 | with open(file + ".ent_overlap.pkl", "wb") as f: 125 | pickle.dump(token_overlaps_ents, f) 126 | else: 127 | with open(file + ".ent_overlap.pkl", "rb") as f: 128 | token_overlaps_ents = pickle.load(f) 129 | latex_token_overlap_ents.append(np.mean(token_overlaps_ents['random'])) 130 | 131 | if args.eval_mauve: 132 | mauve_file = file + ".mauve.pkl" 133 | if idx > 0: 134 | mauve_file += f".{idx}" 135 | if os.path.exists(mauve_file): 136 | with open(mauve_file, "rb") as f: 137 | mauve_data = pickle.load(f) 138 | mauve1 = mauve_data["max_gen_mauve"] 139 | else: 140 | mauve1 = mauve.compute_mauve(p_text=all_gen, q_text=all_human, device_id=0, max_text_length=768, verbose=False) 141 | outputs = { 142 | "max_gen_mauve": mauve1 143 | } 144 | with open(mauve_file, "wb") as f: 145 | pickle.dump(outputs, f) 146 | # print(mauve1.mauve) 147 | all_mauve.append(mauve1.mauve) 148 | 149 | if args.eval_mauve: 150 | print(np.mean(all_mauve)) 151 | latex_mauve.append(np.mean(all_mauve)) 152 | 153 | print(f"Latex token overlap = {' & '.join([f'{100 * x:.1f}' for x in latex_token_overlap])} & {np.mean(latex_token_overlap):.3f}") 154 | print(f"Latex rep = {' & '.join([f'{100 * x:.1f}' for x in latex_rep_score])} & {np.mean(latex_rep_score):.3f}") 155 | print(f"Latex token overlap ents = {' & '.join([f'{100 * x:.1f}' for x in latex_token_overlap_ents])} & {np.mean(latex_token_overlap_ents):.3f}") 156 | print(f"Mauve latex = {' & '.join([f'{x:.3f}' for x in latex_mauve])} & {np.mean(latex_mauve):.3f}") 157 | -------------------------------------------------------------------------------- /rankgen/rankgen_beam_search_choose_eg.py: -------------------------------------------------------------------------------- 1 | from sys import prefix 2 | from transformers import T5Tokenizer, T5EncoderModel 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | import tqdm 7 | import os 8 | import torch 9 | import random 10 | import time 11 | import json 12 | from functools import partial 13 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 14 | from utils import form_partitions 15 | from rankgen_encoder import RankGenEncoder 16 | from utils import truncate, export_server 17 | from transformers.utils import logging 18 | 19 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--dataset', default="data_new/t5_xxl_ppl/pg19_gpt2_medium.jsonl", type=str) 23 | parser.add_argument('--num_samples', default=10, type=int) 24 | parser.add_argument('--beam_size', default=1, type=int) 25 | parser.add_argument('--num_tokens', default=40, type=int) 26 | parser.add_argument('--max_length', default=115, type=int) 27 | parser.add_argument('--top_p', default=0.9, type=float) 28 | parser.add_argument('--model_size', default='xl', type=str) 29 | parser.add_argument('--cache_dir', default=None, type=str) 30 | parser.add_argument('--retriever_model_path', default='t5x_conversion/t5_xl_all', type=str) 31 | parser.add_argument('--num_shards', default=1, type=int) 32 | parser.add_argument('--local_rank', default=0, type=int) 33 | parser.add_argument('--output_file', default=None, type=str) 34 | args = parser.parse_args() 35 | 36 | with open(args.dataset, "r") as f: 37 | data = [json.loads(x) for x in f.read().strip().split("\n")] 38 | 39 | if args.num_shards > 1: 40 | partitions = form_partitions(data, args.num_shards) 41 | data = partitions[args.local_rank] 42 | args.output_file = f'{args.output_file}.shard_{args.local_rank}' 43 | 44 | t5x_embedder = RankGenEncoder(model_path=args.retriever_model_path, cache_dir=args.cache_dir) 45 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 46 | 47 | random.seed(49) 48 | random.shuffle(data) 49 | 50 | random.seed(442) 51 | random.shuffle(data) 52 | 53 | folder_name = f"token_bs_t5x" 54 | 55 | device = "cuda" if torch.cuda.is_available() else "cpu" 56 | 57 | tokenizer = GPT2Tokenizer.from_pretrained(f"gpt2-{args.model_size}", cache_dir=args.cache_dir) 58 | tokenizer.pad_token = tokenizer.eos_token 59 | model = GPT2LMHeadModel.from_pretrained(f"gpt2-{args.model_size}", cache_dir=args.cache_dir) 60 | model.to(device) 61 | model.eval() 62 | 63 | 64 | def postprocess(outputs): 65 | return tokenizer.batch_decode(outputs, skip_special_tokens=True) 66 | 67 | 68 | def scorer_t5x(t5x_embedder, prefix, suffixes, prefix_vector=None): 69 | if prefix_vector is None: 70 | prefix_vector = t5x_embedder.encode(prefix, vectors_type="prefix")["embeddings"] 71 | suffix_vectors = t5x_embedder.encode(suffixes, vectors_type="suffix")["embeddings"] 72 | similarities = torch.matmul(prefix_vector, suffix_vectors.t()).squeeze(dim=0) 73 | return similarities, prefix_vector, suffix_vectors 74 | 75 | 76 | def token_beam_search(contexts, scorer, beam_size=3, temperature=1.0, top_p=0.9, num_tokens=5, num_samples=10, max_length=115): 77 | final_outputs = [] 78 | final_scores = [] 79 | total_generated_tokens = 0 80 | output_str = "" 81 | for ctx in contexts: 82 | if beam_size == 1 and num_samples == 1: 83 | prefix_vector = None 84 | else: 85 | _, prefix_vector, _ = scorer(prefix=ctx, suffixes=[ctx]) 86 | beams = [{ 87 | "text": "", 88 | "eos": False 89 | } for _ in range(beam_size)] 90 | while True: 91 | all_outs = [] 92 | max_new_tokens = min(num_tokens, max_length - total_generated_tokens) 93 | for beam in beams: 94 | # if a beam has ended, add it to all_outs 95 | if beam["eos"]: 96 | all_outs.append(beam) 97 | continue 98 | # otherwise generate the next n tokens 99 | inputs = tokenizer(ctx + beam['text'], truncation=True, padding="longest", 100 | return_tensors="pt", max_length=1024 - max_new_tokens).to(device) 101 | num_input_tokens = len(inputs['input_ids'][0]) 102 | curr_outs = model.generate(**inputs, do_sample=True, output_scores=True, 103 | return_dict_in_generate=True, 104 | max_new_tokens=max_new_tokens, top_k=None, top_p=top_p, 105 | num_return_sequences=num_samples, temperature=temperature) 106 | is_eos = [] 107 | for curr_out in curr_outs['sequences']: 108 | if tokenizer.eos_token_id in curr_out: 109 | is_eos.append(True) 110 | else: 111 | is_eos.append(False) 112 | curr_outs_text = postprocess(curr_outs['sequences'][:, num_input_tokens:]) 113 | for text, eos in zip(curr_outs_text, is_eos): 114 | # update all_outs 115 | all_outs.append({ 116 | "text": beam["text"] + text, 117 | "eos": eos 118 | }) 119 | # Each beam has total_generated_tokens length 120 | total_generated_tokens += max_new_tokens 121 | if len(all_outs) > 1: 122 | # skip beam scoring if only one output to choose from 123 | scores, _, _ = scorer(prefix=ctx, suffixes=[x["text"] for x in all_outs], prefix_vector=prefix_vector) 124 | top_scores, top_indices = torch.topk(scores, k=beam_size) 125 | if len(beams[0]["text"]) == 0: 126 | output_str = f"Prefix = {contexts[0]}\n\n" 127 | for sc, text in zip(scores, all_outs): 128 | output_str += f"{sc} = {text['text']}\n\n" 129 | return output_str 130 | beams = [all_outs[x] for x in top_indices] # only track the top k beams 131 | else: 132 | top_scores = torch.Tensor([1.0]) 133 | top_scores.cuda() 134 | beams = all_outs 135 | 136 | for beam in beams: 137 | if len(tokenizer.tokenize(beam["text"])) >= max_length: 138 | beam["eos"] = True 139 | 140 | if all([x["eos"] for x in beams]): 141 | final_outputs.append([x["text"] for x in beams]) 142 | final_scores.append(top_scores) 143 | break 144 | 145 | scorer_fn = partial(scorer_t5x, t5x_embedder=t5x_embedder) 146 | 147 | outputs = [] 148 | 149 | target_seq_len = [] 150 | gen_seq_len = [] 151 | 152 | logging.set_verbosity_error() 153 | 154 | random.seed(time.time()) 155 | random.shuffle(data) 156 | 157 | for kk, instance in tqdm.tqdm(enumerate(data), total=len(data)): 158 | 159 | output_str = token_beam_search(contexts=[instance["prefix"]], scorer=scorer_fn, beam_size=args.beam_size, 160 | top_p=args.top_p, num_tokens=args.num_tokens, num_samples=args.num_samples, 161 | max_length=args.max_length) 162 | 163 | export_server(output_str, f"rankgen_beam_samples/pg19_xl/{kk}") 164 | -------------------------------------------------------------------------------- /rankgen/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | import string 4 | import pickle 5 | import collections as cll 6 | import torch 7 | import numpy as np 8 | import subprocess 9 | 10 | 11 | def cudafy_tokens(tokens): 12 | for x, y in tokens.items(): 13 | tokens[x] = y.cuda() 14 | return tokens 15 | 16 | 17 | def extend_sequence(sents, sent_lens, start, limit, exceed_len=False, 18 | direction='prefix', skip_sentences=None): 19 | """Extend a sequence by adding more sentences in prefix or suffix.""" 20 | curr_value = start 21 | total_length = sent_lens[curr_value] 22 | full_sequence = sents[curr_value] 23 | assert len(sents) == len(sent_lens) 24 | 25 | if direction == 'prefix': 26 | increment = -1 27 | concat_fn = lambda curr, extra: extra + ' ' + curr 28 | continue_fn = lambda x: x >= 0 29 | else: 30 | increment = 1 31 | concat_fn = lambda curr, extra: curr + ' ' + extra 32 | continue_fn = lambda x: x < len(sents) 33 | 34 | while total_length < limit and continue_fn(curr_value + increment): 35 | proposed_length = total_length + sent_lens[curr_value + increment] 36 | if not exceed_len and proposed_length > limit: 37 | break 38 | if skip_sentences and (curr_value + increment) in skip_sentences: 39 | break 40 | curr_value += increment 41 | full_sequence = concat_fn(curr=full_sequence, extra=sents[curr_value]) 42 | total_length += sent_lens[curr_value] 43 | 44 | if direction == 'prefix': 45 | assert curr_value <= start 46 | else: 47 | assert curr_value >= start 48 | 49 | return full_sequence, curr_value 50 | 51 | 52 | class Bcolors: 53 | HEADER = '\033[95m' 54 | OKBLUE = '\033[94m' 55 | OKGREEN = '\033[92m' 56 | WARNING = '\033[93m' 57 | FAIL = '\033[91m' 58 | ENDC = '\033[0m' 59 | BOLD = '\033[1m' 60 | UNDERLINE = '\033[4m' 61 | 62 | @classmethod 63 | def postprocess(cls, input_str): 64 | input_str = input_str.replace("", cls.HEADER) 65 | input_str = input_str.replace("", cls.OKBLUE) 66 | input_str = input_str.replace("", cls.OKGREEN) 67 | input_str = input_str.replace("", cls.WARNING) 68 | input_str = input_str.replace("", cls.FAIL) 69 | input_str = input_str.replace("", cls.ENDC) 70 | input_str = input_str.replace("", cls.BOLD) 71 | input_str = input_str.replace("", cls.UNDERLINE) 72 | input_str = input_str.replace("", "") 73 | return input_str 74 | 75 | 76 | def export_server(output, filename): 77 | with open("{}.txt".format(filename), "w") as f: 78 | f.write(Bcolors.postprocess(output) + "\n") 79 | subprocess.check_output("cat {0}.txt | ansi2html.sh --palette=linux --bg=dark > {0}.html".format(filename), shell=True) 80 | subprocess.check_output("rm {}.txt".format(filename), shell=True) 81 | 82 | 83 | def form_partitions(dataset, num_shards): 84 | p_indices = np.round(np.linspace(0, len(dataset), num_shards + 1)) 85 | p_indices = [int(x) for x in p_indices] 86 | partitions = [dataset[p_indices[i]:p_indices[i + 1]] for i in range(len(p_indices) - 1)] 87 | assert len(partitions) == num_shards 88 | return partitions 89 | 90 | 91 | def truncate(text): 92 | last_punc = 0 93 | if "." in text: 94 | last_punc = max(last_punc, text.rindex(".")) 95 | if "?" in text: 96 | last_punc = max(last_punc, text.rindex("?")) 97 | if "!" in text: 98 | last_punc = max(last_punc, text.rindex("!")) 99 | if ";" in text: 100 | last_punc = max(last_punc, text.rindex(";")) 101 | if last_punc != 0: 102 | text = text[:last_punc + 1] 103 | return text 104 | 105 | 106 | def execute_gpt2(relevant_window, text_token_ids, tokenizer, model, output_hidden_states=False): 107 | num_ans_tokens = len(text_token_ids[0]) 108 | inputs = tokenizer(" " + relevant_window, return_tensors="pt") 109 | inputs = cudafy_tokens(inputs) 110 | assert torch.equal( 111 | inputs["input_ids"][0, -1 * num_ans_tokens:], 112 | text_token_ids[0] 113 | ) 114 | 115 | with torch.no_grad(): 116 | outputs = model(**inputs, labels=inputs["input_ids"], output_hidden_states=output_hidden_states) 117 | 118 | text_logits = outputs["logits"][0, -1 * num_ans_tokens - 1:-1, :] 119 | 120 | text_softmax = torch.nn.functional.softmax(text_logits, dim=1) 121 | softmax_ranks = torch.argsort(text_softmax, dim=1, descending=True) 122 | text_probs = torch.gather(text_softmax, 1, text_token_ids.t()) 123 | log_probs = torch.log(text_probs) 124 | ppl = torch.exp(-1 * log_probs.sum() / num_ans_tokens).item() 125 | 126 | ranks = [softmax_ranks[i].tolist().index(text_token_ids[0][i].item()) + 1 for i in range(num_ans_tokens)] 127 | 128 | if output_hidden_states: 129 | return outputs["hidden_states"][-1][0, -1 * num_ans_tokens - 1:-1, :] 130 | else: 131 | return ranks, text_probs.squeeze().cpu().numpy(), log_probs.squeeze().cpu().numpy(), ppl 132 | 133 | def normalize_answer(s): 134 | """Lower text and remove punctuation, articles and extra whitespace.""" 135 | 136 | def remove_articles(text): 137 | return re.sub(r'\b(a|an|the)\b', ' ', text) 138 | 139 | def white_space_fix(text): 140 | return ' '.join(text.split()) 141 | 142 | def remove_punc(text): 143 | exclude = set(string.punctuation) 144 | return ''.join(ch for ch in text if ch not in exclude) 145 | 146 | def lower(text): 147 | return text.lower() 148 | 149 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 150 | 151 | 152 | def rep_statistic(prefix, suffix, window=20): 153 | prefix_tokens = normalize_answer(prefix).split() 154 | suffix_tokens = normalize_answer(suffix).split() 155 | start_pos = len(prefix_tokens) 156 | tokens = prefix_tokens + suffix_tokens 157 | reps = [tokens[i] in tokens[i - window:i] for i in range(start_pos, len(tokens))] 158 | if len(reps) == 0: 159 | return 0.0 160 | else: 161 | return np.mean(reps) 162 | 163 | 164 | def f1_score(prediction, ground_truth, gram=1, stopwords=None): 165 | """Calculate word level F1 score.""" 166 | prediction = normalize_answer(prediction) 167 | ground_truth = normalize_answer(ground_truth) 168 | prediction_tokens = prediction.split() 169 | ground_truth_tokens = ground_truth.split() 170 | prediction_tokens = [ 171 | " ".join(prediction_tokens[i:i + gram]) 172 | for i in range(0, len(prediction_tokens) - gram + 1) 173 | ] 174 | ground_truth_tokens = [ 175 | " ".join(ground_truth_tokens[i:i + gram]) 176 | for i in range(0, len(ground_truth_tokens) - gram + 1) 177 | ] 178 | 179 | if stopwords: 180 | prediction_tokens = [x for x in prediction_tokens if x not in stopwords] 181 | ground_truth_tokens = [x for x in ground_truth_tokens if x not in stopwords] 182 | 183 | if not prediction_tokens and not ground_truth_tokens: 184 | return 1.0, 1.0, 1.0 185 | common = cll.Counter(prediction_tokens) & cll.Counter(ground_truth_tokens) 186 | num_same = sum(common.values()) 187 | if num_same == 0: 188 | return 0, 0, 0 189 | precision = 1.0 * num_same / len(prediction_tokens) 190 | recall = 1.0 * num_same / len(ground_truth_tokens) 191 | f1 = (2 * precision * recall) / (precision + recall) 192 | return precision, recall, f1 193 | 194 | def pickle_load(file): 195 | with open(file, "rb") as f: 196 | data = pickle.load(f) 197 | return data 198 | 199 | 200 | def pickle_dump(file, data): 201 | with open(file, "wb") as f: 202 | pickle.dump(data, f) 203 | -------------------------------------------------------------------------------- /rankgen/gpt2_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | from lib2to3.pgen2 import token 4 | import numpy as np 5 | import tqdm 6 | import json 7 | import torch 8 | import os 9 | 10 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 11 | from utils import execute_gpt2, cudafy_tokens 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--dataset', default="rankgen_data/t5_xl_all_domains_pg19_random.jsonl") 15 | parser.add_argument('--model_size', default="medium") 16 | parser.add_argument('--metric', default="avg_conditional") 17 | parser.add_argument('--num_negatives', default=10, type=int) 18 | parser.add_argument('--max_examples', default=7713, type=int) 19 | parser.add_argument('--batch_size', default=6, type=int) 20 | args = parser.parse_args() 21 | 22 | tokenizer = GPT2Tokenizer.from_pretrained(f"gpt2-{args.model_size}") 23 | tokenizer.pad_token = tokenizer.eos_token 24 | model = GPT2LMHeadModel.from_pretrained(f"gpt2-{args.model_size}") 25 | model.cuda() 26 | model.eval() 27 | 28 | args.output_path = f"gold-beats-neg-outputs/{os.path.basename(args.dataset)}_{args.model_size}_{args.metric}.txt" 29 | 30 | avg_score = [] 31 | all_score = [] 32 | 33 | def compute_gpt2(sequences, prefix=None, length_normalize=True): 34 | with torch.inference_mode(): 35 | inputs = cudafy_tokens(tokenizer(sequences, return_tensors="pt", padding=True, truncation=True)) 36 | outputs = model(**inputs) 37 | out_log_probs = torch.nn.functional.log_softmax(outputs["logits"], dim=-1) 38 | gold_log_probs = torch.gather(out_log_probs[:, :-1, :], 2, inputs['input_ids'][:, 1:].unsqueeze(-1)).squeeze() 39 | token_mask = inputs['input_ids'][:, 1:] != tokenizer.pad_token_id 40 | 41 | if prefix: 42 | num_prefix_toks = len(tokenizer(prefix)['input_ids']) 43 | gold_log_probs = gold_log_probs[:, num_prefix_toks - 1:] 44 | token_mask = token_mask[:, num_prefix_toks - 1:] 45 | 46 | gold_log_probs = gold_log_probs * token_mask 47 | if length_normalize: 48 | perplexities = torch.exp(-1 * gold_log_probs.sum(dim=1) / token_mask.sum(dim=1)) 49 | else: 50 | perplexities = -1 * gold_log_probs.sum(dim=1) 51 | perplexities = perplexities.cpu().tolist() 52 | return perplexities 53 | 54 | def compute_pmi(sequences, suffixes): 55 | with torch.inference_mode(): 56 | inputs = cudafy_tokens(tokenizer(sequences, return_tensors="pt", padding=True, truncation=True)) 57 | outputs = model(**inputs) 58 | out_log_probs = torch.nn.functional.log_softmax(outputs["logits"], dim=-1) 59 | gold_log_probs = torch.gather(out_log_probs[:, :-1, :], 2, inputs['input_ids'][:, 1:].unsqueeze(-1)).squeeze() 60 | token_mask = inputs['input_ids'][:, 1:] != tokenizer.pad_token_id 61 | gold_log_probs = gold_log_probs * token_mask 62 | 63 | inputs = cudafy_tokens(tokenizer(suffixes, return_tensors="pt", padding=True, truncation=True)) 64 | outputs = model(**inputs) 65 | out_log_probs = torch.nn.functional.log_softmax(outputs["logits"], dim=-1) 66 | gold_log_probs2 = torch.gather(out_log_probs[:, :-1, :], 2, inputs['input_ids'][:, 1:].unsqueeze(-1)).squeeze() 67 | token_mask2 = inputs['input_ids'][:, 1:] != tokenizer.pad_token_id 68 | gold_log_probs2 = gold_log_probs2 * token_mask2 69 | 70 | perplexities = gold_log_probs2.sum(dim=1) - gold_log_probs.sum(dim=1) 71 | perplexities = perplexities.cpu().tolist() 72 | return perplexities 73 | 74 | if args.dataset.endswith(".jsonl"): 75 | with open(args.dataset, "r") as f: 76 | data = [json.loads(x) for x in f.read().strip().split("\n")] 77 | 78 | if os.path.exists(args.output_path): 79 | with open(args.output_path, "r") as f: 80 | outputs = [x for x in f.read().strip().split("\n")] 81 | else: 82 | outputs = [] 83 | 84 | for idx, dd in tqdm.tqdm(enumerate(data[:args.max_examples]), total=args.max_examples): 85 | if idx < len(outputs): 86 | continue 87 | prefix = dd['prefix'] 88 | if 'targets' in dd: 89 | candidates = dd['targets'] 90 | else: 91 | candidates = [dd['suffix']] + dd['negatives'] 92 | assert len(candidates) == args.num_negatives + 1 93 | sequences = [prefix.strip() + " " + x.strip() for x in candidates] 94 | perplexities = [] 95 | for i in range(0, len(sequences), args.batch_size): 96 | batch_seq = sequences[i:i + args.batch_size] 97 | batch_suffix = [x.strip() for x in candidates[i:i + args.batch_size]] 98 | if args.metric == "avg_conditional": 99 | perplexities += compute_gpt2(batch_seq, dd['prefix']) 100 | elif args.metric == "pmi": 101 | perplexities += compute_pmi(batch_seq, batch_suffix) 102 | elif args.metric == "avg_unconditional": 103 | perplexities += compute_gpt2(batch_seq) 104 | elif args.metric == "conditional": 105 | perplexities += compute_gpt2(batch_seq, prefix, length_normalize=False) 106 | 107 | avg_score.append(np.mean([perplexities[0] < y for y in perplexities[1:]])) 108 | all_score.append(all([perplexities[0] < y for y in perplexities[1:]])) 109 | 110 | if (idx + 1) % 100 == 0: 111 | print(f"{np.mean(avg_score):.4f} average ({len(avg_score)} instances), {np.mean(all_score):.4f} all ({len(all_score)} instances)") 112 | 113 | outputs.append(json.dumps({ 114 | "prefix": prefix, 115 | "targets": candidates, 116 | "scores": [-1 * x for x in perplexities] 117 | })) 118 | 119 | if idx % 100 == 0: 120 | with open(args.output_path, "w") as f: 121 | f.write("\n".join(outputs) + "\n") 122 | 123 | with open(args.output_path, "w") as f: 124 | f.write("\n".join(outputs) + "\n") 125 | 126 | elif args.dataset.endswith(".tsv"): 127 | with open(args.dataset, "r") as f: 128 | data = [x.split("\t") for x in f.read().strip().split("\n")] 129 | 130 | outputs = [] 131 | if args.output_path: 132 | output_path = args.output_path 133 | else: 134 | output_path = args.dataset + ".ppl_scores" 135 | 136 | if os.path.exists(output_path): 137 | with open(output_path, 'r') as f: 138 | outputs = [x for x in f.read().strip().split("\n")] 139 | 140 | for dd in outputs: 141 | dd = json.loads(dd) 142 | avg_score.append(np.mean([dd['scores'][0] > y for y in dd['scores'][1:]])) 143 | all_score.append(all([dd['scores'][0] > y for y in dd['scores'][1:]])) 144 | 145 | for idx in tqdm.tqdm(range(len(outputs) * (args.num_negatives + 1), len(data), args.num_negatives + 1)): 146 | prefix = data[idx][0] 147 | candidates = [] 148 | for jdx in range(args.num_negatives + 1): 149 | assert data[idx + jdx][0] == prefix 150 | candidates.append(data[idx + jdx][1]) 151 | assert len(candidates) == args.num_negatives + 1 152 | sequences = [prefix.strip() + " " + x.strip() for x in candidates] 153 | 154 | perplexities = [] 155 | for i in range(0, len(sequences), args.batch_size): 156 | batch_seq = sequences[i:i + args.batch_size] 157 | batch_suffix = [x.strip() for x in candidates[i:i + args.batch_size]] 158 | if args.metric == "avg_conditional": 159 | perplexities += compute_gpt2(batch_seq, prefix) 160 | elif args.metric == "pmi": 161 | perplexities += compute_pmi(batch_seq, batch_suffix) 162 | elif args.metric == "avg_unconditional": 163 | perplexities += compute_gpt2(batch_seq) 164 | elif args.metric == "conditional": 165 | perplexities += compute_gpt2(batch_seq, prefix, length_normalize=False) 166 | 167 | avg_score.append(np.mean([perplexities[0] < y for y in perplexities[1:]])) 168 | all_score.append(all([perplexities[0] < y for y in perplexities[1:]])) 169 | 170 | assert len(candidates) == len(perplexities) 171 | outputs.append(json.dumps({ 172 | "prefix": prefix, 173 | "targets": candidates, 174 | "scores": [-1 * x for x in perplexities] 175 | })) 176 | 177 | if len(avg_score) % 100 == 0: 178 | print(f"{np.mean(avg_score):.4f} average ({len(avg_score)} instances), {np.mean(all_score):.4f} all ({len(all_score)} instances)") 179 | with open(output_path, "w") as f: 180 | f.write("\n".join(outputs) + "\n") 181 | 182 | with open(output_path, "w") as f: 183 | f.write("\n".join(outputs) + "\n") 184 | 185 | print(f"{np.mean(avg_score):.4f} average ({len(avg_score)} instances), {np.mean(all_score):.4f} all ({len(all_score)} instances)") 186 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## RankGen - Improving Text Generation with Large Ranking Models 2 | 3 | [![made-with-python](https://img.shields.io/badge/Made%20with-Python-red.svg)](#python) 4 | [![arxiv](https://img.shields.io/badge/arXiv-2205.09726-b31b1b.svg)](https://arxiv.org/abs/2205.09726) 5 | [![PyPI version rankgen](https://badge.fury.io/py/rankgen.svg)](https://pypi.python.org/pypi/rankgen/) [![License: Apache 2.0](https://img.shields.io/badge/License-Apache--2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 6 | [![Downloads](https://pepy.tech/badge/rankgen)](https://pepy.tech/project/rankgen) 7 | 8 | This is the official repository for our EMNLP 2022 paper, [RankGen - Improving Text Generation with Large Ranking Models](https://arxiv.org/abs/2205.09726). RankGen is a 1.2 billion encoder model which maps prefixes and generations from any pretrained English language model to a shared vector space. RankGen can be used to rerank multiple full-length samples from an LM, and it can also be incorporated as a scoring function into beam search to significantly improve generation quality (0.85 vs 0.77 [MAUVE](https://arxiv.org/abs/2102.01454), 75% preference according to humans annotators who are English writers). RankGen can also be used like a dense retriever, and achieves state-of-the-art performance on [literary retrieval](https://relic.cs.umass.edu/leaderboard.html). 9 | 10 | This repository contains human evaluation data, links to HuggingFace-compatible model checkpoints, and code to integrate RankGen in beam search on HuggingFace models. RankGen is trained by fine-tuning the T5-XL encoder using the [T5X library](https://github.com/google-research/t5x). 11 | 12 | ### Updates 13 | 14 | * (Mar 2023) The training data for RankGen is now available (PG19 and Wiki splits)! You can get them on Google Cloud (`gs://gresearch/rankgen/rankgen_pp_wiki_v1.zip`, `gs://gresearch/rankgen/rankgen_pp_pg19_v1.zip`) 15 | * (Nov 2022) We have updated our [arXiv version](https://arxiv.org/abs/2205.09726) to show that RankGen beats newer decoding strategies like contrastive search, contrastive decoding and eta sampling! 16 | * (July 2022) RankGen is now a [PyPI package](https://pypi.org/project/rankgen), just run `pip install rankgen` to use it! 17 | * (July 2022) RankGen checkpoints are now available on the HuggingFace Model Hub ([link](https://huggingface.co/kalpeshk2011))! 18 | 19 | ### Model checkpoints 20 | 21 | All RankGen checkpoints are available on the HuggingFace Model Hub - [link](https://huggingface.co/kalpeshk2011) 22 | 23 | We recommend using `RankGen-XL-all`. 24 | 25 | | Checkpoint | Size | Hub Model Name | HF Hub Link | 26 | |-------------------|------|-----------------------------------|------------------------------------------------------------------| 27 | | RankGen-base-all | 0.1B | kalpeshk2011/rankgen-t5-base-all | [link](https://huggingface.co/kalpeshk2011/rankgen-t5-base-all) | 28 | | RankGen-large-all | 0.3B | kalpeshk2011/rankgen-t5-large-all | [link](https://huggingface.co/kalpeshk2011/rankgen-t5-large-all) | 29 | | RankGen-XL-all | 1.2B | kalpeshk2011/rankgen-t5-xl-all | [link](https://huggingface.co/kalpeshk2011/rankgen-t5-xl-all) | 30 | | RankGen-XL-PG19 | 1.2B | kalpeshk2011/rankgen-t5-xl-pg19 | [link](https://huggingface.co/kalpeshk2011/rankgen-t5-xl-pg19) | 31 | 32 | *Older versions of the checkpoints*: 33 | 34 | RankGen XL checkpoints compatible with `T5XEmbeddingGeneratorLegacy` - [here](https://drive.google.com/drive/folders/1m8ujkAqkBBWYAJISZigz1Lw4tQGbZXaY?usp=sharing) 35 | 36 | T5X JAX checkpoints (base, large, XL) - [here](https://github.com/google-research/google-research/tree/master/rankgen) 37 | 38 | ### Setup 39 | 40 | **Requirements** (`pip` will install these dependencies for you) 41 | 42 | Python 3.7+, `torch` (CUDA recommended), `transformers` 43 | 44 | **Installation** 45 | 46 | (from PyPI) 47 | 48 | ``` 49 | python3.7 -m virtualenv rankgen-venv 50 | source rankgen-venv/bin/activate 51 | pip install rankgen 52 | ``` 53 | 54 | (from source) 55 | 56 | ``` 57 | python3.7 -m virtualenv rankgen-venv 58 | source rankgen-venv/bin/activate 59 | git clone https://github.com/martiansideofthemoon/rankgen 60 | cd rankgen 61 | pip install --editable . 62 | ``` 63 | 64 | **Data Download / Test** 65 | 66 | Get the data [here](https://drive.google.com/drive/folders/1DRG2ess7fK3apfB-6KoHb_azMuHbsIv4?usp=sharing) and place folder in root directory. Alternatively, use `gdown` as shown below, 67 | 68 | ``` 69 | gdown --folder https://drive.google.com/drive/folders/1DRG2ess7fK3apfB-6KoHb_azMuHbsIv4 70 | ``` 71 | 72 | Run the test script to make sure the RankGen checkpoint has loaded correctly, 73 | 74 | ``` 75 | python -m rankgen.test_rankgen_encoder --model_path kalpeshk2011/rankgen-t5-base-all 76 | 77 | ### Expected output 78 | 0.0009239262409127233 79 | 0.0011521980725477804 80 | ``` 81 | 82 | ### Using RankGen 83 | 84 | Loading RankGen is simple using the HuggingFace APIs, but we suggest using [`RankGenEncoder`](rankgen/rankgen_encoder.py), which is a small wrapper around the HuggingFace APIs for correctly preprocessing data and doing tokenization automatically. Please see [`rankgen/test_rankgen_encoder.py`](rankgen/test_rankgen_encoder.py) for an example of the usage or see below. 85 | 86 | ``` 87 | from rankgen import RankGenEncoder, RankGenGenerator 88 | 89 | rankgen_encoder = RankGenEncoder("kalpeshk2011/rankgen-t5-xl-all") 90 | ``` 91 | 92 | **Encoding text to prefix/suffix vectors** 93 | 94 | ``` 95 | prefix_vectors = rankgen_encoder.encode(["This is a prefix sentence."], vectors_type="prefix") 96 | suffix_vectors = rankgen_encoder.encode(["This is a suffix sentence."], vectors_type="suffix") 97 | ``` 98 | 99 | **Generating text** 100 | 101 | ``` 102 | # use a HuggingFace compatible language model 103 | generator = RankGenGenerator(rankgen_encoder=rankgen_encoder, language_model="gpt2-medium") 104 | 105 | inputs = ["Whatever might be the nature of the tragedy it would be over with long before this, and those moving black spots away yonder to the west, that he had discerned from the bluff, were undoubtedly the departing raiders. There was nothing left for Keith to do except determine the fate of the unfortunates, and give their bodies decent burial. That any had escaped, or yet lived, was altogether unlikely, unless, perchance, women had been in the party, in which case they would have been borne away prisoners."] 106 | 107 | # Baseline nucleus sampling 108 | print(generator.generate_single(inputs, top_p=0.9)[0][0]) 109 | # Over-generate and re-rank 110 | print(generator.overgenerate_rerank(inputs, top_p=0.9, num_samples=10)[0][0]) 111 | # Beam search 112 | print(generator.beam_search(inputs, top_p=0.9, num_samples=10, beam_size=2)[0][0]) 113 | ``` 114 | 115 | ### Reproducing experiments in the paper 116 | 117 | **Running beam search with RankGen** 118 | 119 | The main file is [`rankgen/rankgen_beam_search.py`](rankgen/rankgen_beam_search.py). To execute it, 120 | 121 | ``` 122 | python rankgen/rankgen_beam_search.py \ 123 | --dataset rankgen_data/wiki.jsonl \ 124 | --rankgen_encoder kalpeshk2011/rankgen-t5-xl-all \ 125 | --num_tokens 20 --num_samples 10 --beam_size 2 \ 126 | --output_file outputs_beam/wiki_t5_xl_beam_2_tokens_20_samples_10.jsonl 127 | ``` 128 | 129 | Evaluating using MAUVE (make sure JSONL file has several thousand generations for intuitive MAUVE scores, 7713 in our experiments), 130 | 131 | ``` 132 | python rankgen/score_multi_beam.py --dataset outputs_beam/wiki_t5_xl_beam_2_tokens_10_samples_10.jsonl 133 | ``` 134 | 135 | **Suffix Identification with GPT2** 136 | 137 | The main file is [`rankgen/rankgen_beam_search.py`](rankgen/.py). To execute it, 138 | 139 | ``` 140 | mkdir gold-beats-neg-outputs 141 | python rankgen/gpt2_score.py \ 142 | --dataset rankgen_data/hellaswag_val.tsv \ 143 | --model_size xl \ 144 | --metric avg_conditional \ 145 | --num_negatives 3 146 | ``` 147 | 148 | The corresponding data files can be found in the same Google Drive [folder](https://drive.google.com/drive/folders/1DRG2ess7fK3apfB-6KoHb_azMuHbsIv4?usp=sharing). 149 | 150 | ### Human evaluation data 151 | 152 | We conducted our human evaluation on Upwork, hiring English teachers and writers. We performed blind A/B testing between RankGen and nucleus sampling. We also asked our annotators to provide a 1-3 sentence explanation. You can find all the 600 annotations across two files in [`human-eval-data`](human-eval-data). To compute the evaluation scores run, 153 | 154 | ``` 155 | python rankgen/score_ab_text.py 156 | ``` 157 | 158 | ### Citation Information 159 | If you use RankGen, please cite it as follows: 160 | ``` 161 | @inproceedings{rankgen22, 162 | author={Kalpesh Krishna and Yapei Chang and John Wieting and Mohit Iyyer}, 163 | booktitle = {Empirical Methods in Natural Language Processing}, 164 | Year = "2022", 165 | Title={RankGen: Improving Text Generation with Large Ranking Models}, 166 | } 167 | ``` 168 | -------------------------------------------------------------------------------- /rankgen/rankgen_beam_search_shifting.py: -------------------------------------------------------------------------------- 1 | from sys import prefix 2 | from transformers import T5Tokenizer, T5EncoderModel 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | import tqdm 7 | import os 8 | import torch 9 | import random 10 | import json 11 | import nltk 12 | from nltk.tokenize import sent_tokenize 13 | from functools import partial 14 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 15 | from utils import form_partitions 16 | from rankgen_encoder import RankGenEncoder 17 | from utils import truncate 18 | from transformers.utils import logging 19 | 20 | nltk.download('punkt') 21 | 22 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--dataset', default="data/multi_outs/t5_xxl_descartes_wiki_ppl.jsonl", type=str) 26 | parser.add_argument('--num_samples', default=10, type=int) 27 | parser.add_argument('--beam_size', default=2, type=int) 28 | parser.add_argument('--num_tokens', default=20, type=int) 29 | parser.add_argument('--top_p', default=0.9, type=float) 30 | parser.add_argument('--model_size', default='medium', type=str) 31 | parser.add_argument('--cache_dir', default=None, type=str) 32 | parser.add_argument('--retriever_model_path', default='t5x_conversion/t5_xl_all', type=str) 33 | parser.add_argument('--num_shards', default=1, type=int) 34 | parser.add_argument('--local_rank', default=0, type=int) 35 | parser.add_argument('--output_file', default=None, type=str) 36 | args = parser.parse_args() 37 | 38 | with open(args.dataset, "r") as f: 39 | data = [json.loads(x) for x in f.read().strip().split("\n")] 40 | 41 | if args.num_shards > 1: 42 | partitions = form_partitions(data, args.num_shards) 43 | data = partitions[args.local_rank] 44 | args.output_file = f'{args.output_file}.shard_{args.local_rank}' 45 | 46 | t5x_embedder = RankGenEncoder(model_path=args.retriever_model_path, cache_dir=args.cache_dir) 47 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 48 | 49 | random.seed(49) 50 | random.shuffle(data) 51 | 52 | random.seed(442) 53 | random.shuffle(data) 54 | 55 | folder_name = f"token_bs_t5x" 56 | 57 | device = "cuda" if torch.cuda.is_available() else "cpu" 58 | 59 | tokenizer = GPT2Tokenizer.from_pretrained(f"gpt2-{args.model_size}", cache_dir=args.cache_dir) 60 | tokenizer.pad_token = tokenizer.eos_token 61 | model = GPT2LMHeadModel.from_pretrained(f"gpt2-{args.model_size}", cache_dir=args.cache_dir) 62 | model.to(device) 63 | model.eval() 64 | 65 | 66 | def postprocess(outputs): 67 | return tokenizer.batch_decode(outputs, skip_special_tokens=True) 68 | 69 | 70 | def truncate(text): 71 | last_punc = 0 72 | if "." in text: 73 | last_punc = max(last_punc, text.rindex(".")) 74 | if "?" in text: 75 | last_punc = max(last_punc, text.rindex("?")) 76 | if "!" in text: 77 | last_punc = max(last_punc, text.rindex("!")) 78 | if last_punc != 0: 79 | text = text[:last_punc + 1] 80 | return text 81 | 82 | 83 | def scorer_t5x(t5x_embedder, prefixes, suffixes, prefix_vectors=None): 84 | if prefix_vectors is None: 85 | prefix_vectors = t5x_embedder.encode(prefixes, vectors_type="prefix")["embeddings"] 86 | suffix_vectors = t5x_embedder.encode(suffixes, vectors_type="suffix")["embeddings"] 87 | matmul = torch.matmul(prefix_vectors, suffix_vectors.t()).squeeze(dim=0) 88 | similarities = torch.from_numpy(np.diagonal(matmul.cpu().numpy())) 89 | return similarities, prefix_vectors, suffix_vectors 90 | 91 | 92 | def token_beam_search(contexts, scorer, beam_size=3, temperature=1.0, top_p=0.9, num_tokens=5, num_samples=10, 93 | max_length=115): 94 | final_outputs = [] 95 | final_scores = [] 96 | total_generated_tokens = 0 97 | for ctx in contexts: 98 | ctx_tokens = tokenizer(ctx, truncation=True, padding="longest", return_tensors="pt")['input_ids'] 99 | beams = [{ 100 | "prefix_text": ctx, 101 | "prefix_tokens": ctx_tokens, 102 | "text": "", 103 | "eos": False 104 | } for _ in range(beam_size)] 105 | while True: 106 | all_outs = [] 107 | max_new_tokens = min(num_tokens, max_length - total_generated_tokens) 108 | for beam in beams: 109 | # if a beam has ended, add it to all_outs 110 | if beam["eos"]: 111 | all_outs.append(beam) 112 | continue 113 | # otherwise generate the next n tokens 114 | inputs = beam['prefix_tokens'].to(device) 115 | num_input_tokens = inputs.size()[1] 116 | curr_outs = model.generate(inputs, do_sample=True, output_scores=True, 117 | return_dict_in_generate=True, 118 | max_new_tokens=max_new_tokens, top_k=None, top_p=top_p, 119 | num_return_sequences=num_samples, temperature=temperature) 120 | is_eos = [] 121 | for curr_out in curr_outs['sequences']: 122 | if tokenizer.eos_token_id in curr_out: 123 | is_eos.append(True) 124 | else: 125 | is_eos.append(False) 126 | curr_outs_text = postprocess(curr_outs['sequences'][:, num_input_tokens:]) 127 | for tokens, text, eos in zip(curr_outs['sequences'][:, num_input_tokens:], curr_outs_text, is_eos): 128 | curr_outs_sents = sent_tokenize(text) 129 | # if a full sentence has been generated 130 | if len(curr_outs_sents) > 1: 131 | # remove first sentence from prefix and append generated sentence to prefix 132 | prefix_sents = nltk.sent_tokenize(beam["prefix_text"])[1:] 133 | prefix_sents.append(curr_outs_sents[0]) 134 | while len(tokenizer(' '.join(prefix_sents))['input_ids']) > 256: 135 | prefix_sents.pop(0) 136 | prefix_text = ' '.join(prefix_sents) 137 | prefix_tokens = tokenizer(prefix_text, truncation=True, padding="longest", return_tensors="pt")[ 138 | 'input_ids'] 139 | prefix_tokens = prefix_tokens.to(device) 140 | all_outs.append({ 141 | "prefix_text": prefix_text, 142 | "prefix_tokens": prefix_tokens, 143 | "text": beam["text"] + text, 144 | "eos": eos 145 | }) 146 | else: 147 | all_outs.append({ 148 | "prefix_text": beam['prefix_text'], 149 | "prefix_tokens": beam['prefix_tokens'], 150 | "text": beam["text"] + text, 151 | "eos": eos 152 | }) 153 | 154 | # Each beam has total_generated_tokens length 155 | total_generated_tokens += max_new_tokens 156 | if len(all_outs) > 1: 157 | # skip beam scoring if only one output to choose from 158 | scores, _, _ = scorer(prefixes=[x["prefix_text"] for x in all_outs], suffixes=[x["text"] for x in all_outs], prefix_vectors=None) 159 | top_scores, top_indices = torch.topk(scores, k=beam_size) 160 | beams = [all_outs[x] for x in top_indices] # only track the top k beams 161 | else: 162 | top_scores = torch.Tensor([1.0]) 163 | top_scores.cuda() 164 | beams = all_outs 165 | 166 | for beam in beams: 167 | if len(tokenizer.tokenize(beam["text"])) >= max_length: 168 | beam["eos"] = True 169 | 170 | if all([x["eos"] for x in beams]): 171 | final_outputs.append([x["text"] for x in beams]) 172 | final_scores.append(top_scores) 173 | break 174 | return final_outputs, final_scores 175 | 176 | 177 | scorer_fn = partial(scorer_t5x, t5x_embedder=t5x_embedder) 178 | 179 | outputs = [] 180 | 181 | target_seq_len = [] 182 | gen_seq_len = [] 183 | 184 | logging.set_verbosity_error() 185 | 186 | if os.path.exists(args.output_file): 187 | with open(args.output_file, "r") as f: 188 | outputs = f.read().strip().split("\n") 189 | 190 | for kk, instance in tqdm.tqdm(enumerate(data), total=len(data)): 191 | if kk < len(outputs): 192 | continue 193 | token_beam_text, token_beam_scores = token_beam_search(contexts=[instance["prefix"]], scorer=scorer_fn, 194 | beam_size=args.beam_size, 195 | top_p=args.top_p, num_tokens=args.num_tokens, 196 | num_samples=args.num_samples) 197 | 198 | token_beam_text = token_beam_text[0] 199 | token_beam_text = [truncate(" ".join(x.split())) for x in token_beam_text] 200 | outputs.append(json.dumps({ 201 | "prefix": instance["prefix"], 202 | "targets": instance["targets"][0:1] + token_beam_text, 203 | "scores": instance["scores"][0:1] + token_beam_scores[0].cpu().tolist() 204 | })) 205 | target_seq_len.append(len(instance["targets"][0].split())) 206 | gen_seq_len.append(len(token_beam_text[0].split())) 207 | 208 | if (kk + 1) % 100 == 0: 209 | print(f"Avg lens ({kk + 1} instances) = {np.mean(gen_seq_len)} generation, {np.mean(target_seq_len)} target") 210 | print("Saving file...") 211 | with open(args.output_file, "w") as f: 212 | f.write("\n".join(outputs) + "\n") 213 | 214 | with open(args.output_file, "w") as f: 215 | f.write("\n".join(outputs) + "\n") 216 | -------------------------------------------------------------------------------- /rankgen/score_multi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | import numpy as np 5 | import mauve 6 | import pickle 7 | import glob 8 | import os 9 | import matplotlib.pyplot as plt 10 | from nltk import tokenize 11 | from nltk.corpus import stopwords 12 | from utils import f1_score, rep_statistic 13 | import spacy 14 | import tqdm 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset', default="data/multi_outs/t5_xxl_wiki_t5_xl_gen_inbook_all.jsonl") 19 | parser.add_argument('--eval_type', default="max") 20 | parser.add_argument('--gram', default=1, type=int) 21 | parser.add_argument('--rep_window', default=20, type=int) 22 | parser.add_argument('--plot_divergence', action='store_true') 23 | parser.add_argument('--eval_mauve', action='store_true') 24 | parser.add_argument('--eval_pos_overlap', action='store_true') 25 | args = parser.parse_args() 26 | 27 | files = glob.glob(args.dataset) 28 | 29 | base_dir = os.path.dirname(files[0]) 30 | assert all([os.path.dirname(x) == base_dir for x in files]) 31 | files = ['pg19_t5_xxl_descartes.jsonl', 'wiki_t5_xxl_descartes.jsonl', 32 | 'pg19_gpt2_medium.jsonl', 'wiki_gpt2_medium.jsonl', 33 | 'pg19_gpt2_xl.jsonl', 'wiki_gpt2_xl.jsonl', 34 | 'pg19_t5_xxl.jsonl', 'wiki_t5_xxl.jsonl'] 35 | files = [os.path.join(base_dir, f) for f in files] 36 | 37 | if args.eval_pos_overlap: 38 | nlp = spacy.load("en_core_web_sm") 39 | latex_token_overlap = [] 40 | latex_rep_score = [] 41 | random_latex_token_overlap = [] 42 | random_latex_rep_score = [] 43 | latex_gold_beats_gen = [] 44 | latex_mauve = [] 45 | 46 | random_latex_token_overlap_ents = [] 47 | latex_token_overlap_ents = [] 48 | 49 | for file in files: 50 | if not os.path.exists(file): 51 | continue 52 | with open(file, 'r') as f: 53 | data = [json.loads(x) for x in f.read().strip().split("\n")] 54 | 55 | data_dict = {x["prefix"]: x for x in data} 56 | if "wiki_" in file: 57 | with open("data/multi_outs/t5_xxl_descartes_wiki_ppl.jsonl", "r") as f: 58 | raw_inp_data = [json.loads(x) for x in f.read().strip().split("\n")] 59 | for rid in raw_inp_data: 60 | assert rid["prefix"] in data_dict 61 | assert rid["targets"][0] == data_dict[rid["prefix"]]["targets"][0] 62 | elif "pg19_" in file: 63 | with open("data_new/ppl/pg19_t5_xxl.jsonl", "r") as f: 64 | raw_inp_data = [json.loads(x) for x in f.read().strip().split("\n")] 65 | for rid in raw_inp_data: 66 | assert rid["prefix"] in data_dict 67 | assert rid["targets"][0] == data_dict[rid["prefix"]]["targets"][0] 68 | 69 | 70 | token_overlaps = { 71 | "human": [], 72 | "random": [], 73 | "best": [] 74 | } 75 | if os.path.exists(file + ".ent_overlap.pkl"): 76 | with open(file + ".ent_overlap.pkl", "rb") as f: 77 | token_overlaps_ents = pickle.load(f) 78 | else: 79 | token_overlaps_ents = { 80 | "human": [], 81 | "random": [], 82 | "best": [] 83 | } 84 | rep_scores = { 85 | "human": [], 86 | "random": [], 87 | "best": [] 88 | } 89 | gold_beats_gen = [] 90 | 91 | for i in range(1): 92 | all_human = [] 93 | all_gen = [] 94 | num_tokens_random = [] 95 | num_tokens = [] 96 | all_max_score = [] 97 | for dd in tqdm.tqdm(data): 98 | all_human.append(dd['prefix'] + ' ' + dd['targets'][0]) 99 | random_gen = random.choice(dd['targets'][1:]) 100 | best_gen_idx = np.argmax(dd['scores'][1:]) + 1 101 | best_gen = dd['targets'][best_gen_idx] 102 | all_gen.append(dd['prefix'] + ' ' + random_gen) 103 | all_max_score.append(dd['prefix'] + ' ' + best_gen) 104 | num_tokens.append(len(best_gen.split())) 105 | num_tokens_random.append(len(random_gen.split())) 106 | 107 | token_overlaps["human"].append( 108 | f1_score(dd['targets'][0], dd['prefix'], stopwords=stopwords.words('english'), gram=args.gram)[0] 109 | ) 110 | token_overlaps["random"].append( 111 | f1_score(random_gen, dd['prefix'], stopwords=stopwords.words('english'), gram=args.gram)[0] 112 | ) 113 | token_overlaps["best"].append( 114 | f1_score(best_gen, dd['prefix'], stopwords=stopwords.words('english'), gram=args.gram)[0] 115 | ) 116 | rep_scores["human"].append(rep_statistic(dd['prefix'], dd['targets'][0], window=args.rep_window)) 117 | rep_scores["random"].append(rep_statistic(dd['prefix'], random_gen, window=args.rep_window)) 118 | rep_scores["best"].append(rep_statistic(dd['prefix'], best_gen, window=args.rep_window)) 119 | 120 | gold_beats_gen.extend([ 121 | dd['scores'][0] > x for x in dd['scores'][1:] 122 | ]) 123 | 124 | if args.eval_pos_overlap and not os.path.exists(file + ".ent_overlap.pkl"): 125 | prefix_nlp = nlp(dd['prefix']) 126 | best_nlp = nlp(best_gen) 127 | prefix_ents = " ".join([x.lemma_.lower() for x in prefix_nlp if x.pos_ in ["PROPN", "NUM", "NOUN"]]) 128 | best_ents = " ".join([x.lemma_.lower() for x in best_nlp if x.pos_ in ["PROPN", "NUM", "NOUN"]]) 129 | 130 | token_overlaps_ents["best"].append( 131 | f1_score(best_ents, prefix_ents, stopwords=stopwords.words('english'), gram=args.gram)[0] 132 | ) 133 | 134 | print(f"Results for {file}...") 135 | print(f"Best gen num tokens = {np.mean(num_tokens)}") 136 | print(f"Random gen num tokens = {np.mean(num_tokens_random)}") 137 | print(f"Human token overlap = {np.mean(token_overlaps['human']):.3f}") 138 | print(f"Random token overlap = {np.mean(token_overlaps['random']):.3f}") 139 | print(f"Best gen token overlap = {np.mean(token_overlaps['best']):.3f}") 140 | print(f"Best gen token overlap entities = {np.mean(token_overlaps_ents['best']):.3f}") 141 | 142 | print(f"Human rep = {np.mean(rep_scores['human']):.3f}") 143 | print(f"Random rep = {np.mean(rep_scores['random']):.3f}") 144 | print(f"Best gen rep = {np.mean(rep_scores['best']):.3f}") 145 | 146 | print(f"Gold beats generation = {np.mean(gold_beats_gen)}") 147 | 148 | latex_token_overlap.append(np.mean(token_overlaps['best'])) 149 | latex_rep_score.append(np.mean(rep_scores['best'])) 150 | 151 | random_latex_token_overlap.append(np.mean(token_overlaps['random'])) 152 | random_latex_rep_score.append(np.mean(rep_scores['random'])) 153 | 154 | latex_gold_beats_gen.append(np.mean(gold_beats_gen)) 155 | 156 | if args.eval_pos_overlap: 157 | latex_token_overlap_ents.append(np.mean(token_overlaps_ents['best'])) 158 | with open(file + ".ent_overlap.pkl", "wb") as f: 159 | pickle.dump(token_overlaps_ents, f) 160 | 161 | 162 | if i == 0 and args.eval_mauve: 163 | if os.path.exists(file + ".mauve.pkl"): 164 | with open(file + ".mauve.pkl", "rb") as f: 165 | mauve_data = pickle.load(f) 166 | mauve2 = mauve_data["max_gen_mauve"] 167 | if "random_gen_mauve" in mauve_data: 168 | mauve1 = mauve_data["random_gen_mauve"] 169 | else: 170 | mauve1 = None 171 | else: 172 | mauve1 = None 173 | mauve2 = mauve.compute_mauve(p_text=all_max_score, q_text=all_human, device_id=0, max_text_length=768, verbose=False) 174 | print(f"Max score mauve = {mauve2.mauve}") 175 | latex_mauve.append(mauve2.mauve) 176 | 177 | if args.eval_mauve and args.eval_type == "both" and mauve1 is None: 178 | mauve1 = mauve.compute_mauve(p_text=all_gen, q_text=all_human, device_id=0, max_text_length=768, verbose=False) 179 | 180 | if args.eval_mauve and mauve1 is not None: 181 | print(f"Random gen mauve = {mauve1.mauve}") 182 | 183 | if i == 0 and args.plot_divergence: 184 | plt.rcParams.update({'font.size': 16}) 185 | plt.axis([0.0, 1.0, 0.0, 1.0]) 186 | plt.plot(mauve1.divergence_curve[:, 0], mauve1.divergence_curve[:, 1]) 187 | plt.plot(mauve2.divergence_curve[:, 0], mauve2.divergence_curve[:, 1]) 188 | plt.fill_between(mauve1.divergence_curve[:, 0], mauve1.divergence_curve[:, 1], hatch='o', label="Nucleus", facecolor='white', edgecolor=plt.rcParams['axes.prop_cycle'].by_key()['color'][0]) 189 | plt.fill(np.append(mauve1.divergence_curve[:, 0], mauve2.divergence_curve[:, 0][::-1]), 190 | np.append(mauve1.divergence_curve[:, 1], mauve2.divergence_curve[:, 1][::-1]), 191 | hatch='/', label="RankGen", facecolor='white', edgecolor=plt.rcParams['axes.prop_cycle'].by_key()['color'][1]) 192 | # plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), 193 | # fancybox=True, shadow=False, ncol=1) 194 | plt.legend(loc='upper right') 195 | plt.xlabel("similarity to Q") 196 | plt.ylabel("similarity to P") 197 | plt.savefig(f'{file}.plot.pdf', bbox_inches="tight") 198 | plt.clf() 199 | 200 | if args.eval_mauve: 201 | outputs = { 202 | "max_gen_mauve": mauve2, 203 | "random_gen_mauve": mauve1 204 | } 205 | with open(file + ".mauve.pkl", "wb") as f: 206 | pickle.dump(outputs, f) 207 | 208 | print("Random gen latex --- ") 209 | print(f"Latex token overlap = {' & '.join([f'{100 * x:.1f}' for x in random_latex_token_overlap])} & {100 * np.mean(random_latex_token_overlap):.1f}") 210 | # print(f"Latex token overlap entities = {' & '.join([f'{100 * x:.1f}' for x in random_latex_token_overlap_ents])} & {100 * np.mean(random_latex_token_overlap_ents):.1f}") 211 | print(f"Latex rep = {' & '.join([f'{100 * x:.1f}' for x in random_latex_rep_score])} & {100 * np.mean(random_latex_rep_score):.1f}") 212 | 213 | 214 | print("Best gen latex --- ") 215 | print(f"Latex token overlap = {' & '.join([f'{100 * x:.1f}' for x in latex_token_overlap])} & {100 * np.mean(latex_token_overlap):.1f}") 216 | print(f"Latex token overlap entities = {' & '.join([f'{100 * x:.1f}' for x in latex_token_overlap_ents])} & {100 * np.mean(latex_token_overlap_ents):.1f}") 217 | print(f"Latex rep = {' & '.join([f'{100 * x:.1f}' for x in latex_rep_score])} & {100 * np.mean(latex_rep_score):.1f}") 218 | 219 | print(f"Gold beats gen latex = {' & '.join([f'{100 * x:.1f}' for x in latex_gold_beats_gen])} & {100 * np.mean(latex_gold_beats_gen):.1f}") 220 | 221 | print(" ") 222 | print(f"Mauve latex = {' & '.join([f'{x:.3f}' for x in latex_mauve])} & {np.mean(latex_mauve):.3f}") 223 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /rankgen/convert_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | import time 5 | import functools 6 | import sys 7 | import pdb 8 | from typing import Any, Dict, Iterable, MutableMapping, Mapping, Optional, Sequence, Tuple, List 9 | 10 | import asyncio 11 | import pickle 12 | import transformers 13 | from absl import logging 14 | from flax import optim 15 | from flax import serialization 16 | from flax import traverse_util 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | from t5x import checkpoint_importer 21 | from t5x import multihost_utils 22 | from t5x import state_utils 23 | from t5x import train_state as train_state_lib 24 | import tensorflow as tf 25 | from tensorflow.io import gfile 26 | import tensorstore as ts 27 | import numpy as np 28 | import typing_extensions 29 | from tensorboard.backend.event_processing import directory_watcher 30 | from tensorboard.backend.event_processing import event_file_loader 31 | from tensorboard.backend.event_processing import io_wrapper 32 | 33 | from t5x.checkpoints import _maybe_update_ts_from_file_to_gcs, _maybe_update_ts_from_gcs_to_file, \ 34 | RestoreStateTransformationFn 35 | 36 | PyTreeDef = type(jax.tree_structure(None)) 37 | LazyArray = checkpoint_importer.LazyArray 38 | LazyAwaitableArray = checkpoint_importer.LazyAwaitableArray 39 | LazyThreadPoolArray = checkpoint_importer.LazyThreadPoolArray 40 | 41 | VERSION = 3 42 | _DESIRED_CHUNK_SIZE_BYTES = 64 * 1024 * 1024 43 | 44 | 45 | class _ParameterInfo: 46 | name: str 47 | ts_spec: Optional[ts.Spec] 48 | 49 | def __init__(self, name, ts_spec): 50 | self.name = name 51 | self.ts_spec = ts_spec 52 | 53 | 54 | def _get_optimizer_state_dict( 55 | ckpt_contents: PyTreeDef, optimizer_state: Mapping[str, Any], 56 | state_transformation_fns: Sequence[RestoreStateTransformationFn]): 57 | version = ckpt_contents.get('version', 0) 58 | if version == 0: 59 | ckpt_optimizer_state = ckpt_contents 60 | else: 61 | ckpt_optimizer_state = ckpt_contents['optimizer'] 62 | 63 | if version >= 2: 64 | for fn in state_transformation_fns: 65 | ckpt_optimizer_state = fn(ckpt_optimizer_state, optimizer_state) 66 | return ckpt_optimizer_state 67 | else: 68 | raise ValueError('Checkpoint versions earlier than 2 are not supported. ' 69 | f'Got version: {version}') 70 | 71 | 72 | def _cast(target: PyTreeDef, dtype: jnp.dtype): 73 | def maybe_cast(x): 74 | if isinstance(x, (int, str)): 75 | # Ignore common non-array types that shouldn't be cast. 76 | return x 77 | elif x.dtype == dtype: 78 | return x 79 | elif isinstance(x, jax.ShapeDtypeStruct): 80 | return jax.ShapeDtypeStruct(x.shape, dtype) 81 | else: 82 | return x.astype(dtype) 83 | 84 | return jax.tree_map(maybe_cast, target) 85 | 86 | 87 | def _get_state_dict_for_save(state_dict: Dict[str, Any], 88 | lazy_load: bool = True) -> Mapping[str, Any]: 89 | def _lazy_load_device_array(arr): 90 | if isinstance(arr, jax.xla.DeviceArray): 91 | return LazyThreadPoolArray(arr.shape, arr.dtype, lambda: np.array(arr)) 92 | return arr 93 | 94 | if lazy_load: 95 | state_dict = jax.tree_map(_lazy_load_device_array, state_dict) 96 | state_dict['target'] = _cast(state_dict['target'], np.float32) 97 | return state_dict 98 | 99 | 100 | def _get_parameter_infos(ckpt_state_dict): 101 | def _get_param_info(name: str, arr: Any): 102 | # print("PRINTING") 103 | # print(name) 104 | # print(arr) 105 | # print(type(arr)) 106 | # if type(arr) == 'numpy.ndarray': 107 | return _ParameterInfo(name=name, ts_spec=None) 108 | 109 | param_names = traverse_util.unflatten_dict({ 110 | k: '/'.join(k) for k in traverse_util.flatten_dict( 111 | ckpt_state_dict, keep_empty_nodes=True) 112 | }) 113 | 114 | # print(param_names) 115 | # print(_get_state_dict_for_save(ckpt_state_dict)) 116 | 117 | return jax.tree_map( 118 | _get_param_info, param_names, 119 | _get_state_dict_for_save(ckpt_state_dict)) 120 | 121 | 122 | async def _read_ts(param_info: _ParameterInfo, maybe_tspec: Any, 123 | ckpt_path: str): 124 | # If saved as a numpy array, but a partitioned read is requested, return a 125 | # slice of the array for that host. Otherwise, return the whole thing. 126 | if isinstance(maybe_tspec, np.ndarray) and param_info: 127 | return maybe_tspec 128 | # If we have anything else that isn't a tensorstore spec just return it. 129 | elif not isinstance(maybe_tspec, ts.Spec): 130 | return maybe_tspec 131 | 132 | tmp_ts_spec_dict = maybe_tspec.to_json() 133 | # Remove non-required params so that we can open Tensorstore 134 | # that was created with a different set of params. 135 | del tmp_ts_spec_dict['metadata']['chunks'] 136 | del tmp_ts_spec_dict['metadata']['compressor'] 137 | 138 | # Convert the relative path in the spec to a path based on the checkpoint 139 | # location. Path and gcs bucket (if applicable) information is updated 140 | # in-place. 141 | _update_ts_path_from_relative_to_absolute( 142 | os.path.dirname(ckpt_path), tmp_ts_spec_dict) 143 | 144 | # if param_info.shape is not None: 145 | # ts_spec_arr_shape = tuple(tmp_ts_spec_dict['metadata']['shape']) 146 | # # Check that the shapes of the array on disk match the expected shape based 147 | # # on the optimizer that is being restored. 148 | # if ts_spec_arr_shape != param_info.shape: 149 | # raise ValueError(f'Shape of `{param_info.name}` in checkpoint ' 150 | # f'{ts_spec_arr_shape} does not match expected ' 151 | # f'{param_info.shape}.') 152 | # Read the array. 153 | t = await ts.open(tmp_ts_spec_dict, open=True) 154 | if param_info.local_chunk_info is not None: 155 | # Just read the subsection we care about. 156 | t = t[param_info.local_chunk_info.slice] 157 | arr = await t.read() 158 | # Assume we had to cast bfloat16 to uint16 to store with zarr. 159 | # TODO(ndl): remove this bitcast, as well as related bitcasts in PW code, 160 | # once we're ready to deprecate T5X checkpoints with "legacy" bfloat16 161 | # support. 162 | if arr.dtype == np.uint16: 163 | arr = arr.view(jnp.bfloat16) 164 | return arr 165 | 166 | 167 | def _create_lazy_awaitable_array(param_info: _ParameterInfo, maybe_ts_spec: Any, 168 | ckpt_path: str) -> LazyAwaitableArray: 169 | get_fn = functools.partial( 170 | _read_ts, param_info, maybe_ts_spec, ckpt_path=ckpt_path) 171 | if isinstance(maybe_ts_spec, ts.Spec) or isinstance(maybe_ts_spec, np.ndarray): 172 | return LazyAwaitableArray.from_tensor_store_spec_or_array( 173 | maybe_ts_spec, get_fn) 174 | 175 | 176 | def _read_state_from_tensorstore( 177 | ckpt_path: str, 178 | parameter_infos: _ParameterInfo, 179 | written_state_dict: Mapping[str, Any], 180 | restore_parameter_infos: Optional[Mapping[str, Any]] = None, 181 | lazy_parameters: bool = False 182 | ) -> Mapping[str, Any]: 183 | if restore_parameter_infos is None: 184 | restore_parameter_infos = parameter_infos 185 | 186 | # Replace TensorStore Specs with the lazy array values. 187 | state_dict = jax.tree_multimap( 188 | functools.partial(_create_lazy_awaitable_array, ckpt_path=ckpt_path), 189 | restore_parameter_infos, written_state_dict) 190 | 191 | if not lazy_parameters: 192 | future_state_dict = jax.tree_map(lambda x: x.get_async(), state_dict) 193 | state_dict = _run_future_tree(future_state_dict) 194 | 195 | state_dict['target'] = _cast(state_dict['target'], np.float32) 196 | 197 | return state_dict 198 | 199 | 200 | def _run_future_tree(future_tree): 201 | """Block until all futures are resolved on this host.""" 202 | future_leaves, treedef = jax.tree_flatten(future_tree) 203 | 204 | # TODO(adarob): Use asyncio.run in py3.7+. 205 | loop = asyncio.get_event_loop() 206 | leaves = loop.run_until_complete(asyncio.gather(*future_leaves)) 207 | return jax.tree_unflatten(treedef, leaves) 208 | 209 | 210 | def restore( 211 | path: Optional[str] = None, 212 | fallback_state: Optional[Mapping[str, Any]] = None, 213 | lazy_parameters: bool = False) -> train_state_lib.TrainState: 214 | ckpt_path = path 215 | 216 | if gfile.isdir(ckpt_path): 217 | ckpt_dir = ckpt_path 218 | ckpt_path = os.path.join(ckpt_path, 'checkpoint') 219 | else: 220 | ckpt_dir = os.path.dirname(ckpt_path) 221 | 222 | if not gfile.exists(ckpt_path) or gfile.isdir(ckpt_path): 223 | raise ValueError(f'Path is not a valid T5X checkpoint: {ckpt_path}') 224 | 225 | logging.info('Restoring from checkpoint: %s', ckpt_path) 226 | 227 | with gfile.GFile(ckpt_path, 'rb') as fp: 228 | raw_contents = fp.read() 229 | if raw_contents.startswith(b'model_checkpoint_path'): 230 | raise ValueError( 231 | 'Attempting to restore a TensorFlow checkpoint as a native T5X ' 232 | 'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' + 233 | ckpt_path) 234 | 235 | ckpt_contents = serialization.msgpack_restore(raw_contents) 236 | 237 | if ckpt_dir.startswith('gs://'): 238 | ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents) 239 | else: 240 | ckpt_contents = _maybe_update_ts_from_gcs_to_file(ckpt_contents) 241 | 242 | ckpt_state_dict = _get_optimizer_state_dict(ckpt_contents, [], []) 243 | 244 | # print(ckpt_state_dict) 245 | 246 | dummy_spec = ts.Spec({'driver': 'zarr', 'kvstore': {'driver': 'memory'}}) 247 | 248 | parameter_infos = _get_parameter_infos(ckpt_state_dict) 249 | # print(parameter_infos) 250 | dummy_written_state_dict = jax.tree_map( 251 | lambda x: x.ts_spec or dummy_spec, 252 | parameter_infos, 253 | ) 254 | 255 | if fallback_state is None: 256 | restore_parameter_infos = parameter_infos 257 | else: 258 | dummy_written_state_dict = state_utils.intersect_state( 259 | dummy_written_state_dict, ckpt_state_dict) 260 | restore_parameter_infos = state_utils.intersect_state( 261 | _parameter_infos, ckpt_state_dict) 262 | 263 | restore_parameter_infos_flat = state_utils.flatten_state_dict( 264 | restore_parameter_infos) 265 | for key in restore_parameter_infos_flat.keys(): 266 | logging.info('Restoring key from ckpt: %s', key) 267 | 268 | written_state_dict = serialization.from_state_dict(dummy_written_state_dict, 269 | ckpt_state_dict) 270 | state_dict = _read_state_from_tensorstore( 271 | ckpt_path, 272 | parameter_infos, 273 | written_state_dict, 274 | restore_parameter_infos=restore_parameter_infos, 275 | lazy_parameters=lazy_parameters) 276 | 277 | if fallback_state is not None: 278 | state_dict = state_utils.merge_state(state_dict, fallback_state) 279 | 280 | for key in state_utils.flatten_state_dict(state_dict).keys(): 281 | if key not in restore_parameter_infos_flat: 282 | logging.info('Not restoring key from ckpt: %s', key) 283 | 284 | return ckpt_state_dict 285 | 286 | 287 | def read_array(data): 288 | path = data['kvstore']['path'] 289 | if path.startswith('t5x/pre_suf_retriever/checkpoint_1100000/') == False: 290 | data['kvstore']['path'] = 't5x/pre_suf_retriever/checkpoint_1100000/' + path 291 | dataset = ts.open(data).result() 292 | return np.array(dataset) 293 | 294 | 295 | ckpt = restore(path="t5x/pre_suf_retriever/checkpoint_1100000") 296 | state_dict = {} 297 | 298 | state_dict['encoder.final_layer_norm.weight'] = ckpt['target']['encoder']['encoder_norm']['scale'] 299 | state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'] = ckpt['target']['encoder']['relpos_bias']['rel_embedding'] 300 | state_dict['encoder.embed_tokens.weight'] = read_array(ckpt['target']['token_embedder']['embedding']) 301 | 302 | for key, value in ckpt['target']['encoder'].items(): 303 | if key.startswith('layers_'): 304 | n = key[7:] 305 | state_dict[f'encoder.block.{n}.layer.0.layer_norm.weight'] = value['pre_attention_layer_norm']['scale'] 306 | state_dict[f'encoder.block.{n}.layer.0.SelfAttention.k.weight'] = read_array( 307 | value['attention']['key']['kernel']) 308 | state_dict[f'encoder.block.{n}.layer.0.SelfAttention.q.weight'] = read_array( 309 | value['attention']['query']['kernel']) 310 | state_dict[f'encoder.block.{n}.layer.0.SelfAttention.v.weight'] = read_array( 311 | value['attention']['value']['kernel']) 312 | state_dict[f'encoder.block.{n}.layer.0.SelfAttention.o.weight'] = read_array( 313 | value['attention']['out']['kernel']) 314 | state_dict[f'encoder.block.{n}.layer.1.layer_norm.weight'] = value['pre_mlp_layer_norm']['scale'] 315 | state_dict[f'encoder.block.{n}.layer.1.DenseReluDense.wi_0.weight'] = read_array( 316 | value['mlp']['wi_0']['kernel']) 317 | state_dict[f'encoder.block.{n}.layer.1.DenseReluDense.wi_1.weight'] = read_array( 318 | value['mlp']['wi_1']['kernel']) 319 | state_dict[f'encoder.block.{n}.layer.1.DenseReluDense.wo.weight'] = read_array( 320 | value['mlp']['wo']['kernel']) 321 | 322 | with open('state_dict.pickle', 'wb') as handle: 323 | pickle.dump(state_dict, handle) 324 | 325 | projection = ckpt['target']['encoder']['encoder_projection_layer']['kernel'] 326 | with open('projection.pickle', 'wb') as handle: 327 | pickle.dump(projection, handle) --------------------------------------------------------------------------------