├── src ├── eval_metrics │ ├── __init__.py │ ├── spice.py │ ├── cider.py │ └── bleu.py ├── select_sentence.py ├── eval.py ├── gpt_sample_data.py ├── hmm_lvd.py ├── gpt_finetune.py ├── hmm_train.jl ├── decode.py └── hmm_model.py ├── .gitignore ├── scripts ├── download_eval_dependencies.sh ├── 7_evaluate.sh ├── 6_select_sentence.sh ├── 2_sample_training_data.sh ├── 1_finetune_gpt.sh ├── 4_train_hmm.sh ├── 3_lvd_hmm.sh └── 5_decode.sh ├── Project.toml └── README.md /src/eval_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | data/* 3 | models/* 4 | output/* 5 | logs/* 6 | src/eval_metrics/lib/* 7 | src/eval_metrics/cache/* 8 | Manifest.toml 9 | -------------------------------------------------------------------------------- /scripts/download_eval_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | LIB=src/eval_metrics/lib 4 | mkdir -p $LIB 5 | 6 | echo "Downloading..." 7 | 8 | # spice 9 | SPICE=SPICE-1.0.zip 10 | wget https://panderson.me/images/$SPICE 11 | unzip SPICE-1.0.zip -d $LIB/ 12 | rm -f $SPICE 13 | 14 | # 15 | bash src/eval_metrics/lib/SPICE-1.0/get_stanford_models.sh 16 | -------------------------------------------------------------------------------- /scripts/7_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for unsupervised setting 4 | python src/eval.py \ 5 | --result_file common-gen_validation_unsupervised_selected.json \ 6 | --target_file common-gen_validation.json 7 | 8 | # for supervised setting 9 | # python src/eval.py \ 10 | # --result_file common-gen_validation_supervised_selected.json \ 11 | # --target_file common-gen_validation.json -------------------------------------------------------------------------------- /Project.toml: -------------------------------------------------------------------------------- 1 | [deps] 2 | ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" 3 | CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" 4 | CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" 5 | DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" 6 | Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c" 7 | ProbabilisticCircuits = "2396afbe-23d7-11ea-1e05-f1aa98e17a44" 8 | PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" 9 | 10 | [compat] 11 | ProbabilisticCircuits = "0.4.1" 12 | -------------------------------------------------------------------------------- /scripts/6_select_sentence.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for unsupervised setting 4 | python src/select_sentence.py --device cuda --cuda_core 0 \ 5 | --rerank --gpt_model_path models/gpt2-large_unsupervised/checkpoint-1 \ 6 | --input_file output/common-gen_validation_unsupervised_output.json \ 7 | --output_file output/common-gen_validation_unsupervised_selected.json 8 | 9 | 10 | # for supervised setting 11 | # python src/select_sentence.py --device cuda --cuda_core 0 \ 12 | # --rerank --gpt_model_path models/gpt2-large_unsupervised/checkpoint-1 \ 13 | # --input_file output/common-gen_validation_supervised_output.json \ 14 | # --output_file output/common-gen_validation_supervised_selected.json -------------------------------------------------------------------------------- /scripts/2_sample_training_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for unsupervised setting 4 | mkdir -p ./data/unsupervised/ 5 | for idx in {1..40} 6 | do 7 | python src/gpt_sample_data.py --device cuda --cuda_core 0 \ 8 | --model_file ./models/gpt2-large_unsupervised/checkpoint-1 \ 9 | --sample_num 200000 --max_sample_length 32 --batch_size 1024 \ 10 | --output_file ./data/unsupervised/common-gen.train.${idx} 11 | done 12 | 13 | # for supervised setting 14 | # mkdir -p ./data/supervised/ 15 | # for idx in {1..40} 16 | # do 17 | # python src/gpt_sample_data.py --device cuda --cuda_core $cuda_core \ 18 | # --model_file ./models/gpt2-large_supervised/checkpoint-3 \ 19 | # --sample_num 200000 --max_sample_length 32 --batch_size 1024 \ 20 | # --output_file ./data/supervised/common-gen.train.${idx} 21 | # done -------------------------------------------------------------------------------- /scripts/1_finetune_gpt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p logs 4 | 5 | # for unsupervised setting 6 | 7 | python src/gpt_finetune.py --device cuda --cuda_core 0 \ 8 | --max_epoch 5 --batch_size 128 --lr 1e-6 \ 9 | --train_data_file ./data/common-gen_train.json \ 10 | --validation_data_file ./data/common-gen_validation.json \ 11 | --model_path ./models/gpt2-large_unsupervised/ \ 12 | --log_file ./logs/1_finetune_gpt_unsupervised.log 13 | 14 | 15 | # for supervised setting 16 | # python src/gpt_finetune.py --device cuda --cuda_core 0 \ 17 | # --seq2seq \ 18 | # --max_epoch 5 --batch_size 128 --lr 1e-6 \ 19 | # --train_data_file ./data/common-gen_train.json \ 20 | # --validation_data_file ./data/common-gen_validation.json \ 21 | # --model_path ./models/gpt2-large_supervised/ \ 22 | # --log_file ./logs/1_finetune_gpt_supervised.log -------------------------------------------------------------------------------- /scripts/4_train_hmm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | hidden_states=4096 4 | cuda_core=0 5 | 6 | # for unsupervised setting 7 | mkdir -p models/hmm_${hidden_states}_unsupervised 8 | 9 | julia --project src/hmm_train.jl --cuda_id $cuda_core \ 10 | --model_path models/hmm_${hidden_states}_unsupervised/ \ 11 | --checkpoint 0 --max_epochs 40 --sample_length 32 \ 12 | --hidden_states $hidden_states --vocab_size 50257 --batch_size 2048 \ 13 | --pseudocount 0.1 \ 14 | --log_file logs/3_train_hmm_unsupervised.log \ 15 | --train_data_file data/unsupervised/common-gen.train 16 | 17 | 18 | # for supervised setting 19 | # julia --project src/hmm_train.jl --cuda_id $cuda_core \ 20 | # --model_path models/hmm_${hidden_states}_supervised/ \ 21 | # --checkpoint 0 --max_epochs 40 --sample_length 32 \ 22 | # --hidden_states $hidden_states --vocab_size 50257 --batch_size 2048 \ 23 | # --pseudocount 0.1 \ 24 | # --log_file logs/3_train_hmm_supervised.log \ 25 | # --train_data_file data/supervised/common-gen.train -------------------------------------------------------------------------------- /scripts/3_lvd_hmm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | hidden_states=4096 4 | cuda_core=0 5 | 6 | # for unsupervised setting 7 | mkdir -p models/hmm_${hidden_states}_unsupervised 8 | 9 | python src/hmm_lvd.py --device cuda --cuda_core $cuda_core \ 10 | --teacher_model_checkpoint ./models/gpt2-large_unsupervised/checkpoint-1 \ 11 | --sample_num 500000 --max_sample_length 32 --batch_size 512 \ 12 | --hidden_states $hidden_states --vocab_size 50257 --kmeans_iterations 200 --pseudocount 0.001 \ 13 | --output_file models/hmm_${hidden_states}_unsupervised/checkpoint-0.weight 14 | 15 | # for supervised setting 16 | # mkdir -p models/hmm_${hidden_states}_supervised 17 | 18 | # python src/hmm_lvd.py --device cuda --cuda_core $cuda_core \ 19 | # --teacher_model_checkpoint ./models/gpt2-large_supervised/checkpoint-3 \ 20 | # --sample_num 500000 --max_sample_length 32 --batch_size 512 \ 21 | # --hidden_states $hidden_states --vocab_size 50257 --kmeans_iterations 200 --pseudocount 0.001 \ 22 | # --output_file models/hmm_${hidden_states}_supervised/checkpoint-0.weight -------------------------------------------------------------------------------- /scripts/5_decode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p output 4 | 5 | # for unsupervised setting 6 | python src/decode.py --device cuda --cuda_core 0 \ 7 | --hmm_batch_size 128 --seq2seq 0 \ 8 | --min_sample_length 5 --max_sample_length 32 \ 9 | --num_beams 128 --length_penalty 0.2 \ 10 | --hmm_model_path models/hmm_4096_unsupervised/checkpoint-40.weight.th \ 11 | --gpt_model_path models/gpt2-large_unsupervised/checkpoint-1 \ 12 | --dataset_file data/common-gen_validation.json \ 13 | --output_file output/common-gen_validation_unsupervised_output.json 14 | 15 | 16 | # for supervised setting 17 | # python src/decode.py --device cuda --cuda_core 0 \ 18 | # --hmm_batch_size 128 --seq2seq 2 --w 0.3 \ 19 | # --min_sample_length 5 --max_sample_length 32 \ 20 | # --num_beams 128 --length_penalty 0.2 \ 21 | # --hmm_model_path models/hmm_4096_supervised/checkpoint-40.weight.th \ 22 | # --gpt_model_path models/gpt2-large_supervised/checkpoint-3 \ 23 | # --dataset_file data/common-gen_validation.json \ 24 | # --output_file output/common-gen_validation_supervised_output.json -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GeLaTo 2 | 3 | This is the source code for the paper ["Tractable Control for Autoregressive Language Generation"](https://arxiv.org/abs/2304.07438) (ICML 2023) 4 | 5 | 6 | ## Requirements 7 | We suggest using conda to setup environment. 8 | 9 | ``` 10 | conda create --name gelato python=3.8 11 | conda activate gelato 12 | ``` 13 | 14 | for PyTorch & Transformers: 15 | ``` 16 | pip3 install torch torchvision torchaudio transformers==4.21.3 datasets lemminflect 17 | conda install -c pytorch faiss-gpu 18 | ``` 19 | 20 | to train HMMs with Juice.jl, you need to download Julia: 21 | ``` 22 | https://julialang.org/downloads/ 23 | ``` 24 | 25 | for evaluation: 26 | ``` 27 | pip3 install evaluate rouge_score 28 | pip3 install -U spacy 29 | python -m spacy download en_core_web_sm 30 | ``` 31 | 32 | ## Models & Outputs 33 | We release checkpoints for the base models (GPT2-large finetuned on CommonGen) and the distilled HMMs for reproducibility. In addition, we also release the generated examples. 34 | 35 | ``` 36 | https://drive.google.com/drive/folders/1cagRWGrGQ6HNes0z7Li2dHo2PfcuuZEl?usp=sharing 37 | ``` 38 | 39 | ## Running the GeLaTo Pipeline 40 | 41 | We use CommonGen (unsupervised setting) as an example to illustrate how to run the GeLaTo pipeline. See contents of the scripts for full command lines. 42 | 43 | ### 1. finetuning the base model 44 | ``` 45 | bash scripts/1_finetune_gpt.sh 46 | ``` 47 | 48 | 49 | ### 2. training the HMMs 50 | To train an HMM that approximates the base model, there are three steps: 51 | 52 | * sampling training data from the base model 53 | ``` 54 | bash scripts/2_sample_training_data.sh 55 | ``` 56 | 57 | * using latent variable distillation (LVD) to initialize HMM parameters 58 | ``` 59 | bash scripts/3_lvd_hmm.sh 60 | ``` 61 | 62 | * train HMM with EM (need Julia installation) 63 | ``` 64 | bash scripts/4_train_hmm.sh 65 | ``` 66 | 67 | ### 3. generation 68 | ``` 69 | bash scripts/5_decode.sh 70 | ``` 71 | 72 | ### 4. re-ranking the generated sentences 73 | ``` 74 | bash scripts/6_select_sentence.sh 75 | ``` 76 | 77 | ### 5. evaluation 78 | ``` 79 | bash scripts/download_eval_dependencies.sh 80 | bash scripts/7_evaluate.sh 81 | ``` 82 | 83 | 84 | -------------------------------------------------------------------------------- /src/select_sentence.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import argparse 5 | 6 | from tqdm import tqdm 7 | import torch 8 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 9 | 10 | device = 'cuda' 11 | 12 | def init(): 13 | global device 14 | global CUDA_CORE 15 | 16 | arg_parser = argparse.ArgumentParser() 17 | arg_parser.add_argument('--device', default='cuda', type=str) 18 | arg_parser.add_argument('--cuda_core', default='1', type=str) 19 | 20 | arg_parser.add_argument('--rerank', action='store_true') 21 | arg_parser.add_argument('--gpt_model_path', default='gpt2', type=str) 22 | arg_parser.add_argument('--input_file', default='', type=str) 23 | arg_parser.add_argument('--output_file', default='', type=str) 24 | 25 | args = arg_parser.parse_args() 26 | 27 | device = args.device 28 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_core 29 | 30 | return args 31 | 32 | 33 | def loglikelihood(model, tokenizer, texts): 34 | inputs = tokenizer(texts, padding=True, return_tensors='pt') 35 | input_ids = inputs['input_ids'].to(device) 36 | attention_mask = inputs['attention_mask'].to(device) 37 | 38 | n, d = input_ids.shape 39 | with torch.no_grad(): 40 | logits = model(input_ids=input_ids, attention_mask=attention_mask).logits[:,:-1,:] 41 | log_probs = torch.log_softmax(logits, dim=-1) 42 | log_probs = log_probs[torch.arange(0, n).unsqueeze(-1), 43 | torch.arange(0, d-1).unsqueeze(0), input_ids[:,1:]] 44 | log_probs *= attention_mask[:,1:] 45 | 46 | lls = log_probs.sum(dim=-1) 47 | 48 | return lls.tolist() 49 | 50 | 51 | def main(): 52 | args = init() 53 | 54 | if args.rerank: 55 | print(f'loading gpt2 from {args.gpt_model_path} ...') 56 | gpt_model = GPT2LMHeadModel.from_pretrained(args.gpt_model_path) 57 | gpt_model.config.pad_token_id = gpt_model.config.eos_token_id 58 | gpt_model.eval() 59 | gpt_model.to(device) 60 | 61 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large') 62 | tokenizer.pad_token = tokenizer.eos_token 63 | 64 | examples = [] 65 | processed_examples = [] 66 | with open(args.input_file, 'r') as fin: 67 | examples = json.load(fin) 68 | 69 | for example in tqdm(examples): 70 | if example['sentences'] != []: 71 | if args.rerank: 72 | sentences = ['<|endoftext|>' + x for x in example['sentences']] 73 | lls = loglikelihood(gpt_model, tokenizer, sentences) 74 | selected = sorted([(a, b) for a, b in zip(example['sentences'], lls)], 75 | key=lambda x: x[1], reverse=True)[0][0] 76 | else: 77 | selected = example['sentences'][0] 78 | else: 79 | selected = '' 80 | continue 81 | 82 | processed_examples.append({ 83 | 'concept_set_idx': example['concept_set_idx'], 84 | 'concepts': example['concepts'], 85 | 'sentence': selected, 86 | }) 87 | 88 | with open(args.output_file, 'w') as fout: 89 | json.dump(processed_examples, fout, indent=2) 90 | 91 | 92 | if __name__ == '__main__': 93 | main() -------------------------------------------------------------------------------- /src/eval_metrics/spice.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | import threading 5 | import json 6 | import numpy as np 7 | import ast 8 | import tempfile 9 | 10 | # Assumes spice.jar is in the same directory as spice.py. Change as needed. 11 | SPICE_JAR = os.getcwd() + '/src/eval_metrics/lib/SPICE-1.0/spice-1.0.jar' 12 | TEMP_DIR = 'tmp' 13 | CACHE_DIR = 'cache' 14 | 15 | class Spice: 16 | """ 17 | Main Class to compute the SPICE metric 18 | """ 19 | 20 | def float_convert(self, obj): 21 | try: 22 | return float(obj) 23 | except: 24 | return np.nan 25 | 26 | def compute_score(self, gts, res): 27 | assert(sorted(gts.keys()) == sorted(res.keys())) 28 | imgIds = sorted(gts.keys()) 29 | 30 | # Prepare temp input file for the SPICE scorer 31 | input_data = [] 32 | for id in imgIds: 33 | hypo = res[id] 34 | ref = gts[id] 35 | 36 | # Sanity check. 37 | assert(type(hypo) is list) 38 | assert(len(hypo) == 1) 39 | assert(type(ref) is list) 40 | assert(len(ref) >= 1) 41 | 42 | input_data.append({ 43 | "image_id" : id, 44 | "test" : hypo[0], 45 | "refs" : ref 46 | }) 47 | 48 | cwd = os.path.dirname(os.path.abspath(__file__)) 49 | temp_dir=os.path.join(cwd, TEMP_DIR) 50 | if not os.path.exists(temp_dir): 51 | os.makedirs(temp_dir) 52 | in_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir, mode='w') 53 | json.dump(input_data, in_file, indent=2) 54 | in_file.close() 55 | 56 | # Start job 57 | out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) 58 | out_file.close() 59 | cache_dir=os.path.join(cwd, CACHE_DIR) 60 | if not os.path.exists(cache_dir): 61 | os.makedirs(cache_dir) 62 | spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name, 63 | '-cache', cache_dir, 64 | '-out', out_file.name, 65 | '-subset', 66 | '-silent' 67 | ] 68 | print(" ".join(spice_cmd)) 69 | subprocess.check_call(spice_cmd, 70 | cwd=os.path.dirname(os.path.abspath(__file__))) 71 | 72 | # Read and process results 73 | with open(out_file.name) as data_file: 74 | results = json.load(data_file) 75 | os.remove(in_file.name) 76 | os.remove(out_file.name) 77 | 78 | imgId_to_scores = {} 79 | spice_scores = [] 80 | for item in results: 81 | imgId_to_scores[item['image_id']] = item['scores'] 82 | spice_scores.append(self.float_convert(item['scores']['All']['f'])) 83 | average_score = np.mean(np.array(spice_scores)) 84 | scores = [] 85 | for image_id in imgIds: 86 | # Convert none to NaN before saving scores over subcategories 87 | score_set = {} 88 | for category,score_tuple in imgId_to_scores[image_id].items(): 89 | score_set[category] = {k: self.float_convert(v) for k, v in list(score_tuple.items())} 90 | scores.append(score_set) 91 | return average_score, scores 92 | 93 | def method(self): 94 | return "SPICE" 95 | 96 | 97 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | from eval_metrics.bleu import Bleu 2 | from eval_metrics.cider import Cider 3 | from eval_metrics.spice import Spice 4 | 5 | import spacy 6 | import json 7 | import codecs 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--target_file', default="", type=str) 12 | parser.add_argument('--result_file', default="", type=str) 13 | args = parser.parse_args() 14 | 15 | nlp = spacy.load("en_core_web_sm") 16 | 17 | def tokenize(dict): 18 | for key in dict: 19 | new_sentence_list = [] 20 | for sentence in dict[key]: 21 | a = '' 22 | for token in nlp(sentence): 23 | a += token.text 24 | a += ' ' 25 | new_sentence_list.append(a.rstrip()) 26 | dict[key] = new_sentence_list 27 | 28 | return dict 29 | 30 | 31 | def evaluator(gts, res): 32 | eval = {} 33 | # ================================================= 34 | # Set up scorers 35 | # ================================================= 36 | print('tokenization...') 37 | # Todo: use Spacy for tokenization 38 | gts = tokenize(gts) 39 | res = tokenize(res) 40 | 41 | # ================================================= 42 | # Set up scorers 43 | # ================================================= 44 | print('setting up scorers...') 45 | scorers = [ 46 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 47 | # (Meteor(), "METEOR"), 48 | # (Rouge(), "ROUGE_L"), 49 | (Cider(), "CIDEr"), 50 | (Spice(), "SPICE") 51 | ] 52 | 53 | # ================================================= 54 | # Compute scores 55 | # ================================================= 56 | for scorer, method in scorers: 57 | print("computing %s score..." % (scorer.method())) 58 | score, scores = scorer.compute_score(gts, res) 59 | if type(method) == list: 60 | for sc, scs, m in zip(score, scores, method): 61 | eval[m] = sc 62 | print("%s: %0.3f" % (m, sc)) 63 | else: 64 | eval[method] = score 65 | print("%s: %0.3f" % (method, score)) 66 | 67 | 68 | def load_targets(dataset_file): 69 | with open(dataset_file, 'r') as fin: 70 | examples = json.load(fin) 71 | 72 | examples_ = {} 73 | for example in examples: 74 | idx = example['concept_set_idx'] 75 | if idx in examples_: 76 | examples_[idx]['sentences'] = examples_[idx]['sentences'] + [example['target']] 77 | else: 78 | examples_[idx] = { 79 | 'concept_set_idx': idx, 80 | 'concepts': example['concepts'], 81 | 'sentences': [example['target']], 82 | } 83 | 84 | examples = [v for _, v in examples_.items()] 85 | 86 | return examples 87 | 88 | targets = load_targets(args.target_file) 89 | 90 | with open(args.result_file, 'r') as fin: 91 | results = json.load(fin) 92 | 93 | targets = sorted(targets, key=lambda x:x['concept_set_idx']) 94 | results = sorted(results, key=lambda x:x['concept_set_idx']) 95 | 96 | results_idx_set = set([example['concept_set_idx'] for example in results]) 97 | targets = [example for example in targets if example['concept_set_idx'] in results_idx_set] 98 | 99 | gts = {} 100 | res = {} 101 | for gts_line, res_line in zip(targets, results): 102 | assert(gts_line['concepts'] == res_line['concepts']) 103 | key = '#'.join(gts_line['concepts']) 104 | gts[key] = [x.rstrip('\n') for x in gts_line['sentences']] 105 | 106 | sentence = res_line['sentence'] 107 | sentence.replace('.', ' .') 108 | sentence.replace(',', ' ,') 109 | res[key] = [sentence.rstrip('\n')] 110 | 111 | evaluator(gts, res) 112 | 113 | from rouge_score import rouge_scorer 114 | predictions = [x['sentence'] for x in result] 115 | references = [x['sentences'] for x in targets] 116 | scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True) 117 | 118 | scores = [] 119 | for pred, ref in zip(predictions, references): 120 | rs = [scorer.score(pred, i)['rougeL'].fmeasure for i in ref] 121 | scores.append(sum(rs)/len(rs)) 122 | 123 | print('rougeL score') 124 | print(sum(scores) / len(scores)) 125 | 126 | 127 | -------------------------------------------------------------------------------- /src/gpt_sample_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | from tqdm import tqdm 6 | from transformers import GPT2LMHeadModel 7 | 8 | device = 'cuda' 9 | 10 | def init(): 11 | global device 12 | 13 | arg_parser = argparse.ArgumentParser() 14 | arg_parser.add_argument('--device', default='cuda', type=str) 15 | arg_parser.add_argument('--cuda_core', default='0', type=str) 16 | 17 | arg_parser.add_argument('--model_file', default='', type=str) 18 | 19 | arg_parser.add_argument('--sample_num', default=100, type=int) 20 | arg_parser.add_argument('--max_sample_length', default=20, type=int) 21 | arg_parser.add_argument('--evaluate_ll', action='store_true') 22 | arg_parser.add_argument('--batch_size', default=32, type=int) 23 | 24 | arg_parser.add_argument('--output_file', default='', type=str) 25 | 26 | args = arg_parser.parse_args() 27 | 28 | device = args.device 29 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_core 30 | 31 | return args 32 | 33 | 34 | def sample(model, sample_num, max_sample_length, batch_size, evaluate_ll=False): 35 | model.eval() 36 | 37 | examples = torch.LongTensor() 38 | ll_sum, token_num = 0.0, 0 39 | with torch.no_grad(): 40 | for i in tqdm(range(0, sample_num, batch_size)): 41 | num_return_seq = min(batch_size, sample_num - i) 42 | if evaluate_ll: 43 | res = model.generate(do_sample=True, max_length=max_sample_length+1, 44 | num_return_sequences=num_return_seq, top_k=50257, output_scores=True, 45 | return_dict_in_generate=True) 46 | else: 47 | res = model.generate(do_sample=True, max_length=max_sample_length+1, 48 | num_return_sequences=num_return_seq, top_k=50257, output_scores=False, 49 | return_dict_in_generate=True) 50 | 51 | examples_ = res.sequences.clone().to('cpu') 52 | examples_ = examples_[:, 1:] 53 | 54 | n, d = examples_.shape 55 | if d < max_sample_length: 56 | examples_ = torch.cat((examples_, 57 | torch.LongTensor([[model.config.eos_token_id] * (max_sample_length - d)] * n)), dim=1) 58 | 59 | examples = torch.cat((examples, examples_), dim=0) 60 | 61 | # evaluating avg log likelihood: 62 | if evaluate_ll: 63 | examples_ = res.sequences.clone().to(device) 64 | examples_ = examples_[:, 1:] 65 | scores = torch.stack(res.scores, 1).to(device) 66 | d = examples_.shape[1] 67 | 68 | mask = torch.ones(num_return_seq, d).type(torch.LongTensor).to(device) 69 | for j in range(0, num_return_seq): 70 | for k in range(0, d): 71 | if examples_[j, k] == model.config.eos_token_id: 72 | mask[j, k+1:] = 0 73 | break 74 | 75 | log_probs = torch.log(torch.softmax(scores, -1)) 76 | log_probs = log_probs[ 77 | torch.arange(examples_.shape[0]).unsqueeze(-1), 78 | torch.arange(examples_.shape[1]).unsqueeze(0), 79 | examples_[:,:]] 80 | log_probs[mask[:,:] == 0] = 0.0 81 | 82 | token_num += torch.sum(torch.sum(mask, dim=-1), dim=-1).item() 83 | ll_sum += torch.sum(torch.sum(log_probs, -1), -1).item() 84 | 85 | if evaluate_ll: 86 | ll_per_sample = ll_sum / sample_num 87 | ll_per_token = ll_sum / token_num 88 | print(f'll_per_sample: {ll_per_sample}') 89 | print(f'll_per_token: {ll_per_token}') 90 | 91 | return examples 92 | 93 | 94 | def write(examples, output_file): 95 | examples = examples.tolist() 96 | with open(output_file, 'w') as fout: 97 | for example in examples: 98 | fout.write(','.join([str(x) for x in example]) + '\n') 99 | 100 | 101 | def main(): 102 | args = init() 103 | 104 | print(f'loading {args.model_file} ...') 105 | model = GPT2LMHeadModel.from_pretrained(args.model_file) 106 | model.config.pad_token_id = model.config.eos_token_id 107 | model.to(device) 108 | 109 | examples = sample(model, args.sample_num, args.max_sample_length, 110 | args.batch_size, evaluate_ll=args.evaluate_ll) 111 | 112 | write(examples, args.output_file) 113 | 114 | 115 | if __name__ == '__main__': 116 | main() -------------------------------------------------------------------------------- /src/hmm_lvd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | import numpy 6 | import faiss 7 | 8 | from tqdm import tqdm 9 | from transformers import GPT2LMHeadModel 10 | 11 | device = 'cuda' 12 | 13 | def init(): 14 | global device 15 | 16 | arg_parser = argparse.ArgumentParser() 17 | arg_parser.add_argument('--device', default='cpu', type=str) 18 | arg_parser.add_argument('--cuda_core', default='0', type=str) 19 | 20 | arg_parser.add_argument('--teacher_model_checkpoint', default='', type=str) 21 | arg_parser.add_argument('--sample_num', default=500000, type=int) 22 | arg_parser.add_argument('--max_sample_length', default=20, type=int) 23 | arg_parser.add_argument('--batch_size', default=32, type=int) 24 | 25 | arg_parser.add_argument('--hidden_states', default=256, type=int) 26 | arg_parser.add_argument('--vocab_size', default=50257, type=int) 27 | arg_parser.add_argument('--kmeans_iterations', default=1000, type=int) 28 | arg_parser.add_argument('--pseudocount', default=0.001, type=float) 29 | 30 | arg_parser.add_argument('--output_file', default='hmm.weight', type=str) 31 | 32 | args = arg_parser.parse_args() 33 | 34 | device = args.device 35 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_core 36 | 37 | return args 38 | 39 | 40 | def sample_examples(teacher_model_checkpoint, sample_num, max_sample_length, batch_size): 41 | teacher_model = GPT2LMHeadModel.from_pretrained(teacher_model_checkpoint).to(device) 42 | eos_token_id = teacher_model.config.eos_token_id 43 | teacher_model.config.pad_token_id = eos_token_id 44 | 45 | inf = 1e10 46 | 47 | suffixes = [] # sequence_offset, token_offset, token 48 | suffix_embeddings = [] 49 | 50 | for batch_idx in tqdm(range(0, sample_num, batch_size)): 51 | num_return_seq = min(batch_size, sample_num - batch_idx) 52 | with torch.no_grad(): 53 | outputs = teacher_model.generate(do_sample=True, min_length=3, max_length=max_sample_length+1, 54 | num_return_sequences=num_return_seq, top_k=50257, 55 | output_hidden_states=True, return_dict_in_generate=True) 56 | 57 | sequences = outputs.sequences[:, 1:].clone().to('cpu') # remove the first eos token 58 | 59 | _, d = sequences.shape 60 | mask = torch.ones(num_return_seq, d).type(torch.LongTensor) 61 | for i in range(1, d): 62 | mask[sequences[:, i] == eos_token_id, i] = 0 63 | 64 | token_hidden_states = torch.stack([x[12].clone().to('cpu') for x in outputs.hidden_states], dim=1).squeeze() 65 | suffix_hidden_states = token_hidden_states * mask.unsqueeze(-1) 66 | 67 | for i in range(0, d): 68 | suffixes.extend([((batch_idx+j, i), # suffix_offset 69 | sequences[j, i].item()) # token 70 | for j in range(0, num_return_seq) if mask[j, i] == 1]) # suffix_offset = (batch_idx+j, i) 71 | suffix_embeddings.append(suffix_hidden_states[mask[:, i] == 1, i, :]) 72 | 73 | suffix_embeddings = torch.cat(suffix_embeddings, dim=0).detach().cpu().numpy() 74 | 75 | return suffixes, suffix_embeddings 76 | 77 | 78 | def Kmeans_faiss(vecs, K, max_iterations=1000, nredo=1, verbose=True): 79 | vecs = numpy.unique(vecs, axis=0) # this line is slow 80 | kmeans = faiss.Kmeans(vecs.shape[1], K, 81 | niter=max_iterations, nredo=nredo, verbose=verbose, 82 | max_points_per_centroid=vecs.shape[0] // K, gpu=True) 83 | kmeans.train(vecs) 84 | 85 | return kmeans 86 | 87 | 88 | def update_flows(alpha, beta, gamma, suffixes, idx2cluster, 89 | hidden_states, vocab_size): 90 | 91 | eos_token_id = vocab_size - 1 92 | 93 | offset2index = {} 94 | for idx, suffix in enumerate(suffixes): 95 | offset2index[suffix[0]] = idx 96 | 97 | for idx in tqdm(range(0, len(suffixes))): 98 | suffix = suffixes[idx] 99 | suffix_offset, token = suffix 100 | suffix_offset_next = (suffix_offset[0], suffix_offset[1]+1) 101 | u = idx2cluster[idx] 102 | 103 | v = None 104 | if suffix_offset_next in offset2index: 105 | v = idx2cluster[offset2index[suffix_offset_next]] 106 | else: 107 | v = hidden_states - 1 # the reserved hidden state for token 108 | 109 | alpha[u, v] += 1.0 110 | beta[u, token] += 1.0 111 | if suffix_offset[1] == 0: 112 | gamma[u] += 1.0 113 | 114 | alpha[hidden_states-1, hidden_states-1] = 1.0 115 | beta[hidden_states-1, eos_token_id] = 1.0 116 | 117 | 118 | def write_params(alpha, beta, gamma, pseudocount, 119 | hidden_states, vocab_size, output_file): 120 | 121 | alpha += pseudocount 122 | beta += pseudocount 123 | gamma += pseudocount 124 | 125 | alpha = torch.log(alpha / torch.sum(alpha, dim=-1).unsqueeze(-1)) 126 | beta = torch.log(beta / torch.sum(beta, dim=-1).unsqueeze(-1)) 127 | gamma = torch.log(gamma / torch.sum(gamma, dim=-1)) 128 | 129 | torch.save({'hidden_states': hidden_states, 130 | 'vocab_size': vocab_size, 131 | 'alpha': alpha, 132 | 'beta': beta, 133 | 'gamma': gamma,}, 134 | f'{output_file}.th') 135 | 136 | 137 | def main(): 138 | args = init() 139 | 140 | print(f'sampling {args.sample_num} examples from {args.teacher_model_checkpoint} ...') 141 | suffixes, suffix_embeddings = sample_examples(args.teacher_model_checkpoint, 142 | args.sample_num, args.max_sample_length, args.batch_size) 143 | 144 | print(f'training K-means with {args.hidden_states-1} clusters and {len(suffixes)} suffix embeddings ...') 145 | kmeans = Kmeans_faiss(suffix_embeddings, args.hidden_states - 1, 146 | max_iterations=args.kmeans_iterations) 147 | 148 | print(f'clustering {len(suffixes)} suffix embeddings into {args.hidden_states-1} clusters ...') 149 | _, idx2cluster = kmeans.index.search(suffix_embeddings, 1) 150 | idx2cluster = numpy.squeeze(idx2cluster).tolist() 151 | 152 | alpha = torch.zeros(args.hidden_states, args.hidden_states) 153 | beta = torch.zeros(args.hidden_states, args.vocab_size) 154 | gamma = torch.zeros(args.hidden_states) 155 | 156 | print('computing flows ...') 157 | update_flows(alpha, beta, gamma, suffixes, idx2cluster, 158 | args.hidden_states, args.vocab_size) 159 | 160 | print(f'storing parameters to {args.output_file} ...') 161 | write_params(alpha, beta, gamma, args.pseudocount, 162 | args.hidden_states, args.vocab_size, args.output_file) 163 | 164 | 165 | if __name__ == '__main__': 166 | main() -------------------------------------------------------------------------------- /src/gpt_finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 10 | 11 | device = 'cuda' 12 | 13 | 14 | class DatasetFromFile(torch.utils.data.Dataset): 15 | def __init__(self, dataset_file, seq2seq): 16 | with open(dataset_file, 'r') as fin: 17 | data = json.load(fin) 18 | 19 | if seq2seq: 20 | texts = ['<|endoftext|>' + ' ' + ' '.join(e['concepts']) + ' = ' + e['target'] + '<|endoftext|>' for e in data] 21 | else: 22 | texts = ['<|endoftext|>' + ' ' + e['target'] + '<|endoftext|>' for e in data] 23 | 24 | self.texts = texts 25 | 26 | def __len__(self): 27 | return len(self.texts) 28 | 29 | def __getitem__(self, index): 30 | return self.texts[index] 31 | 32 | 33 | def init(): 34 | global device 35 | 36 | arg_parser = argparse.ArgumentParser() 37 | arg_parser.add_argument('--device', default='cuda', type=str) 38 | arg_parser.add_argument('--cuda_core', default='0', type=str) 39 | 40 | arg_parser.add_argument('--max_epoch', default=20, type=int) 41 | arg_parser.add_argument('--batch_size', default=8, type=int) 42 | arg_parser.add_argument('--lr', default=0.0001, type=float) 43 | arg_parser.add_argument('--grad_accum_iters', default=1, type=int) 44 | arg_parser.add_argument('--max_sequence_length', default=None, type=int) 45 | 46 | arg_parser.add_argument('--seq2seq', action='store_true') 47 | arg_parser.add_argument('--skip_eval', action='store_true') 48 | 49 | arg_parser.add_argument('--train_data_file', default='common_gen', type=str) 50 | arg_parser.add_argument('--validation_data_file', default='', type=str) 51 | arg_parser.add_argument('--model_path', default='', type=str) 52 | arg_parser.add_argument('--log_file', default='log.txt', type=str) 53 | 54 | args = arg_parser.parse_args() 55 | 56 | device = args.device 57 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_core 58 | 59 | return args 60 | 61 | 62 | def aggregate_loss(model, data_loader): 63 | model.eval() 64 | losses = [] 65 | with torch.no_grad(): 66 | for batch in tqdm(data_loader): 67 | inputs = { 68 | 'input_ids': batch['input_ids'].to(device), 69 | 'attention_mask': batch['attention_mask'].to(device), 70 | 'labels': batch['labels'].to(device) 71 | } 72 | loss = model(**inputs).loss 73 | losses.append(loss.item()) 74 | return torch.mean(torch.Tensor(losses)).item() 75 | 76 | 77 | def main(): 78 | args = init() 79 | 80 | train_data = DatasetFromFile(args.train_data_file, args.seq2seq) 81 | if args.validation_data_file != '': 82 | validation_data = DatasetFromFile(args.validation_data_file, args.seq2seq) 83 | else: 84 | validation_data = None 85 | 86 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large') 87 | tokenizer.pad_token = tokenizer.eos_token 88 | 89 | def collate(batch): 90 | batch_encoding = tokenizer([text for text in batch], padding=True) 91 | 92 | labels = [[(x if y == 1 else -100) for x, y in zip(e, mask)] 93 | for e, mask in zip(batch_encoding['input_ids'], batch_encoding['attention_mask'])] 94 | batch_encoding_tensor = { 95 | 'input_ids': torch.LongTensor(batch_encoding['input_ids']), 96 | 'attention_mask': torch.LongTensor(batch_encoding['attention_mask']), 97 | 'labels': torch.LongTensor(labels) 98 | } 99 | 100 | if (args.max_sequence_length is not None) and \ 101 | (args.max_sequence_length < batch_encoding_tensor['input_ids'].shape[1]): 102 | batch_encoding_tensor = { 103 | 'input_ids': batch_encoding_tensor['input_ids'][:, :args.max_sequence_length], 104 | 'attention_mask': batch_encoding_tensor['attention_mask'][:, :args.max_sequence_length], 105 | 'labels': batch_encoding_tensor['labels'][:, :args.max_sequence_length] 106 | } 107 | 108 | return batch_encoding_tensor 109 | 110 | train_loader = DataLoader(train_data, collate_fn=collate, batch_size=args.batch_size, shuffle=True) 111 | if validation_data is not None: 112 | validation_loader = DataLoader(validation_data, collate_fn=collate, batch_size=args.batch_size, shuffle=True) 113 | 114 | model = GPT2LMHeadModel.from_pretrained('gpt2-large', 115 | pad_token_id=tokenizer.eos_token_id) 116 | print('Saving checkpoint-0') 117 | model_save_path = os.path.join(args.model_path, 'checkpoint-0') 118 | model.save_pretrained(model_save_path) 119 | 120 | model.to(device) 121 | 122 | optim = torch.optim.AdamW(model.parameters(), lr=args.lr) 123 | 124 | for epoch in range(1, args.max_epoch+1): 125 | print(f'epoch {epoch}') 126 | 127 | model.train() 128 | optim.zero_grad() 129 | batch_idx = 0 130 | for batch in tqdm(train_loader): 131 | inputs = { 132 | 'input_ids': batch['input_ids'].to(device), 133 | 'attention_mask': batch['attention_mask'].to(device), 134 | 'labels': batch['labels'].to(device) 135 | } 136 | loss = model(**inputs).loss 137 | 138 | loss = loss / args.grad_accum_iters 139 | 140 | loss.backward() 141 | 142 | batch_idx += 1 143 | if (batch_idx % args.grad_accum_iters == 0) or (batch_idx == len(train_loader)): 144 | optim.step() 145 | optim.zero_grad() 146 | 147 | 148 | print(f'Saving checkpoint-{epoch}') 149 | if not os.path.exists(args.model_path): 150 | os.makedirs(args.model_path) 151 | model_save_path = os.path.join(args.model_path, f'checkpoint-{epoch}') 152 | model.save_pretrained(model_save_path) 153 | 154 | if not args.skip_eval: 155 | print(f'Evaluating checkpoint-{epoch}') 156 | train_loss = aggregate_loss(model, train_loader) 157 | if validation_data is not None: 158 | validation_loss = aggregate_loss(model, validation_loader) 159 | else: 160 | validation_loss = -1.0 161 | 162 | msg = f'epoch {epoch}, train_loss: {train_loss}, validation_loss: {validation_loss}' 163 | 164 | print(msg) 165 | with open(args.log_file, 'a+') as fout: 166 | fout.write(msg + '\n') 167 | else: 168 | print('Skipped evaluation') 169 | 170 | 171 | if __name__ == '__main__': 172 | main() 173 | -------------------------------------------------------------------------------- /src/hmm_train.jl: -------------------------------------------------------------------------------- 1 | using ArgParse 2 | using Pickle 3 | using CUDA 4 | using CSV 5 | using DataFrames 6 | using ProbabilisticCircuits 7 | using ProbabilisticCircuits: PlainInputNode, PlainSumNode, CuBitsProbCircuit, 8 | multiply, loglikelihood, full_batch_em, update_parameters 9 | 10 | 11 | function dataset_cpu(dataset_path, sample_length; padding=true) 12 | dataframe = CSV.read(dataset_path, DataFrame; 13 | header=false, types=Union{UInt32, Missing}, strict=true) 14 | if padding 15 | m = map(x -> x==50257 ? UInt32(50256) : x, Tables.matrix(dataframe)) 16 | else 17 | m = map(x -> x==50257 ? missing : x, Tables.matrix(dataframe)) 18 | end 19 | return m[:, 1:sample_length] 20 | end 21 | 22 | 23 | function init() 24 | s = ArgParseSettings() 25 | @add_arg_table s begin 26 | "--cuda_id" 27 | help = "CUDA ID" 28 | arg_type = Int64 29 | default = 1 30 | "--model_path" 31 | help = "path for saving/loading checkpoints" 32 | arg_type = String 33 | default = "" 34 | "--checkpoint" 35 | help = "start iterations" 36 | arg_type = Int64 37 | default = 0 38 | "--max_epochs" 39 | help = "max iterations" 40 | arg_type = Int64 41 | default = 1000 42 | "--train_data_file" 43 | help = "train file path" 44 | arg_type = String 45 | "--sample_length" 46 | help = "" 47 | arg_type = Int64 48 | "--hidden_states" 49 | help = "number of clusters used in warmup" 50 | arg_type = Int64 51 | default = 512 52 | "--vocab_size" 53 | help = "" 54 | arg_type = Int64 55 | default = 50257 56 | "--batch_size" 57 | help = "batch_size" 58 | arg_type = Int64 59 | default = 512 60 | "--pseudocount" 61 | help = "pseudocount" 62 | arg_type = Float64 63 | default = 1.0 64 | "--log_file" 65 | help = "log file" 66 | arg_type = String 67 | end 68 | args = parse_args(ARGS, s) 69 | 70 | args 71 | end 72 | 73 | 74 | function load_hmm(checkpoint_file, sample_length) 75 | 76 | x = Pickle.Torch.THload(checkpoint_file) 77 | hidden_states = x["hidden_states"] 78 | vocab_size = x["vocab_size"] 79 | alpha = x["alpha"] 80 | beta = x["beta"] 81 | gamma = x["gamma"] 82 | 83 | input2group = Dict() 84 | sum2group = Dict() 85 | 86 | layer = Any[] 87 | inputs = Any[] 88 | 89 | for suffix_len in 1:sample_length 90 | var = sample_length - suffix_len + 1 91 | 92 | # construct input nodes 93 | inputs = Any[] 94 | for u in 1:hidden_states 95 | weights = beta[u, :] 96 | input = PlainInputNode(var, Categorical(weights)) 97 | push!(inputs, input) 98 | end 99 | 100 | for u in 1:hidden_states 101 | input2group[inputs[u]] = u 102 | end 103 | 104 | # construct linear layer 105 | if suffix_len == 1 106 | layer = inputs 107 | else 108 | layer_new = Any[] 109 | for u in 1:hidden_states 110 | children = [layer[v] for v in 1:hidden_states] 111 | sum_node = PlainSumNode(children, alpha[u, :]) 112 | sum2group[sum_node] = u 113 | push!(layer_new, multiply(inputs[u], sum_node)) 114 | end 115 | layer = layer_new 116 | end 117 | end 118 | 119 | pc = PlainSumNode(layer, gamma) 120 | 121 | pc2hmm = Dict( 122 | "state2sum" => [x.inputs[2] for x in layer], 123 | "state2input" => inputs, 124 | ) 125 | 126 | pc, input2group, sum2group, hidden_states, vocab_size, pc2hmm 127 | end 128 | 129 | 130 | function save_hmm(pc, pc2hmm, hidden_states, vocab_size, 131 | checkpoint_file_path) 132 | 133 | state2sum, state2input = pc2hmm["state2sum"], pc2hmm["state2input"] 134 | 135 | # write alpha 136 | alpha = Array{Float32}(undef, hidden_states, hidden_states) 137 | beta = Array{Float32}(undef, hidden_states, vocab_size) 138 | for u in 1:hidden_states 139 | alpha[u, :] = state2sum[u].params 140 | end 141 | 142 | # write beta 143 | for u in 1:hidden_states 144 | beta[u, :] = state2input[u].dist.logps 145 | end 146 | 147 | # write gamma 148 | gamma = Float32.(pc.params) 149 | 150 | Pickle.Torch.THsave(checkpoint_file_path, Dict( 151 | "hidden_states" => hidden_states, 152 | "vocab_size" => vocab_size, 153 | "alpha" => alpha, 154 | "beta" => beta, 155 | "gamma" => gamma, 156 | )) 157 | end 158 | 159 | 160 | function train_hmm(bpc, node2group, edge2group, 161 | checkpoint, max_epochs, batch_size, pseudocount, 162 | model_path, pc, pc2hmm, hidden_states, vocab_size, 163 | train_data_file, sample_length, log_file) 164 | 165 | for epoch in checkpoint+1:max_epochs 166 | load_path = "$train_data_file.$epoch" 167 | println("loading train data $load_path ...") 168 | data_epoch = dataset_cpu(load_path, sample_length) 169 | data_size = size(data_epoch)[1] 170 | 171 | validation_epoch = cu(data_epoch[1:div(data_size, 10), :]) 172 | train_epoch = cu(data_epoch[div(data_size, 10)+1:data_size, :]) 173 | 174 | if epoch == checkpoint+1 175 | ll = loglikelihood(bpc, validation_epoch; batch_size) 176 | println("$(checkpoint)\t0.0\t$(ll)") 177 | open(log_file, "a+") do fout 178 | write(fout, "$(checkpoint)\t0.0\t$(ll)\n") 179 | end 180 | end 181 | 182 | println("Full batch epoch = ", epoch) 183 | @time train_ll = full_batch_em(bpc, train_epoch, 1; 184 | batch_size, pseudocount, node2group, edge2group) 185 | 186 | validation_ll = loglikelihood(bpc, validation_epoch; batch_size) 187 | println("$(epoch)\t$(train_ll[end])\t$(validation_ll)") 188 | open(log_file, "a+") do fout 189 | write(fout, "$(epoch)\t$(train_ll[end])\t$(validation_ll)\n") 190 | end 191 | 192 | println("Free memory") 193 | @time begin 194 | CUDA.unsafe_free!(train_epoch) 195 | CUDA.unsafe_free!(validation_epoch) 196 | end 197 | 198 | # save checkpoint every 5 epoch 199 | if !isnothing(model_path) && model_path!= "" && epoch % 5 == 0 200 | update_parameters(bpc) 201 | checkpoint_path = model_path * "/checkpoint-$(epoch).weight.th" 202 | save_hmm(pc, pc2hmm, hidden_states, vocab_size, 203 | checkpoint_path) 204 | end 205 | end 206 | end 207 | 208 | 209 | function main() 210 | args = init() 211 | 212 | println(args) 213 | device!(args["cuda_id"]) 214 | 215 | open(args["log_file"], "a+") do fout 216 | write(fout, join(ARGS, " ") * "\n") 217 | end 218 | 219 | # load checkpoint 220 | checkpoint_file = args["model_path"] * "/" * "checkpoint-$(args["checkpoint"]).weight.th" 221 | println("loading params from $(checkpoint_file) ...") 222 | @time pc, input2group, sum2group, hidden_states, vocab_size, pc2hmm = load_hmm( 223 | checkpoint_file, args["sample_length"]) 224 | 225 | println("gc ...") 226 | @time GC.gc() 227 | 228 | println("moving circuit to gpu ...") 229 | CUDA.@time bpc, node2group, edge2group = CuBitsProbCircuit(pc, input2group, sum2group) 230 | 231 | println("gc ...") 232 | @time GC.gc() 233 | 234 | @time println("training hmm with $(num_parameters(pc)) params and $(num_nodes(pc)) nodes ...") 235 | 236 | # free memory 237 | println("runing gc to free RAM ...") 238 | input2group, sum2group = nothing, nothing 239 | @time GC.gc() 240 | 241 | println("training hmm ...") 242 | train_hmm(bpc, node2group, edge2group, 243 | args["checkpoint"], args["max_epochs"], args["batch_size"], args["pseudocount"], 244 | args["model_path"], pc, pc2hmm, hidden_states, vocab_size, 245 | args["train_data_file"], args["sample_length"], args["log_file"]) 246 | 247 | println() 248 | end 249 | 250 | 251 | if abspath(PROGRAM_FILE) == @__FILE__ 252 | main() 253 | end -------------------------------------------------------------------------------- /src/eval_metrics/cider.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import pdb 9 | import math 10 | 11 | def precook(s, n=4, out=False): 12 | """ 13 | Takes a string as input and returns an object that can be given to 14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 15 | can take string arguments as well. 16 | :param s: string : sentence to be converted into ngrams 17 | :param n: int : number of ngrams for which representation is calculated 18 | :return: term frequency vector for occuring ngrams 19 | """ 20 | words = s.split() 21 | counts = defaultdict(int) 22 | for k in range(1,n+1): 23 | for i in range(len(words)-k+1): 24 | ngram = tuple(words[i:i+k]) 25 | counts[ngram] += 1 26 | return counts 27 | 28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 29 | '''Takes a list of reference sentences for a single segment 30 | and returns an object that encapsulates everything that BLEU 31 | needs to know about them. 32 | :param refs: list of string : reference sentences for some image 33 | :param n: int : number of ngrams for which (ngram) representation is calculated 34 | :return: result (list of dict) 35 | ''' 36 | return [precook(ref, n) for ref in refs] 37 | 38 | def cook_test(test, n=4): 39 | '''Takes a test sentence and returns an object that 40 | encapsulates everything that BLEU needs to know about it. 41 | :param test: list of string : hypothesis sentence for some image 42 | :param n: int : number of ngrams for which (ngram) representation is calculated 43 | :return: result (dict) 44 | ''' 45 | return precook(test, n, True) 46 | 47 | class CiderScorer(object): 48 | """CIDEr scorer. 49 | """ 50 | 51 | def copy(self): 52 | ''' copy the refs.''' 53 | new = CiderScorer(n=self.n) 54 | new.ctest = copy.copy(self.ctest) 55 | new.crefs = copy.copy(self.crefs) 56 | return new 57 | 58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 59 | ''' singular instance ''' 60 | self.n = n 61 | self.sigma = sigma 62 | self.crefs = [] 63 | self.ctest = [] 64 | self.document_frequency = defaultdict(float) 65 | self.cook_append(test, refs) 66 | self.ref_len = None 67 | 68 | def cook_append(self, test, refs): 69 | '''called by constructor and __iadd__ to avoid creating new instances.''' 70 | 71 | if refs is not None: 72 | self.crefs.append(cook_refs(refs)) 73 | if test is not None: 74 | self.ctest.append(cook_test(test)) ## N.B.: -1 75 | else: 76 | self.ctest.append(None) # lens of crefs and ctest have to match 77 | 78 | def size(self): 79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 80 | return len(self.crefs) 81 | 82 | def __iadd__(self, other): 83 | '''add an instance (e.g., from another sentence).''' 84 | 85 | if type(other) is tuple: 86 | ## avoid creating new CiderScorer instances 87 | self.cook_append(other[0], other[1]) 88 | else: 89 | self.ctest.extend(other.ctest) 90 | self.crefs.extend(other.crefs) 91 | 92 | return self 93 | def compute_doc_freq(self): 94 | ''' 95 | Compute term frequency for reference data. 96 | This will be used to compute idf (inverse document frequency later) 97 | The term frequency is stored in the object 98 | :return: None 99 | ''' 100 | for refs in self.crefs: 101 | # refs, k ref captions of one image 102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 103 | self.document_frequency[ngram] += 1 104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 105 | 106 | def compute_cider(self): 107 | def counts2vec(cnts): 108 | """ 109 | Function maps counts of ngram to vector of tfidf weights. 110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 111 | The n-th entry of array denotes length of n-grams. 112 | :param cnts: 113 | :return: vec (array of dict), norm (array of float), length (int) 114 | """ 115 | vec = [defaultdict(float) for _ in range(self.n)] 116 | length = 0 117 | norm = [0.0 for _ in range(self.n)] 118 | for (ngram,term_freq) in cnts.items(): 119 | # give word count 1 if it doesn't appear in reference corpus 120 | df = np.log(max(1.0, self.document_frequency[ngram])) 121 | # ngram index 122 | n = len(ngram)-1 123 | # tf (term_freq) * idf (precomputed idf) for n-grams 124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 125 | # compute norm for the vector. the norm will be used for computing similarity 126 | norm[n] += pow(vec[n][ngram], 2) 127 | 128 | if n == 1: 129 | length += term_freq 130 | norm = [np.sqrt(n) for n in norm] 131 | return vec, norm, length 132 | 133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 134 | ''' 135 | Compute the cosine similarity of two vectors. 136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 137 | :param vec_ref: array of dictionary for vector corresponding to reference 138 | :param norm_hyp: array of float for vector corresponding to hypothesis 139 | :param norm_ref: array of float for vector corresponding to reference 140 | :param length_hyp: int containing length of hypothesis 141 | :param length_ref: int containing length of reference 142 | :return: array of score for each n-grams cosine similarity 143 | ''' 144 | delta = float(length_hyp - length_ref) 145 | # measure consine similarity 146 | val = np.array([0.0 for _ in range(self.n)]) 147 | for n in range(self.n): 148 | # ngram 149 | for (ngram,count) in vec_hyp[n].items(): 150 | # vrama91 : added clipping 151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 152 | 153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 154 | val[n] /= (norm_hyp[n]*norm_ref[n]) 155 | 156 | assert(not math.isnan(val[n])) 157 | # vrama91: added a length based gaussian penalty 158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 159 | return val 160 | 161 | # compute log reference length 162 | self.ref_len = np.log(float(len(self.crefs))) 163 | 164 | scores = [] 165 | for test, refs in zip(self.ctest, self.crefs): 166 | # compute vector for test captions 167 | vec, norm, length = counts2vec(test) 168 | # compute vector for ref captions 169 | score = np.array([0.0 for _ in range(self.n)]) 170 | for ref in refs: 171 | vec_ref, norm_ref, length_ref = counts2vec(ref) 172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 173 | # change by vrama91 - mean of ngram scores, instead of sum 174 | score_avg = np.mean(score) 175 | # divide by number of references 176 | score_avg /= len(refs) 177 | # multiply score by 10 178 | score_avg *= 10.0 179 | # append score of an image to the score list 180 | scores.append(score_avg) 181 | return scores 182 | 183 | def compute_score(self, option=None, verbose=0): 184 | # compute idf 185 | self.compute_doc_freq() 186 | # assert to check document frequency 187 | assert(len(self.ctest) >= max(self.document_frequency.values())) 188 | # compute cider score 189 | score = self.compute_cider() 190 | # debug 191 | # print score 192 | return np.mean(np.array(score)), np.array(score) 193 | 194 | # Filename: cider.py 195 | # 196 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 197 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 198 | # 199 | # Creation Date: Sun Feb 8 14:16:54 2015 200 | # 201 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 202 | 203 | import pdb 204 | 205 | class Cider: 206 | """ 207 | Main Class to compute the CIDEr metric 208 | 209 | """ 210 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 211 | # set cider to sum over 1 to 4-grams 212 | self._n = n 213 | # set the standard deviation parameter for gaussian penalty 214 | self._sigma = sigma 215 | 216 | def compute_score(self, gts, res): 217 | """ 218 | Main function to compute CIDEr score 219 | :param hypo_for_image (dict) : dictionary with key and value 220 | ref_for_image (dict) : dictionary with key and value 221 | :return: cider (float) : computed CIDEr score for the corpus 222 | """ 223 | 224 | assert(list(gts.keys()) == list(res.keys())) 225 | imgIds = list(gts.keys()) 226 | 227 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 228 | 229 | for id in imgIds: 230 | hypo = res[id] 231 | ref = gts[id] 232 | 233 | # Sanity check. 234 | assert(type(hypo) is list) 235 | assert(len(hypo) == 1) 236 | assert(type(ref) is list) 237 | assert(len(ref) > 0) 238 | 239 | cider_scorer += (hypo[0], ref) 240 | 241 | (score, scores) = cider_scorer.compute_score() 242 | 243 | return score, scores 244 | 245 | def method(self): 246 | return "CIDEr" -------------------------------------------------------------------------------- /src/decode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | from tqdm import tqdm 6 | import torch 7 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 8 | from lemminflect import getAllInflections 9 | 10 | from hmm_model import * 11 | 12 | device = 'cuda' 13 | 14 | class GPTConstraintModel(GPT2LMHeadModel): 15 | def forward(self, **kwargs): 16 | input_ids = kwargs['input_ids'] 17 | 18 | hmm_model = kwargs['hmm_model'] 19 | hmm_cnf = kwargs['hmm_cnf'] 20 | hmm_seq_len = kwargs['hmm_seq_len'] 21 | hmm_prompt_len = kwargs['hmm_prompt_len'] 22 | hmm_seq2seq = kwargs['hmm_seq2seq'] 23 | hmm_w = kwargs['hmm_w'] 24 | gpt_only = kwargs['gpt_only'] 25 | hmm_only = kwargs['hmm_only'] 26 | hmm_fix_order = kwargs['hmm_fix_order'] 27 | 28 | prefixes = [tuple(prefix) for prefix in input_ids[:,1:].tolist()] 29 | 30 | hmm_logits_alpha, hmm_logits = hmm_model.compute_logits( 31 | prefixes, hmm_cnf, hmm_seq_len, hmm_prompt_len, 32 | hmm_seq2seq, fix_order=hmm_fix_order) 33 | 34 | kwargs_ = kwargs.copy() 35 | del kwargs_['hmm_model'] 36 | del kwargs_['hmm_cnf'] 37 | del kwargs_['hmm_seq_len'] 38 | del kwargs_['hmm_prompt_len'] 39 | del kwargs_['hmm_seq2seq'] 40 | del kwargs_['hmm_w'] 41 | del kwargs_['hmm_only'] 42 | del kwargs_['gpt_only'] 43 | del kwargs_['hmm_fix_order'] 44 | 45 | outputs = super().forward(**kwargs_) 46 | 47 | hmm_logits_alpha = torch.log_softmax(hmm_logits_alpha, dim=-1) 48 | hmm_logits = torch.log_softmax(hmm_logits, dim=-1) 49 | gpt_logits = torch.log_softmax(outputs.logits[:,-1,:], dim=-1) 50 | 51 | if hmm_only: 52 | logits_new = hmm_logits_alpha 53 | elif gpt_only: 54 | logits_new = gpt_logits 55 | else: 56 | if hmm_seq2seq: 57 | logits_new = hmm_w * hmm_logits_alpha + (1.0 - hmm_w) * gpt_logits 58 | else: 59 | logits_new = hmm_logits_alpha + gpt_logits - hmm_logits 60 | 61 | logits_new = torch.log_softmax(logits_new, dim=-1) 62 | 63 | outputs.logits[:,-1,:] = logits_new 64 | 65 | return outputs 66 | 67 | 68 | def prepare_inputs_for_generation(self, input_ids, **model_kwargs): 69 | inputs = super().prepare_inputs_for_generation(input_ids, **model_kwargs) 70 | 71 | inputs['hmm_model'] = model_kwargs['hmm_model'] 72 | inputs['hmm_cnf'] = model_kwargs['hmm_cnf'] 73 | inputs['hmm_seq_len'] = model_kwargs['hmm_seq_len'] 74 | inputs['hmm_prompt_len'] = model_kwargs['hmm_prompt_len'] 75 | inputs['hmm_seq2seq'] = model_kwargs['hmm_seq2seq'] 76 | inputs['hmm_w'] = model_kwargs['hmm_w'] 77 | inputs['gpt_only'] = model_kwargs['gpt_only'] 78 | inputs['hmm_only'] = model_kwargs['hmm_only'] 79 | inputs['hmm_fix_order'] = model_kwargs['hmm_fix_order'] 80 | 81 | return inputs 82 | 83 | 84 | def init(): 85 | global device 86 | global CUDA_CORE 87 | 88 | arg_parser = argparse.ArgumentParser() 89 | arg_parser.add_argument('--device', default='cuda', type=str) 90 | arg_parser.add_argument('--cuda_core', default='1', type=str) 91 | arg_parser.add_argument('--hmm_batch_size', default=256, type=int) 92 | 93 | arg_parser.add_argument('--seq2seq', default=0, type=int) 94 | # --seq2seq 0: unsupervised setting, use the unsupervised base model together with the 95 | # HMM distilled from the unsupervised base model. 96 | # --seq2seq 1: supervised setting 1, use the supervised base model together with the HMM 97 | # model distilled from the unsuperivised base model 98 | # --seq2seq 2: supervised setting 2, use supervised base model and the HMM distilled from 99 | # the supervised base model 100 | arg_parser.add_argument('--w', default=0.2, type=float) 101 | # weight for geometric mean, only effective with --seq2seq non-zero 102 | arg_parser.add_argument('--hmm_only', action='store_true') 103 | arg_parser.add_argument('--gpt_only', action='store_true') 104 | 105 | arg_parser.add_argument('--min_sample_length', default=5, type=int) 106 | arg_parser.add_argument('--max_sample_length', default=32, type=int) 107 | arg_parser.add_argument('--num_beams', default=2, type=int) 108 | arg_parser.add_argument('--length_penalty', default=0.2, type=float) 109 | arg_parser.add_argument('--fix_order', action='store_true') 110 | arg_parser.add_argument('--no_inflection', action='store_true') 111 | 112 | arg_parser.add_argument('--hmm_model_path', default=None, type=str) 113 | arg_parser.add_argument('--gpt_model_path', default='gpt2', type=str) 114 | arg_parser.add_argument('--dataset_file', default='', type=str) 115 | arg_parser.add_argument('--dataset_start', default=0, type=int) 116 | arg_parser.add_argument('--dataset_end', default=-1, type=int) 117 | arg_parser.add_argument('--output_file', default='pred.json', type=str) 118 | 119 | args = arg_parser.parse_args() 120 | 121 | # device = f'cuda:{args.cuda_core}' # args.device 122 | # os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_core 123 | torch.cuda.set_device(int(args.cuda_core)) 124 | 125 | return args 126 | 127 | 128 | def concepts2cnf(concepts, tokenizer, no_inflection=False): 129 | cnf = [] 130 | concept_set = set([tuple(tokenizer.encode(f' {x}')) for x in concepts]) 131 | 132 | for concept in concepts: 133 | s = tuple(tokenizer.encode(f' {concept}')) 134 | inflections = set([s]) 135 | if not no_inflection: 136 | for k, v in getAllInflections(concept).items(): 137 | for x in v: 138 | t = tuple(tokenizer.encode(f' {x}')) 139 | if len(s) <= len(t) and t[:len(s)] == s: 140 | continue 141 | # when both surf and surfer are required concepts 142 | # avoid the case that surfer is considered an inflection of surf 143 | if t in concept_set: 144 | continue 145 | inflections.add(t) 146 | 147 | clause = tuple(inflections) 148 | cnf.append(clause) 149 | 150 | cnf = tuple(cnf) 151 | 152 | return cnf 153 | 154 | 155 | def load_dataset(dataset_file, dataset_start=0, dataset_end=-1): 156 | with open(dataset_file, 'r') as fin: 157 | examples = json.load(fin) 158 | if dataset_end == -1: 159 | dataset_end = len(examples)-1 160 | 161 | examples_ = {} 162 | for example in examples: 163 | idx = example['concept_set_idx'] 164 | if dataset_start <= idx and idx <= dataset_end: 165 | examples_[idx] = { 166 | 'concept_set_idx': idx, 167 | 'concepts': example['concepts'], 168 | 'sentences': [], 169 | } 170 | 171 | examples = [v for _, v in examples_.items()] 172 | 173 | return examples 174 | 175 | 176 | def main(): 177 | args = init() 178 | 179 | print(f'loading gpt2 from {args.gpt_model_path} ...') 180 | gpt_model = GPTConstraintModel.from_pretrained(args.gpt_model_path) 181 | gpt_model.config.pad_token_id = gpt_model.config.eos_token_id 182 | gpt_model.config.use_cache = False 183 | gpt_model.to(device) 184 | 185 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large') 186 | 187 | # pre-define sep_tokens 188 | sep_tokens = [] 189 | for token in range(0, 50257): 190 | char = tokenizer.decode(token)[0] 191 | if char in [' ', '.', ',',]: 192 | sep_tokens.append(token) 193 | 194 | print(f'loading hmm from {args.hmm_model_path} ...') 195 | hmm_model = HMM(args.hmm_model_path, sep_tokens=sep_tokens) 196 | hmm_model.to(device) 197 | 198 | examples = load_dataset(args.dataset_file, args.dataset_start, args.dataset_end) 199 | 200 | print('generating sequences ...') 201 | counter = 0 202 | for example_idx in tqdm(range(0, len(examples))): 203 | example = examples[example_idx] 204 | concepts = example['concepts'] 205 | 206 | cnf = concepts2cnf(concepts, tokenizer, no_inflection=args.no_inflection) 207 | 208 | prompt = '<|endoftext|>' 209 | if args.seq2seq: 210 | prompt += ' ' + ' '.join(concepts) + ' =' 211 | prompt = tuple(tokenizer.encode(prompt)) 212 | 213 | if args.seq2seq == 0 or args.seq2seq == 1: 214 | hmm_seq_len = args.max_sample_length 215 | else: 216 | hmm_seq_len = len(prompt) - 1 + args.max_sample_length 217 | 218 | if args.seq2seq == 0: 219 | hmm_prompt_len = 0 220 | else: 221 | hmm_prompt_len = len(prompt) - 1 222 | 223 | model_kwargs = { 224 | 'hmm_model': hmm_model, 225 | 'hmm_cnf': cnf, 226 | 'hmm_seq_len': hmm_seq_len, 227 | 'hmm_prompt_len': hmm_prompt_len, 228 | 'hmm_seq2seq': args.seq2seq, 229 | 'hmm_w': args.w, 230 | 'gpt_only': args.gpt_only, 231 | 'hmm_only': args.hmm_only, 232 | 'hmm_fix_order': args.fix_order 233 | } 234 | 235 | input_ids = torch.tensor([prompt], device=device) 236 | with torch.no_grad(): 237 | hmm_model.initialize_cache(hmm_seq_len, cnf, 238 | prompt_tokens=prompt[1:], batch_size=args.hmm_batch_size, fix_order=args.fix_order) 239 | 240 | outputs = gpt_model.generate(input_ids=input_ids, do_sample=False, 241 | num_beams=args.num_beams, num_return_sequences=args.num_beams, 242 | min_length=len(prompt)+args.min_sample_length, max_length=len(prompt)+args.max_sample_length, 243 | top_k=50257, length_penalty=args.length_penalty, no_repeat_ngram_size=4, 244 | output_scores=False, return_dict_in_generate=True, **model_kwargs) 245 | 246 | output_ids = outputs.sequences.detach() 247 | 248 | sentences = [x.strip() for x in tokenizer.batch_decode( 249 | output_ids[:,len(prompt):], 250 | skip_special_tokens=True, clean_up_tokenization_spaces=False)] 251 | examples[example_idx]['sentences'] = sentences 252 | 253 | with open(args.output_file, 'w') as fout: 254 | json.dump(examples[:example_idx+1], fout, indent=2) 255 | 256 | 257 | if __name__ == '__main__': 258 | main() 259 | -------------------------------------------------------------------------------- /src/eval_metrics/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | def precook(s, n=4, out=False): 24 | """Takes a string as input and returns an object that can be given to 25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 26 | can take string arguments as well.""" 27 | words = s.split() 28 | counts = defaultdict(int) 29 | for k in range(1,n+1): 30 | for i in range(len(words)-k+1): 31 | ngram = tuple(words[i:i+k]) 32 | counts[ngram] += 1 33 | return (len(words), counts) 34 | 35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 36 | '''Takes a list of reference sentences for a single segment 37 | and returns an object that encapsulates everything that BLEU 38 | needs to know about them.''' 39 | 40 | reflen = [] 41 | maxcounts = {} 42 | for ref in refs: 43 | rl, counts = precook(ref, n) 44 | reflen.append(rl) 45 | for (ngram,count) in counts.items(): 46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 47 | 48 | # Calculate effective reference sentence length. 49 | if eff == "shortest": 50 | reflen = min(reflen) 51 | elif eff == "average": 52 | reflen = float(sum(reflen))/len(reflen) 53 | 54 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 55 | 56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 57 | 58 | return (reflen, maxcounts) 59 | 60 | def cook_test(test, xxx_todo_changeme, eff=None, n=4): 61 | '''Takes a test sentence and returns an object that 62 | encapsulates everything that BLEU needs to know about it.''' 63 | (reflen, refmaxcounts) = xxx_todo_changeme 64 | testlen, counts = precook(test, n, True) 65 | 66 | result = {} 67 | 68 | # Calculate effective reference sentence length. 69 | 70 | if eff == "closest": 71 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 72 | else: ## i.e., "average" or "shortest" or None 73 | result["reflen"] = reflen 74 | 75 | result["testlen"] = testlen 76 | 77 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 78 | 79 | result['correct'] = [0]*n 80 | for (ngram, count) in counts.items(): 81 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 82 | 83 | return result 84 | 85 | class BleuScorer(object): 86 | """Bleu scorer. 87 | """ 88 | 89 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 90 | # special_reflen is used in oracle (proportional effective ref len for a node). 91 | 92 | def copy(self): 93 | ''' copy the refs.''' 94 | new = BleuScorer(n=self.n) 95 | new.ctest = copy.copy(self.ctest) 96 | new.crefs = copy.copy(self.crefs) 97 | new._score = None 98 | return new 99 | 100 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 101 | ''' singular instance ''' 102 | 103 | self.n = n 104 | self.crefs = [] 105 | self.ctest = [] 106 | self.cook_append(test, refs) 107 | self.special_reflen = special_reflen 108 | 109 | def cook_append(self, test, refs): 110 | '''called by constructor and __iadd__ to avoid creating new instances.''' 111 | 112 | if refs is not None: 113 | self.crefs.append(cook_refs(refs)) 114 | if test is not None: 115 | cooked_test = cook_test(test, self.crefs[-1]) 116 | self.ctest.append(cooked_test) ## N.B.: -1 117 | else: 118 | self.ctest.append(None) # lens of crefs and ctest have to match 119 | 120 | self._score = None ## need to recompute 121 | 122 | def ratio(self, option=None): 123 | self.compute_score(option=option) 124 | return self._ratio 125 | 126 | def score_ratio(self, option=None): 127 | '''return (bleu, len_ratio) pair''' 128 | return (self.fscore(option=option), self.ratio(option=option)) 129 | 130 | def score_ratio_str(self, option=None): 131 | return "%.4f (%.2f)" % self.score_ratio(option) 132 | 133 | def reflen(self, option=None): 134 | self.compute_score(option=option) 135 | return self._reflen 136 | 137 | def testlen(self, option=None): 138 | self.compute_score(option=option) 139 | return self._testlen 140 | 141 | def retest(self, new_test): 142 | if type(new_test) is str: 143 | new_test = [new_test] 144 | assert len(new_test) == len(self.crefs), new_test 145 | self.ctest = [] 146 | for t, rs in zip(new_test, self.crefs): 147 | self.ctest.append(cook_test(t, rs)) 148 | self._score = None 149 | 150 | return self 151 | 152 | def rescore(self, new_test): 153 | ''' replace test(s) with new test(s), and returns the new score.''' 154 | 155 | return self.retest(new_test).compute_score() 156 | 157 | def size(self): 158 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 159 | return len(self.crefs) 160 | 161 | def __iadd__(self, other): 162 | '''add an instance (e.g., from another sentence).''' 163 | 164 | if type(other) is tuple: 165 | ## avoid creating new BleuScorer instances 166 | self.cook_append(other[0], other[1]) 167 | else: 168 | assert self.compatible(other), "incompatible BLEUs." 169 | self.ctest.extend(other.ctest) 170 | self.crefs.extend(other.crefs) 171 | self._score = None ## need to recompute 172 | 173 | return self 174 | 175 | def compatible(self, other): 176 | return isinstance(other, BleuScorer) and self.n == other.n 177 | 178 | def single_reflen(self, option="average"): 179 | return self._single_reflen(self.crefs[0][0], option) 180 | 181 | def _single_reflen(self, reflens, option=None, testlen=None): 182 | 183 | if option == "shortest": 184 | reflen = min(reflens) 185 | elif option == "average": 186 | reflen = float(sum(reflens))/len(reflens) 187 | elif option == "closest": 188 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 189 | else: 190 | assert False, "unsupported reflen option %s" % option 191 | 192 | return reflen 193 | 194 | def recompute_score(self, option=None, verbose=0): 195 | self._score = None 196 | return self.compute_score(option, verbose) 197 | 198 | def compute_score(self, option=None, verbose=0): 199 | n = self.n 200 | small = 1e-9 201 | tiny = 1e-15 ## so that if guess is 0 still return 0 202 | bleu_list = [[] for _ in range(n)] 203 | 204 | if self._score is not None: 205 | return self._score 206 | 207 | if option is None: 208 | option = "average" if len(self.crefs) == 1 else "closest" 209 | 210 | self._testlen = 0 211 | self._reflen = 0 212 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 213 | 214 | # for each sentence 215 | for comps in self.ctest: 216 | testlen = comps['testlen'] 217 | self._testlen += testlen 218 | 219 | if self.special_reflen is None: ## need computation 220 | reflen = self._single_reflen(comps['reflen'], option, testlen) 221 | else: 222 | reflen = self.special_reflen 223 | 224 | self._reflen += reflen 225 | 226 | for key in ['guess','correct']: 227 | for k in range(n): 228 | totalcomps[key][k] += comps[key][k] 229 | 230 | # append per image bleu score 231 | bleu = 1. 232 | for k in range(n): 233 | bleu *= (float(comps['correct'][k]) + tiny) \ 234 | /(float(comps['guess'][k]) + small) 235 | bleu_list[k].append(bleu ** (1./(k+1))) 236 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 237 | if ratio < 1: 238 | for k in range(n): 239 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 240 | 241 | #if verbose > 1: 242 | #print comps, reflen 243 | 244 | totalcomps['reflen'] = self._reflen 245 | totalcomps['testlen'] = self._testlen 246 | 247 | bleus = [] 248 | bleu = 1. 249 | for k in range(n): 250 | bleu *= float(totalcomps['correct'][k] + tiny) \ 251 | / (totalcomps['guess'][k] + small) 252 | bleus.append(bleu ** (1./(k+1))) 253 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 254 | if ratio < 1: 255 | for k in range(n): 256 | bleus[k] *= math.exp(1 - 1/ratio) 257 | 258 | #if verbose > 0: 259 | #print totalcomps 260 | #print "ratio:", ratio 261 | 262 | self._score = bleus 263 | return self._score, bleu_list 264 | 265 | #!/usr/bin/env python 266 | # 267 | # File Name : bleu.py 268 | # 269 | # Description : Wrapper for BLEU scorer. 270 | # 271 | # Creation Date : 06-01-2015 272 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 273 | # Authors : Hao Fang and Tsung-Yi Lin 274 | 275 | class Bleu: 276 | def __init__(self, n=4): 277 | # default compute Blue score up to 4 278 | self._n = n 279 | self._hypo_for_image = {} 280 | self.ref_for_image = {} 281 | 282 | def compute_score(self, gts, res): 283 | 284 | assert(gts.keys() == res.keys()) 285 | imgIds = gts.keys() 286 | 287 | bleu_scorer = BleuScorer(n=self._n) 288 | for id in imgIds: 289 | hypo = res[id] 290 | ref = gts[id] 291 | 292 | # Sanity check. 293 | assert(type(hypo) is list) 294 | assert(len(hypo) == 1) 295 | assert(type(ref) is list) 296 | assert(len(ref) >= 1) 297 | 298 | bleu_scorer += (hypo[0], ref) 299 | 300 | #score, scores = bleu_scorer.compute_score(option='shortest') 301 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1) 302 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 303 | 304 | # return (bleu, bleu_info) 305 | return score, scores 306 | 307 | def method(self): 308 | return "Bleu" 309 | -------------------------------------------------------------------------------- /src/hmm_model.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | # check whether x ends with y 7 | def end_with(x, y): 8 | if len(y) > len(x): 9 | return False 10 | return x[-len(y):] == y 11 | 12 | 13 | def remove_from_cnf(cnf, keyword, fix_order=False): 14 | if fix_order: 15 | if len(cnf) == 0: 16 | return cnf 17 | if keyword in cnf[0]: 18 | return cnf[1:] 19 | return cnf 20 | else: 21 | for i, clause in enumerate(cnf): 22 | if keyword in clause: 23 | return cnf[:i] + cnf[i+1:] 24 | return cnf 25 | 26 | 27 | def update_cnf(prefix, cnf, sep_tokens_set=None, fix_order=False): 28 | keywords = set([keyword for clause in cnf for keyword in clause]) 29 | for i in range(0, len(prefix)-1): 30 | for j in range(i+1, len(prefix)): 31 | if (prefix[i: j] in keywords) and (prefix[j] in sep_tokens_set): 32 | cnf = remove_from_cnf(cnf, prefix[i:j], fix_order=fix_order) 33 | break 34 | return cnf 35 | 36 | 37 | # checks whether the prefix ends with some partial keywords 38 | # and returns all possible next_seqs and next_cnfs 39 | def case_analysis(prefix, keywords, cnf, hmm_seq_len_left, fix_order=False): 40 | end_with_keyword = False 41 | end_with_partial_keyword = False 42 | 43 | next_seqs, next_cnfs = [], [] 44 | for keyword in keywords: 45 | for j in range(1, len(keyword)): 46 | # append next_seqs that finish partial keywords 47 | if end_with(prefix, keyword[:j]) and len(keyword[j:]) <= hmm_seq_len_left: 48 | next_seqs.append(keyword[j:]) 49 | next_cnfs.append(remove_from_cnf(cnf, keyword, fix_order=fix_order)) 50 | end_with_with_parital_keyword = True 51 | if end_with(prefix, keyword): 52 | end_with_keyword = True 53 | 54 | if end_with_keyword or (not end_with_partial_keyword): 55 | for i in range(0, len(prefix)): 56 | if prefix[i:] in keywords: 57 | cnf = remove_from_cnf(cnf, prefix[i:], fix_order=fix_order) 58 | # keywords_cnf = set([keyword for clause in cnf for keyword in clause]) 59 | for keyword in keywords: 60 | if len(keyword) <= hmm_seq_len_left: 61 | next_seqs.append(keyword) 62 | next_cnfs.append(remove_from_cnf(cnf, keyword, fix_order=fix_order)) 63 | 64 | return next_seqs, next_cnfs, end_with_keyword, end_with_partial_keyword 65 | 66 | 67 | class HMM(nn.Module): 68 | def __init__(self, weights_file, sep_tokens=[]): 69 | super().__init__() 70 | 71 | assert(weights_file[-2:] == 'th') 72 | 73 | d = torch.load(weights_file) 74 | alpha, beta, gamma = d['alpha'], d['beta'], d['gamma'] 75 | 76 | alpha = torch.log_softmax(alpha, dim=1) 77 | beta = torch.log_softmax(beta, dim=1) 78 | gamma = torch.log_softmax(gamma, dim=0) 79 | 80 | self.alpha = nn.Parameter(alpha, requires_grad=False) 81 | self.beta = nn.Parameter(beta, requires_grad=False) 82 | self.gamma = nn.Parameter(gamma, requires_grad=False) 83 | 84 | self.cache = {} 85 | 86 | self.cache['sep_tokens'] = set(sep_tokens) 87 | 88 | 89 | def forward(self, x): 90 | device = x.device 91 | alpha, beta, gamma = self.alpha, self.beta, self.gamma 92 | 93 | batch_size, seq_len = x.shape 94 | hidden_states, vocab_size = beta.shape 95 | 96 | y = torch.zeros(batch_size, hidden_states).to(device) 97 | for t in range(seq_len - 1, -1, -1): 98 | if t != seq_len - 1: 99 | y = torch.logsumexp(alpha.unsqueeze(0) + y.unsqueeze(1), dim=2) 100 | inputs = beta[torch.arange(hidden_states).unsqueeze(0).to(device), 101 | x[:, t].unsqueeze(-1)] # batch_size * hidden_states 102 | y = y + inputs 103 | y = torch.logsumexp(gamma.unsqueeze(0) + y, dim=1) 104 | 105 | return y 106 | 107 | 108 | def initialize_cache(self, hmm_seq_len, cnf0, 109 | prompt_tokens=(), batch_size=256, fix_order=False): 110 | 111 | self.cache = {'sep_tokens': self.cache['sep_tokens']} 112 | 113 | torch.cuda.empty_cache() 114 | 115 | device = self.alpha.device 116 | inf, eos_token_id = 1e10, 50256 117 | hidden_states, vocab_size = self.beta.shape 118 | alpha, beta, gamma = self.alpha, self.beta, self.gamma 119 | 120 | beta = beta.clone() 121 | beta[:, 796:797] = -inf 122 | 123 | # compute start_tokens # beginning-of-keyword tokens 124 | sep_tokens = self.cache['sep_tokens'] 125 | start_tokens = set([keyword[0] for clause in cnf0 for keyword in clause]) 126 | self.cache['start_tokens'] = start_tokens 127 | 128 | sep_non_start_tokens = list(sep_tokens.difference(start_tokens)) 129 | beta_sep_non_start_mars = torch.logsumexp(beta[:, sep_non_start_tokens], dim=1) # hidden_states 130 | 131 | non_sep_tokens = [token for token in range(0, vocab_size) if token not in sep_tokens] 132 | beta_non_sep_mars = torch.logsumexp(beta[:, non_sep_tokens], dim=1) # hidden_states 133 | 134 | non_start_tokens = [token for token in range(0, vocab_size) if token not in start_tokens] 135 | beta_non_start_mars = torch.logsumexp(beta[:, non_start_tokens], dim=1) # hidden_states 136 | 137 | # initialize cache A 138 | A_cache = {(): self.gamma.clone()} 139 | self.cache['A'] = A_cache 140 | for i in range(1, len(prompt_tokens)): 141 | self.update_A([prompt_tokens[:i]]) 142 | 143 | # initialize cache C 144 | C_cache = {} 145 | C = torch.eye(hidden_states, device=device) 146 | 147 | alpha_exp, beta_exp = torch.exp(alpha), torch.exp(beta) 148 | keywords = list(set([keyword for clause in cnf0 for keyword in clause])) 149 | max_keyword_len = max([len(keyword) for keyword in keywords]) 150 | 151 | C = C.unsqueeze(0) # 1 * hidden_states * hidden_states 152 | for suffix_len in range(1, max_keyword_len+1): 153 | input_probs = [beta_exp[:, x[-suffix_len]] if len(x) >= suffix_len 154 | else torch.ones(hidden_states, device=device) for x in keywords] 155 | input_probs = torch.stack(input_probs, dim=0) # len(keywords) * hidden_states 156 | C = input_probs.unsqueeze(-1) * torch.matmul(alpha_exp.unsqueeze(0), C) 157 | for i, keyword in enumerate(keywords): 158 | if len(keyword) >= suffix_len: 159 | C_cache[keyword[-suffix_len:]] = torch.log(C[i]) 160 | 161 | # initialize cache B and B_sep 162 | B_cache, B_sep_cache = {}, {} 163 | for subset_size in range(0, len(cnf0)+1): 164 | for subset in itertools.combinations(cnf0, subset_size): 165 | B_cache[tuple(subset)] = -inf * torch.ones(hmm_seq_len+1, hidden_states, device=device) 166 | B_sep_cache[tuple(subset)] = -inf * torch.ones(hmm_seq_len+1, hidden_states, device=device) 167 | B_cache[()][hmm_seq_len, :] = 0.0 168 | B_sep_cache[()][hmm_seq_len, :] = 0.0 169 | 170 | all_subsets = [subset for subset in B_cache] 171 | subset_batch_size = max(batch_size // len(keywords), 1) 172 | for t in range(hmm_seq_len-1, -1, -1): 173 | for subset_batch_idx in range(0, len(all_subsets), subset_batch_size): 174 | subset_batch_size_ = min(subset_batch_size, len(all_subsets) - subset_batch_idx) 175 | subset_batch = all_subsets[subset_batch_idx: subset_batch_idx + subset_batch_size_] 176 | 177 | # case 1: first token is sep and start 178 | C, B = [], [] 179 | for subset in subset_batch: 180 | C_subset, B_subset = [], [] 181 | for keyword in keywords: 182 | if t + len(keyword) <= hmm_seq_len: 183 | next_subset = remove_from_cnf(subset, keyword, fix_order=fix_order) 184 | C_subset.append(C_cache[keyword]) 185 | B_subset.append(B_sep_cache[next_subset][t+len(keyword), :]) 186 | else: # probability 0 187 | C_subset.append(torch.eye(hidden_states, device=device)) 188 | B_subset.append(-inf * torch.ones(hidden_states, device=device)) 189 | C_subset = torch.stack(C_subset, dim=0) # len(keywords) * hidden_states * hidden_states 190 | B_subset = torch.stack(B_subset, dim=0) # len(keywords) * hidden_states 191 | C.append(C_subset) 192 | B.append(B_subset) 193 | 194 | C = torch.stack(C, dim=0) # subset_num * len(keywords) * hidden_states * hidden_states 195 | B = torch.stack(B, dim=0) # subset_num * len(keywords) * hidden_states 196 | C += B.unsqueeze(2) 197 | CB = torch.logsumexp(C, dim=3) # subset_num * len(keywords) * hidden_states 198 | CB = torch.logsumexp(CB, dim=1) # subset_num * hidden_states 199 | 200 | B = torch.stack([B_cache[subset][t+1, :] for subset in subset_batch], dim=0) # subset_num * hidden_states 201 | B = torch.logsumexp(alpha.unsqueeze(0) + B.unsqueeze(1), dim=2) # subset_num * hidden_states 202 | 203 | B1 = beta_sep_non_start_mars.unsqueeze(0) + B # subset_num * hidden_states 204 | B2 = beta_non_sep_mars.unsqueeze(0) + B # subset_num * hidden_states 205 | B_sep = torch.logaddexp(CB, B1) 206 | B = torch.logaddexp(B_sep, B2) 207 | 208 | for i, subset in enumerate(subset_batch): 209 | B_cache[subset][t, :] = B[i] 210 | B_sep_cache[subset][t, :] = B_sep[i] 211 | 212 | self.cache['A'], self.cache['C'] = A_cache, C_cache 213 | self.cache['B'], self.cache['B_sep'] = B_cache, B_sep_cache 214 | 215 | 216 | def update_A(self, prefixes): 217 | A_cache = self.cache['A'] 218 | A = torch.stack([A_cache[prefix[:-1]] for prefix in prefixes], dim=0) # len(prefixes) * hidden_states 219 | log_probs = torch.stack([self.beta[:, prefix[-1]] for prefix in prefixes], dim=0) # len(prefixes) * hidden_states 220 | alpha_t = torch.transpose(self.alpha, 0, 1).unsqueeze(0) # 1 * hidden_states * hidden_states 221 | A = torch.logsumexp(alpha_t + (A + log_probs).unsqueeze(1), dim=2) # len(prefixes) * hidden_states 222 | 223 | for i, prefix in enumerate(prefixes): 224 | A_cache[prefix] = A[i] 225 | 226 | 227 | # compute logits for next_token: 228 | # return Pr(prefix, next_token, cnf), Pr(prefix, next_token) 229 | # here we can assume all prefixes are of the same length 230 | def compute_logits(self, prefixes, cnf0, 231 | seq_len, prompt_len, seq2seq, early_stop=False, fix_order=False): 232 | inf = 1e10 233 | 234 | device = self.alpha.device 235 | neginf_cuda = -inf * torch.ones(1, device=device) 236 | eos_token_id = 50256 237 | hidden_states, vocab_size = self.beta.shape 238 | alpha, beta, gamma = self.alpha, self.beta, self.gamma 239 | 240 | # beta = beta.clone() 241 | # beta[:, 796:797] = neginf_cuda 242 | 243 | keywords0 = set([keyword for clause in cnf0 for keyword in clause]) 244 | 245 | sep_tokens, start_tokens = self.cache['sep_tokens'], self.cache['start_tokens'] 246 | 247 | sep_non_start_tokens_set = sep_tokens.difference(start_tokens) 248 | sep_non_start_mask = torch.tensor([(0.0 if token in sep_non_start_tokens_set else -inf) 249 | for token in range(0, vocab_size)], device=device).unsqueeze(0) 250 | beta_sep_non_start = beta + sep_non_start_mask 251 | 252 | non_start_mask = torch.tensor([(-inf if token in start_tokens else 0.0) 253 | for token in range(0, vocab_size)], device=device).unsqueeze(0) 254 | beta_non_start = beta + non_start_mask 255 | aib = torch.zeros(hidden_states, vocab_size, device=device) 256 | 257 | A_cache, C_cache = self.cache['A'], self.cache['C'] 258 | B_cache, B_sep_cache = self.cache['B'], self.cache['B_sep'] 259 | 260 | if seq2seq == 1: 261 | prefixes = [prefix[prompt_len:] for prefix in prefixes] 262 | 263 | prefix_len = len(prefixes[0]) 264 | if prefix_len > 0: 265 | self.update_A(prefixes) 266 | 267 | logits_alpha, logits_unconditioned = [], [] 268 | for prefix in prefixes: 269 | hmm_seq_len_left = seq_len - prefix_len 270 | if seq2seq == 2: 271 | prefix_non_prompt = prefix[prompt_len:] 272 | else: 273 | prefix_non_prompt = prefix 274 | 275 | # remove clauses that are already satisfied by prefix 276 | cnf = update_cnf(prefix_non_prompt, cnf0, 277 | sep_tokens_set=sep_tokens, fix_order=fix_order) 278 | 279 | if early_stop and len(cnf) == 0: 280 | logits_alpha.append(torch.zeros(vocab_size, device=device)) 281 | logits_unconditioned.append(torch.zeros(vocab_size, device=device)) 282 | continue 283 | 284 | # case analysis 285 | next_seqs, next_cnfs, end_with_keyword, end_with_partial_keyword = case_analysis( 286 | prefix_non_prompt, keywords0, cnf, hmm_seq_len_left, fix_order=fix_order) 287 | 288 | logits = -inf * torch.ones(vocab_size, device=device) 289 | 290 | if len(next_seqs) > 0: 291 | A = A_cache[prefix] # hidden_states 292 | C, B = [], [] 293 | for next_seq, next_cnf in zip(next_seqs, next_cnfs): 294 | C.append(C_cache[next_seq]) 295 | B.append(B_sep_cache[next_cnf][prefix_len+len(next_seq), :]) 296 | C = torch.stack(C, dim=0) # len(next_seqs) * hidden_states * hidden_states 297 | B = torch.stack(B, dim=0) # len(next_seqs) * hidden_states 298 | 299 | log_probs = torch.logsumexp(A.unsqueeze(0) + torch.logsumexp(C + B.unsqueeze(1), dim=2), dim=1) # len(next_seqs) 300 | 301 | for i, seq in enumerate(next_seqs): 302 | logits[seq[0]:seq[0]+1] = torch.logaddexp(logits[seq[0]:seq[0]+1], log_probs[i:i+1]) 303 | 304 | if end_with_keyword or (not end_with_partial_keyword): 305 | inputs = beta_sep_non_start if end_with_keyword else beta_non_start # hidden_states * vocab_size 306 | if end_with_keyword: 307 | for i in range(0, len(prefix_non_prompt)): 308 | if prefix_non_prompt[i:] in keywords0: 309 | cnf = remove_from_cnf(cnf, prefix_non_prompt[i:], fix_order=fix_order) 310 | 311 | a = A_cache[prefix] # hidden_states 312 | b = B_cache[cnf][prefix_len+1, :] # hidden_states 313 | aib = a.unsqueeze(-1) + inputs + torch.logsumexp(alpha + b.unsqueeze(0), dim=1).unsqueeze(-1) 314 | aib = torch.logsumexp(aib, dim=0) 315 | logits = torch.logaddexp(logits, aib) 316 | 317 | if len(cnf) > 0: 318 | logits[eos_token_id:eos_token_id+1] += -inf * torch.ones(1, device=device) 319 | 320 | logits[796:797] = neginf_cuda 321 | logits_alpha.append(logits) 322 | 323 | logits_ = torch.logsumexp(A_cache[prefix].unsqueeze(-1) + beta, dim=0) 324 | logits_unconditioned.append(logits_) 325 | 326 | logits_alpha = torch.stack(logits_alpha, dim=0) 327 | logits_unconditioned = torch.stack(logits_unconditioned, dim=0) 328 | 329 | return logits_alpha, logits_unconditioned 330 | --------------------------------------------------------------------------------