├── .gitignore ├── GPT2base.py ├── LCG ├── .gitignore ├── bleu.py ├── calc_incrate.py ├── data │ ├── ._.DS_Store │ └── dataloader.py ├── eval_kit │ ├── .gitignore │ ├── README.md │ ├── example-inputs │ │ └── README.md │ ├── measure_scores.py │ ├── metrics │ │ ├── __init__.py │ │ └── pymteval.py │ ├── mteval │ │ └── mteval-v13a-sig.pl │ ├── pycocoevalcap │ │ ├── __init__.py │ │ ├── bleu │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ ├── bleu.py │ │ │ └── bleu_scorer.py │ │ ├── cider │ │ │ ├── __init__.py │ │ │ ├── cider.py │ │ │ └── cider_scorer.py │ │ ├── eval.py │ │ ├── meteor │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ │ └── paraphrase-en.gz │ │ │ ├── meteor-1.5.jar │ │ │ └── meteor.py │ │ ├── rouge │ │ │ ├── __init__.py │ │ │ └── rouge.py │ │ └── tokenizer │ │ │ ├── __init__.py │ │ │ ├── ptbtokenizer.py │ │ │ └── stanford-corenlp-3.4.1.jar │ └── pycocotools │ │ ├── __init__.py │ │ └── coco.py ├── finetune.py ├── generate_testset.py ├── measure_scores.py ├── metrics │ ├── __init__.py │ └── pymteval.py ├── models │ ├── ._.DS_Store │ └── decoder │ │ └── modeling_cdgpt.py ├── mteval │ └── mteval-v13a-sig.pl ├── pretrain.py ├── pycocoevalcap │ ├── __init__.py │ ├── bleu │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── bleu.py │ │ └── bleu_scorer.py │ ├── cider │ │ ├── __init__.py │ │ ├── cider.py │ │ └── cider_scorer.py │ ├── eval.py │ ├── meteor │ │ ├── __init__.py │ │ ├── data │ │ │ └── paraphrase-en.gz │ │ ├── meteor-1.5.jar │ │ └── meteor.py │ ├── rouge │ │ ├── __init__.py │ │ └── rouge.py │ └── tokenizer │ │ ├── __init__.py │ │ ├── ptbtokenizer.py │ │ ├── stanford-corenlp-3.4.1.jar │ │ ├── tmp6fqi5rw8 │ │ └── tmpelkdsh54 ├── pycocotools │ ├── __init__.py │ └── coco.py └── readme.md ├── LICENSE ├── README.md ├── bleu.py ├── check_diff.py ├── constant.py ├── constraints.py ├── dump └── MT │ ├── fisher_10k_8-4-0.80.log │ └── fisher_base.log ├── fudge_related.py ├── marianMT.py ├── neural_constr.py ├── requirements.txt ├── train_MT.py └── train_rc.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.txt 3 | *.json 4 | *.pyc 5 | .DS_Store -------------------------------------------------------------------------------- /GPT2base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Model, GPT2Config 3 | from transformers import OpenAIGPTTokenizer, OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTPreTrainedModel, OpenAIGPTConfig 4 | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, CausalLMOutput 5 | 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | from packaging import version 11 | from torch import nn 12 | from torch.nn.functional import softmax 13 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, Sigmoid, LogSigmoid 14 | 15 | import numpy as np 16 | 17 | class ConstrainedLM(GPT2LMHeadModel): 18 | def __init__(self, config): 19 | super().__init__(config) 20 | self.model_rc = nn.Linear(config.n_embd, config.vocab_size, bias=True) 21 | self.model_rc_transformer = GPT2Model.from_pretrained("gpt2") 22 | self.constraint_factor = 1.0 23 | self.temperature = 1.0 24 | self.use_rc_transformer = False 25 | self.log_sigmoid_fct = LogSigmoid() 26 | 27 | def set_model_rc_transformer(self, model_rc_transformer): 28 | self.model_rc_transformer = model_rc_transformer 29 | 30 | def set_constraint_factor(self, factor): 31 | self.constraint_factor = factor 32 | 33 | def set_use_rc_transformer(self, use_rc_transformer): 34 | self.use_rc_transformer = use_rc_transformer 35 | 36 | def set_temperature(self, temperature): 37 | self.temperature = temperature 38 | 39 | def get_model_rc(self): 40 | return self.model_rc 41 | 42 | def set_model_rc(self, new_model_rc): 43 | self.model_rc = new_model_rc 44 | 45 | def forward( 46 | self, 47 | input_ids=None, 48 | past_key_values=None, 49 | attention_mask=None, 50 | token_type_ids=None, 51 | position_ids=None, 52 | head_mask=None, 53 | inputs_embeds=None, 54 | encoder_hidden_states=None, 55 | encoder_attention_mask=None, 56 | labels=None, 57 | rc_labels=None, 58 | rc_weights=None, 59 | use_cache=None, 60 | use_temperature=False, 61 | output_attentions=None, 62 | output_hidden_states=None, 63 | return_dict=None, 64 | ): 65 | 66 | #return_dict is always None 67 | #copy from huggingface docs 68 | 69 | transformer_outputs = self.transformer( 70 | input_ids, 71 | past_key_values=past_key_values, 72 | attention_mask=attention_mask, 73 | token_type_ids=token_type_ids, 74 | position_ids=position_ids, 75 | head_mask=head_mask, 76 | inputs_embeds=inputs_embeds, 77 | encoder_hidden_states=encoder_hidden_states, 78 | encoder_attention_mask=encoder_attention_mask, 79 | use_cache=use_cache, 80 | output_attentions=output_attentions, 81 | output_hidden_states=output_hidden_states, 82 | return_dict=return_dict, 83 | ) 84 | hidden_states = transformer_outputs.last_hidden_state 85 | 86 | # if past_key_values is not None and head_mask is not None: 87 | # print ("Shape Check:") 88 | # print (input_ids.shape) 89 | # print (attention_mask.shape) 90 | # print (len(past_key_values)) 91 | # print (len(past_key_values[0])) 92 | # print (past_key_values[0][0].shape) 93 | # print (head_mask.shape) 94 | # print (head_mask.sum()) 95 | 96 | 97 | if self.model_parallel: 98 | torch.cuda.set_device(self.transformer.first_device) 99 | hidden_states = hidden_states.to(self.lm_head.weight.device) 100 | 101 | lm_logits = self.lm_head(hidden_states) 102 | if use_temperature: 103 | lm_logits = lm_logits * self.temperature 104 | 105 | if self.constraint_factor == 0.0: 106 | pred_logits = lm_logits 107 | else: 108 | if self.use_rc_transformer: 109 | rc_hidden_states = self.model_rc_transformer( 110 | input_ids, 111 | past_key_values=past_key_values, 112 | attention_mask=attention_mask, 113 | token_type_ids=token_type_ids, 114 | position_ids=position_ids, 115 | head_mask=head_mask, 116 | inputs_embeds=inputs_embeds, 117 | encoder_hidden_states=encoder_hidden_states, 118 | encoder_attention_mask=encoder_attention_mask, 119 | use_cache=use_cache, 120 | output_attentions=output_attentions, 121 | output_hidden_states=output_hidden_states, 122 | return_dict=return_dict, 123 | ).last_hidden_state 124 | else: 125 | rc_hidden_states = hidden_states 126 | constr_logits = self.log_sigmoid_fct(self.model_rc(rc_hidden_states)) * self.constraint_factor 127 | pred_logits = lm_logits + constr_logits 128 | 129 | loss = None 130 | if labels is not None: 131 | # Shift so that tokens < n predict n 132 | shift_logits = lm_logits[..., :-1, :].contiguous() 133 | shift_labels = labels[..., 1:].contiguous() 134 | loss_fct = CrossEntropyLoss() 135 | # Flatten the tokens 136 | if rc_labels is not None: 137 | rc_labels = rc_labels.long() 138 | shift_logits = shift_logits.index_select(0, torch.where(rc_labels == 1)[0]) 139 | shift_labels = shift_labels.index_select(0, torch.where(rc_labels == 1)[0]) 140 | 141 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 142 | 143 | elif rc_labels is not None: 144 | length = input_ids.shape[1] 145 | if len(rc_labels.shape) == 1: 146 | rc_labels = rc_labels.unsqueeze(1) 147 | for i in range(length - 1): 148 | pred_logits = torch.gather(constr_logits[:, i, :], 1, input_ids[:, i + 1].unsqueeze(1)) 149 | 150 | if rc_weights is not None: 151 | weights = softmax(rc_weights * (1.0 - self.temperature), dim=0) 152 | loss_fct = BCEWithLogitsLoss(weight= 153 | (attention_mask[:, i + 1] * weights).unsqueeze(1), reduction='sum') 154 | else: 155 | loss_fct = BCEWithLogitsLoss(weight=attention_mask[:, i + 1].unsqueeze(1)) 156 | 157 | cur_loss = loss_fct(pred_logits, rc_labels) 158 | if loss is not None: 159 | loss += cur_loss 160 | else: 161 | loss = cur_loss 162 | 163 | return CausalLMOutputWithCrossAttentions( 164 | loss=loss, 165 | logits=pred_logits, 166 | past_key_values=transformer_outputs.past_key_values, 167 | hidden_states=transformer_outputs.hidden_states, 168 | attentions=transformer_outputs.attentions, 169 | cross_attentions=transformer_outputs.cross_attentions, 170 | ) 171 | 172 | if __name__ == "__main__": 173 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 174 | model = ConstrainedLM.from_pretrained("gpt2") 175 | 176 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 177 | print(sum([np.prod(p.size()) for p in model_parameters])) 178 | 179 | sentence_prefix = "I play piano every" 180 | 181 | input_ids = tokenizer.encode( 182 | sentence_prefix, 183 | add_special_tokens=False, 184 | return_tensors="pt", 185 | ) 186 | 187 | #model(input_ids=input_ids) 188 | 189 | for i in range(5): 190 | output_ids = model.generate( 191 | input_ids=input_ids, 192 | do_sample=True, 193 | max_length=30, # desired output sentence length 194 | pad_token_id=model.config.eos_token_id, 195 | )[0].tolist() 196 | 197 | generated_text = tokenizer.decode( 198 | output_ids, 199 | clean_up_tokenization_spaces=True) 200 | 201 | print(generated_text) 202 | 203 | -------------------------------------------------------------------------------- /LCG/.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.txt 3 | *.json 4 | *.pyc 5 | .DS_Store -------------------------------------------------------------------------------- /LCG/bleu.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pprint 4 | import math 5 | import re 6 | from transformers import BasicTokenizer 7 | c=0 8 | r=0 9 | 10 | split = BasicTokenizer(never_split=["", "", ""]) 11 | 12 | def getCAndR(candidateSentence,referenceSentences): 13 | global c 14 | global r 15 | candidateSentence = list(candidateSentence) 16 | referenceSentences = [list(item) for item in referenceSentences] 17 | referenceCount=[] 18 | referenceLength=[] 19 | c+=len(candidateSentence) 20 | for index3 in range(0,len(referenceSentences)): 21 | referenceCount.append(abs(len(referenceSentences[index3])-len(candidateSentence))) 22 | referenceLength.append(len(referenceSentences[index3])) 23 | r+=referenceLength[referenceCount.index(min(referenceCount))] 24 | 25 | def getBP(): 26 | if c>=r: 27 | return 1 28 | else: 29 | return math.exp(1-r/float(c)) 30 | 31 | 32 | 33 | def getFiles(candidatePath,referencePath): 34 | candidatefile=candidatePath 35 | referencefiles=[] 36 | if os.path.isfile(referencePath): 37 | referencefiles.append(referencePath) 38 | else: 39 | referencefiles=os.listdir(referencePath) 40 | for i in range(0,len(referencefiles)): 41 | referencefiles[i]=referencePath+"/"+referencefiles[i] 42 | return candidatefile,referencefiles 43 | 44 | def readFiles(candidatefile,referencefiles): 45 | candidateData=[] 46 | referencesData=[] 47 | idx = 0 48 | exclude = set() 49 | with open(candidatefile) as fp: 50 | for line in fp: 51 | splitted = split.tokenize(line) 52 | candidateData.append(splitted) 53 | idx += 1 54 | for i in range(0,len(referencefiles)): 55 | temp=[] 56 | idx = 0 57 | with open(referencefiles[i]) as fp: 58 | for line in fp: 59 | if idx in exclude: 60 | idx += 1 61 | continue 62 | temp.append(split.tokenize(line)) 63 | idx += 1 64 | referencesData.append(temp) 65 | return candidateData,referencesData 66 | 67 | def uniGramDictionary(sentence): 68 | dictionary={} 69 | #sentence = sentence[0] 70 | sentence = list(sentence) 71 | i = 0 72 | while i < len(sentence): 73 | unigram=sentence[i] 74 | #print "unigram:", unigram 75 | if unigram in dictionary: 76 | dictionary[unigram]+=1 77 | else: 78 | dictionary[unigram]=1 79 | i += 1 80 | return dictionary 81 | def biGramDictionary(sentence): 82 | dictionary={} 83 | #sentence = sentence[0] 84 | i = 0 85 | sentence = list(sentence) 86 | while i < len(sentence): 87 | if i+1 >= len(sentence): 88 | break 89 | bigram="".join(sentence[i])+" "+"".join(sentence[i+1]) 90 | #print "bigram:", bigram 91 | if bigram in dictionary: 92 | dictionary[bigram]+=1 93 | else: 94 | dictionary[bigram]=1 95 | i += 1 96 | return dictionary 97 | def triGramDictionary(sentence): 98 | dictionary={} 99 | #sentence = sentence[0] 100 | i = 0 101 | sentence = list(sentence) 102 | while i < len(sentence): 103 | if i+2 >= len(sentence): 104 | break 105 | trigram="".join(sentence[i])+" "+"".join(sentence[i+1])+" "+"".join(sentence[i+2]) 106 | #print "trigram:", trigram 107 | if trigram in dictionary: 108 | dictionary[trigram]+=1 109 | else: 110 | dictionary[trigram]=1 111 | i += 1 112 | return dictionary 113 | def quadrupleGramDictionary(sentence): 114 | dictionary={} 115 | #sentence = sentence[0] 116 | i = 0 117 | sentence = list(sentence) 118 | while i < len(sentence): 119 | if i+3 >= len(sentence): 120 | break 121 | quadruplegram="".join(sentence[i])+" "+"".join(sentence[i+1])+" "+"".join(sentence[i+2])+" "+"".join(sentence[i+3]) 122 | #print "quadruplegram:", quadruplegram 123 | if quadruplegram in dictionary: 124 | dictionary[quadruplegram]+=1 125 | else: 126 | dictionary[quadruplegram]=1 127 | i += 1 128 | 129 | return dictionary 130 | def uniGram(candidateSentence,referenceSentences): 131 | referenceDict=[] 132 | reference=[] 133 | #candidateSentence=candidateSentence.lower().split() 134 | candidateSentence=list(filter(None,candidateSentence)) 135 | candidateDict = uniGramDictionary(candidateSentence) 136 | count=0 137 | for line in referenceSentences: 138 | #line=line.lower().split() 139 | line=list(filter(None,line)) 140 | reference.append(line) 141 | referenceDict.append(uniGramDictionary(line)) 142 | getCAndR(candidateSentence,reference) 143 | for word in candidateDict: 144 | #print "word in candidateDict:", word 145 | maxRefIndex=0 146 | for index2 in range(0,len(referenceDict)): 147 | if word in referenceDict[index2]: 148 | maxRefIndex=max(maxRefIndex,referenceDict[index2][word]) 149 | 150 | count+=min(candidateDict[word],maxRefIndex) 151 | #print count 152 | sumngram=0 153 | for values in candidateDict.values(): 154 | sumngram+=values 155 | return count,sumngram 156 | 157 | def biGram(candidateSentence,referenceSentences): 158 | referenceDict=[] 159 | #candidateSentence=candidateSentence.lower().split() 160 | candidateSentence=filter(None,candidateSentence) 161 | candidateDict = biGramDictionary(candidateSentence) 162 | count=0 163 | for line in referenceSentences: 164 | #line=line.lower().split() 165 | line=filter(None,line) 166 | referenceDict.append(biGramDictionary(line)) 167 | for word in candidateDict: 168 | maxRefIndex=0 169 | for index2 in range(0,len(referenceDict)): 170 | if word in referenceDict[index2]: 171 | maxRefIndex=max(maxRefIndex,referenceDict[index2][word]) 172 | count+=min(candidateDict[word],maxRefIndex) 173 | sumngram=0 174 | for values in candidateDict.values(): 175 | sumngram+=values 176 | return count,sumngram 177 | 178 | def triGram(candidateSentence,referenceSentences): 179 | referenceDict=[] 180 | #candidateSentence=candidateSentence.lower().split() 181 | candidateSentence=filter(None,candidateSentence) 182 | candidateDict = triGramDictionary(candidateSentence) 183 | count=0 184 | for line in referenceSentences: 185 | #line=line.lower().split() 186 | line=filter(None,line) 187 | referenceDict.append(triGramDictionary(line)) 188 | for word in candidateDict: 189 | maxRefIndex=0 190 | for index2 in range(0,len(referenceDict)): 191 | if word in referenceDict[index2]: 192 | maxRefIndex=max(maxRefIndex,referenceDict[index2][word]) 193 | 194 | count+=min(candidateDict[word],maxRefIndex) 195 | sumngram=0 196 | for values in candidateDict.values(): 197 | sumngram+=values 198 | return count,sumngram 199 | 200 | def quadrupleGram(candidateSentence,referenceSentences): 201 | referenceDict=[] 202 | #candidateSentence=candidateSentence.lower().split() 203 | candidateSentence=filter(None,candidateSentence) 204 | candidateDict = quadrupleGramDictionary(candidateSentence) 205 | count=0 206 | for line in referenceSentences: 207 | #line=line.lower().split() 208 | line=filter(None,line) 209 | referenceDict.append(quadrupleGramDictionary(line)) 210 | for word in candidateDict: 211 | maxRefIndex=0 212 | for index2 in range(0,len(referenceDict)): 213 | if word in referenceDict[index2]: 214 | maxRefIndex=max(maxRefIndex,referenceDict[index2][word]) 215 | count+=min(candidateDict[word],maxRefIndex) 216 | sumngram=0 217 | for values in candidateDict.values(): 218 | sumngram+=values 219 | return count,sumngram 220 | 221 | def getModifiedPrecision(candidateData,referencesData): 222 | global c 223 | global r 224 | uniNum=0 225 | uniDen=0 226 | biNum=0 227 | biDen=0 228 | triNum=0 229 | triDen=0 230 | quadrupleNum=0 231 | quadrupleDen=0 232 | for index in range(0,len(candidateData)): 233 | referenceSentences=[] 234 | candidateSentence=candidateData[index] 235 | for index1 in range(0,len(referencesData)): 236 | referenceSentences.append(referencesData[index1][index]) 237 | #print candidateSentence 238 | #print referenceSentences[0] 239 | uniClipCount,uniCount=uniGram(candidateSentence,referenceSentences) 240 | uniNum+=uniClipCount 241 | uniDen+=uniCount 242 | biClipCount,biCount=biGram(candidateSentence,referenceSentences) 243 | biNum+=biClipCount 244 | biDen+=biCount 245 | triClipCount,triCount=triGram(candidateSentence,referenceSentences) 246 | triNum+=triClipCount 247 | triDen+=triCount 248 | quadrupleClipCount,quadrupleCount=quadrupleGram(candidateSentence,referenceSentences) 249 | quadrupleNum+=quadrupleClipCount 250 | quadrupleDen+=quadrupleCount 251 | print (uniNum,uniDen) 252 | print (biNum,biDen) 253 | print (triNum,triDen) 254 | print (quadrupleNum,quadrupleDen) 255 | if uniDen > 0: 256 | unigram1=uniNum/float(uniDen) 257 | else: 258 | unigram1 = 0 259 | if biDen > 0: 260 | bigram1=biNum/float(biDen) 261 | else: 262 | bigram1 = 0 263 | if triDen > 0 : 264 | trigram1=triNum/float(triDen) 265 | else: 266 | trigram1 = 0 267 | if quadrupleDen > 0: 268 | quadruplegram1=quadrupleNum/float(quadrupleDen) 269 | else: 270 | quadruplegram1 = 0 271 | 272 | print (unigram1,bigram1,trigram1,quadruplegram1) 273 | bleu1 = 0 274 | bleu2 = 0 275 | bleu3 = 0 276 | bleu4 = 0 277 | if unigram1+bigram1+trigram1+quadruplegram1 == 0: 278 | bleu1 = 0 279 | bleu2 = 0 280 | bleu3 = 0 281 | bleu4 = 0 282 | else: 283 | if unigram1 > 0: 284 | print ("in 1",getBP()) 285 | bleu1=getBP()*math.exp(math.log(unigram1)) 286 | if unigram1 > 0 and bigram1 > 0 : 287 | print( "in 2",getBP()) 288 | bleu2=getBP()*math.exp(0.5*math.log(unigram1)+0.5*math.log(bigram1)) 289 | if unigram1 > 0 and bigram1 > 0 and trigram1 > 0 : 290 | print( "in 3",getBP()) 291 | bleu3=getBP()*math.exp((1/3.0)*math.log(unigram1)+(1/3.0)*math.log(bigram1)+(1/3.0)*math.log(trigram1)) 292 | if unigram1 >0 and bigram1 >0 and trigram1 > 0 and quadruplegram1 > 0: 293 | print( "in 4",getBP()) 294 | bleu4=getBP()*math.exp(0.25*math.log(unigram1)+0.25*math.log(bigram1)+0.25*math.log(trigram1)+0.25*math.log(quadruplegram1)) 295 | 296 | print (bleu1, bleu2, bleu3, bleu4) 297 | fp=open('bleu_out.txt','a') 298 | fp.write("%s has blue score: %f\t%f\t%f\t%f\n" %(sys.argv[1], bleu1, bleu2, bleu3, bleu4)) 299 | fp.close() 300 | 301 | if __name__ == "__main__": 302 | candidatefile,referencefiles = getFiles(sys.argv[1],sys.argv[2]) 303 | candidateData,referencesData=readFiles(candidatefile,referencefiles) 304 | """ 305 | for item in candidateData: 306 | for word in item: 307 | print word, 308 | print 309 | """ 310 | getModifiedPrecision(candidateData,referencesData) 311 | print(c, " ", r) -------------------------------------------------------------------------------- /LCG/calc_incrate.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import tqdm 3 | import nltk 4 | from nltk.stem import WordNetLemmatizer 5 | from nltk.corpus import wordnet 6 | 7 | lemmatizer = WordNetLemmatizer() 8 | 9 | 10 | def nltk_pos_tagger(nltk_tag): 11 | if nltk_tag.startswith('J'): 12 | return wordnet.ADJ 13 | elif nltk_tag.startswith('V'): 14 | return wordnet.VERB 15 | elif nltk_tag.startswith('N'): 16 | return wordnet.NOUN 17 | elif nltk_tag.startswith('R'): 18 | return wordnet.ADV 19 | else: 20 | return None 21 | 22 | 23 | def lemmatize_sentence(sentence): 24 | nltk_tagged = nltk.pos_tag(nltk.word_tokenize(sentence)) 25 | wordnet_tagged = map(lambda x: (x[0], nltk_pos_tagger(x[1])), nltk_tagged) 26 | lemmatized_sentence = [] 27 | 28 | for word, tag in wordnet_tagged: 29 | if tag is None: 30 | lemmatized_sentence.append(word) 31 | else: 32 | lemmatized_sentence.append(lemmatizer.lemmatize(word, tag)) 33 | return " ".join(lemmatized_sentence) 34 | 35 | N = 0 36 | M = 0 37 | with open("./data/common_gen/common_gen-generated-baseline.jsonl", "w") as fout: 38 | with open("./data/common_gen/common_gen-keys.txt", "r") as fkeys: 39 | with open("./data/common_gen/common_gen-generated-baseline.txt", "r") as ftest: # 40 | A, B = fkeys.readlines(), ftest.readlines() 41 | if len(A) == len(B): 42 | for (line_key, line_seq) in tqdm.tqdm(zip(A, B)): 43 | keys = line_key.strip().split() 44 | seq = lemmatize_sentence(line_seq.strip()) 45 | print("{\"concept_set\": \"%s\", \"pred_scene\": [\"%s\"]}" % ("#".join(keys), line_seq.strip()), file=fout) 46 | # M += 1 47 | M += len(keys) 48 | included_keys = [key for key in keys if seq.count(key) > 0] 49 | if len(keys) == len(included_keys): 50 | # N += 1 51 | N += len(included_keys) 52 | else: 53 | N += len(included_keys) 54 | # print("Actual keys:", keys) 55 | # print("Included keys:", included_keys) 56 | # print("Generated samples:", line_seq) 57 | else: 58 | i = 0 59 | j = 0 60 | for i in range(len(A)): 61 | cand_N = 0 62 | keys = A[i].strip().split() 63 | M += 1 64 | while B[j].strip() != "": 65 | seq = lemmatize_sentence(B[j].strip()) 66 | included_keys = [key for key in keys if seq.count(key) > 0] 67 | cand_N = max(cand_N, len(included_keys)) 68 | j += 1 69 | if len(keys) == cand_N: 70 | N += 1 71 | j += 1 72 | 73 | print("baseline stats:") 74 | print(N) 75 | print(M) 76 | print(N / M) 77 | 78 | N = 0 79 | M = 0 80 | with open("./data/common_gen/common_gen-generated-finetune.jsonl", "w") as fout: 81 | with open("./data/common_gen/common_gen-keys.txt", "r") as fkeys: 82 | with open("./data/common_gen/common_gen-generated-finetune.txt", "r") as ftest: # 83 | A, B = fkeys.readlines(), ftest.readlines() 84 | if len(A) == len(B): 85 | for (line_key, line_seq) in tqdm.tqdm(zip(A, B)): 86 | keys = line_key.strip().split() 87 | seq = lemmatize_sentence(line_seq.strip()) 88 | print("{\"concept_set\": \"%s\", \"pred_scene\": [\"%s\"]}" % ("#".join(keys), line_seq.strip()), file=fout) 89 | # M += 1 90 | M += len(keys) 91 | included_keys = [key for key in keys if seq.count(key) > 0] 92 | if len(keys) == len(included_keys): 93 | # N += 1 94 | N += len(included_keys) 95 | else: 96 | N += len(included_keys) 97 | # print("Actual keys:", keys) 98 | # print("Included keys:", included_keys) 99 | # print("Generated samples:", line_seq) 100 | else: 101 | i = 0 102 | j = 0 103 | for i in range(len(A)): 104 | cand_N = 0 105 | keys = A[i].strip().split() 106 | M += 1 107 | while B[j].strip() != "": 108 | seq = lemmatize_sentence(B[j].strip()) 109 | included_keys = [key for key in keys if seq.count(key) > 0] 110 | cand_N = max(cand_N, len(included_keys)) 111 | j += 1 112 | if len(keys) == cand_N: 113 | N += 1 114 | j += 1 115 | 116 | 117 | print("finetuned stats:") 118 | print(N) 119 | print(M) 120 | print(N / M) -------------------------------------------------------------------------------- /LCG/data/._.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MtSomeThree/constrDecoding/5cacf8515352806a14d389b813da7f761c895761/LCG/data/._.DS_Store -------------------------------------------------------------------------------- /LCG/eval_kit/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .*.swp 3 | -------------------------------------------------------------------------------- /LCG/eval_kit/README.md: -------------------------------------------------------------------------------- 1 | E2E NLG Challenge Evaluation metrics 2 | ==================================== 3 | 4 | The metrics used for the challenge include: 5 | * BLEU + NIST from [MT-Eval](#mt-eval), 6 | * METEOR, ROUGE-L, CIDEr from the [MS-COCO Caption evaluation scripts](#microsoft-coco-caption-evaluation). 7 | 8 | Running the evaluation 9 | ---------------------- 10 | 11 | ### Requirements/Installation ### 12 | 13 | The metrics script requires the following dependencies: 14 | - Java 1.8 15 | - Python **3.6+** with [matplotlib](https://pypi.python.org/pypi/matplotlib) and [scikit-image](https://pypi.python.org/pypi/scikit-image) packages 16 | - Perl 5.8.8 or higher with the [XML::Twig](http://search.cpan.org/~mirod/XML-Twig-3.49/Twig.pm) CPAN module 17 | 18 | 19 | To install the required Python packages, run (assuming root access or [virtualenv](https://virtualenv.pypa.io/en/stable/)): 20 | ``` 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | To install the required Perl module, run (assuming root access or [perlbrew](https://perlbrew.pl/)/[plenv](https://github.com/tokuhirom/plenv)): 25 | ``` 26 | curl -L https://cpanmin.us | perl - App::cpanminus # install cpanm 27 | cpanm XML::Twig 28 | ``` 29 | 30 | 31 | ### Usage ### 32 | 33 | The main entry point is [measure_scores.py](measure_scores.py). To get a listing of all available options, 34 | run: 35 | ``` 36 | ./measure_scores.py -h 37 | ``` 38 | 39 | The system outputs and human references can either be in a TSV/CSV format, or in plain text. This is 40 | distinguished by the file extension (plain text assumed, unless it's `.tsv` or `.csv`). 41 | 42 | For TSV/CSV, the script assumes that the first column contains source MRs/texts and the second column 43 | contains system outputs or references. Multiple references for the same source MRs/texts are grouped automatically 44 | (either by the same source as in the system output file, if it's also a TSV/CSV, or by consecutive identical 45 | sources). 46 | If there are headers in the TSV/CSV file with reasonably identifiable labels (e.g. “MR”, “source”, 47 | “system output”, “reference” etc., there's some guessing involved), the columns should be identified automatically. 48 | In that case, the file doesn't need to have just two columns in the exact order. 49 | 50 | For plain text files, the script assumes one instance 51 | per line for your system outputs and one entry per line or multiple references for the same instance 52 | separated by empty lines for the references (see 53 | [TGen data conversion](https://github.com/UFAL-DSG/tgen/blob/master/e2e-challenge/README.md)). 54 | 55 | Example human reference and system output files are provided in the [example-inputs](example-inputs/) 56 | subdirectory -- you can try the script on them using this command: 57 | ``` 58 | ./measure_scores.py example-inputs/devel-conc.txt example-inputs/baseline-output.txt 59 | ``` 60 | 61 | Source metrics scripts 62 | ---------------------- 63 | 64 | ### MT-Eval ### 65 | 66 | We used the NIST MT-Eval v13a script adapted for significance tests, from 67 | . 68 | We adapted the script to allow a variable number of references. 69 | 70 | 71 | ### Microsoft COCO Caption Evaluation ### 72 | 73 | These provide a different variant of BLEU (which is not used for evaluation in the E2E challenge), 74 | METEOR, ROUGE-L, CIDER. We used the [Github code for these metrics](https://github.com/tylin/coco-caption). 75 | The metrics are unchanged, apart from removing support for images and some of the dependencies. 76 | 77 | 78 | References 79 | ---------- 80 | 81 | - [Microsoft COCO Captions: Data Collection and Evaluation Server](http://arxiv.org/abs/1504.00325) 82 | - PTBTokenizer: We use the [Stanford Tokenizer](http://nlp.stanford.edu/software/tokenizer.shtml) which is included in [Stanford CoreNLP 3.4.1](http://nlp.stanford.edu/software/corenlp.shtml). 83 | - BLEU: [BLEU: a Method for Automatic Evaluation of Machine Translation](http://www.aclweb.org/anthology/P02-1040.pdf) 84 | - NIST: [Automatic Evaluation of Machine Translation Quality Using N-gram Co-Occurrence Statistics](http://www.mt-archive.info/HLT-2002-Doddington.pdf) 85 | - Meteor: [Project page](http://www.cs.cmu.edu/~alavie/METEOR/) with related publications. We use the latest version (1.5) of the [Code](https://github.com/mjdenkowski/meteor). Changes have been made to the source code to properly aggreate the statistics for the entire corpus. 86 | - Rouge-L: [ROUGE: A Package for Automatic Evaluation of Summaries](http://anthology.aclweb.org/W/W04/W04-1013.pdf) 87 | - CIDEr: [CIDEr: Consensus-based Image Description Evaluation](http://arxiv.org/pdf/1411.5726.pdf) 88 | 89 | Acknowledgements 90 | ---------------- 91 | Original developers of the MSCOCO evaluation scripts: 92 | 93 | Xinlei Chen, Hao Fang, Tsung-Yi Lin, Ramakrishna Vedantam, David Chiang, Michael Denkowski, Alexander Rush 94 | -------------------------------------------------------------------------------- /LCG/eval_kit/example-inputs/README.md: -------------------------------------------------------------------------------- 1 | Example data files for the E2E metrics script 2 | ============================================= 3 | 4 | The files in this subdirectory show the data format expected by the E2E NLG 5 | challenge metrics script. 6 | 7 | * [devel-conc.txt](devel-conc.txt) -- human references for the first 10 instances 8 | of the E2E NLG challenge development set. One reference per line, different 9 | instances are separated by empty lines. 10 | * [baseline-output.txt](baseline-output.txt) -- output of the baseline system on 11 | the first 10 development instances. One instance per line. 12 | 13 | -------------------------------------------------------------------------------- /LCG/eval_kit/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MtSomeThree/constrDecoding/5cacf8515352806a14d389b813da7f761c895761/LCG/eval_kit/metrics/__init__.py -------------------------------------------------------------------------------- /LCG/eval_kit/metrics/pymteval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | BLEU & NIST measurements -- should be compatible with mteval-v13a.pl (basic tokenization). 6 | Also provides BLEU +1 smoothing (if set to work like that). 7 | 8 | TODO: International tokenization 9 | TODO: NIST with variable number of references is not the same as the edited mteval-v13a.pl, 10 | but this should be the proper way to compute it. Should be fixed there. 11 | """ 12 | 13 | from __future__ import unicode_literals 14 | from __future__ import division 15 | from builtins import zip 16 | from builtins import range 17 | from past.utils import old_div 18 | from builtins import object 19 | from collections import defaultdict 20 | import math 21 | import re 22 | 23 | 24 | class NGramScore(object): 25 | """Base class for BLEU & NIST, providing tokenization and some basic n-gram matching 26 | functions.""" 27 | 28 | def __init__(self, max_ngram, case_sensitive): 29 | """Create the scoring object. 30 | @param max_ngram: the n-gram level to compute the score for 31 | @param case_sensitive: use case-sensitive matching? 32 | """ 33 | self.max_ngram = max_ngram 34 | self.case_sensitive = case_sensitive 35 | 36 | def reset(self): 37 | """Reset the object, zero all counters.""" 38 | raise NotImplementedError() 39 | 40 | def append(self, pred_sent, ref_sents): 41 | """Add a sentence to the statistics. 42 | @param pred_sent: system output / predicted sentence 43 | @param ref_sents: reference sentences 44 | """ 45 | raise NotImplementedError() 46 | 47 | def score(self): 48 | """Compute the current score based on sentences added so far.""" 49 | raise NotImplementedError() 50 | 51 | def ngrams(self, n, sent): 52 | """Given a sentence, return n-grams of nodes for the given N. Lowercases 53 | everything if the measure should not be case-sensitive. 54 | 55 | @param n: n-gram 'N' (1 for unigrams, 2 for bigrams etc.) 56 | @param sent: the sent in question 57 | @return: n-grams of nodes, as tuples of tuples (t-lemma & formeme) 58 | """ 59 | if not self.case_sensitive: 60 | return list(zip(*[[tok.lower() for tok in sent[i:]] for i in range(n)])) 61 | return list(zip(*[sent[i:] for i in range(n)])) 62 | 63 | def check_tokenized(self, pred_sent, ref_sents): 64 | """Tokenize the predicted sentence and reference sentences, if they are not tokenized. 65 | @param pred_sent: system output / predicted sentence 66 | @param ref_sent: a list of corresponding reference sentences 67 | @return: a tuple of (pred_sent, ref_sent) where everything is tokenized 68 | """ 69 | # tokenize if needed 70 | pred_sent = pred_sent if isinstance(pred_sent, list) else self.tokenize(pred_sent) 71 | ref_sents = [ref_sent if isinstance(ref_sent, list) else self.tokenize(ref_sent) 72 | for ref_sent in ref_sents] 73 | return pred_sent, ref_sents 74 | 75 | def get_ngram_counts(self, n, sents): 76 | """Returns a dictionary with counts of all n-grams in the given sentences. 77 | @param n: the "n" in n-grams (how long the n-grams should be) 78 | @param sents: list of sentences for n-gram counting 79 | @return: a dictionary (ngram: count) listing counts of n-grams attested in any of the sentences 80 | """ 81 | merged_ngrams = {} 82 | 83 | for sent in sents: 84 | ngrams = defaultdict(int) 85 | 86 | for ngram in self.ngrams(n, sent): 87 | ngrams[ngram] += 1 88 | for ngram, cnt in ngrams.items(): 89 | merged_ngrams[ngram] = max((merged_ngrams.get(ngram, 0), cnt)) 90 | return merged_ngrams 91 | 92 | def tokenize(self, sent): 93 | """This tries to mimic multi-bleu-detok from Moses, and by extension mteval-v13b. 94 | Code taken directly from there and attempted rewrite into Python.""" 95 | # language-independent part: 96 | sent = re.sub(r'', r'', sent) # strip "skipped" tags 97 | sent = re.sub(r'-\n', r'', sent) # strip end-of-line hyphenation and join lines 98 | sent = re.sub(r'\n', r' ', sent) # join lines 99 | sent = re.sub(r'"', r'"', sent) # convert SGML tag for quote to " 100 | sent = re.sub(r'&', r'&', sent) # convert SGML tag for ampersand to & 101 | sent = re.sub(r'<', r'<', sent) # convert SGML tag for less-than to > 102 | sent = re.sub(r'>', r'>', sent) # convert SGML tag for greater-than to < 103 | 104 | # language-dependent part (assuming Western languages): 105 | sent = " " + sent + " " # pad with spaces 106 | sent = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', r' \1 ', sent) # tokenize punctuation 107 | sent = re.sub(r'([^0-9])([\.,])', r'\1 \2 ', sent) # tokenize period and comma unless preceded by a digit 108 | sent = re.sub(r'([\.,])([^0-9])', r' \1 \2', sent) # tokenize period and comma unless followed by a digit 109 | sent = re.sub(r'([0-9])(-)', r'\1 \2 ', sent) # tokenize dash when preceded by a digit 110 | sent = re.sub(r'\s+', r' ', sent) # one space only between words 111 | sent = sent.strip() # remove padding 112 | 113 | return sent.split(' ') 114 | 115 | 116 | class BLEUScore(NGramScore): 117 | """An accumulator object capable of computing BLEU score using multiple references. 118 | 119 | The BLEU score is always smoothed a bit so that it's never undefined. For sentence-level 120 | measurements, proper smoothing should be used via the smoothing parameter (set to 1.0 for 121 | the same behavior as default Moses's MERT sentence BLEU). 122 | """ 123 | 124 | TINY = 1e-15 125 | SMALL = 1e-9 126 | 127 | def __init__(self, max_ngram=4, case_sensitive=False, smoothing=0.0): 128 | """Create the scoring object. 129 | @param max_ngram: the n-gram level to compute the score for (default: 4) 130 | @param case_sensitive: use case-sensitive matching (default: no) 131 | @param smoothing: constant to add for smoothing (defaults to 0.0, sentBLEU uses 1.0) 132 | """ 133 | super(BLEUScore, self).__init__(max_ngram, case_sensitive) 134 | self.smoothing = smoothing 135 | self.reset() 136 | 137 | def reset(self): 138 | """Reset the object, zero all counters.""" 139 | self.ref_len = 0 140 | self.cand_lens = [0] * self.max_ngram 141 | self.hits = [0] * self.max_ngram 142 | 143 | def append(self, pred_sent, ref_sents): 144 | """Append a sentence for measurements, increase counters. 145 | 146 | @param pred_sent: the system output sentence (string/list of tokens) 147 | @param ref_sents: the corresponding reference sentences (list of strings/lists of tokens) 148 | """ 149 | pred_sent, ref_sents = self.check_tokenized(pred_sent, ref_sents) 150 | 151 | # compute n-gram matches 152 | for i in range(self.max_ngram): 153 | self.hits[i] += self.compute_hits(i + 1, pred_sent, ref_sents) 154 | self.cand_lens[i] += len(pred_sent) - i 155 | 156 | # take the reference that is closest in length to the candidate 157 | # (if there are two of the same distance, take the shorter one) 158 | closest_ref = min(ref_sents, key=lambda ref_sent: (abs(len(ref_sent) - len(pred_sent)), len(ref_sent))) 159 | self.ref_len += len(closest_ref) 160 | 161 | def score(self): 162 | """Return the current BLEU score, according to the accumulated counts.""" 163 | return self.bleu() 164 | 165 | def compute_hits(self, n, pred_sent, ref_sents): 166 | """Compute clipped n-gram hits for the given sentences and the given N 167 | 168 | @param n: n-gram 'N' (1 for unigrams, 2 for bigrams etc.) 169 | @param pred_sent: the system output sentence (tree/tokens) 170 | @param ref_sents: the corresponding reference sentences (list/tuple of trees/tokens) 171 | """ 172 | merged_ref_ngrams = self.get_ngram_counts(n, ref_sents) 173 | pred_ngrams = self.get_ngram_counts(n, [pred_sent]) 174 | 175 | hits = 0 176 | for ngram, cnt in pred_ngrams.items(): 177 | hits += min(merged_ref_ngrams.get(ngram, 0), cnt) 178 | 179 | return hits 180 | 181 | def bleu(self): 182 | """Return the current BLEU score, according to the accumulated counts.""" 183 | # brevity penalty (smoothed a bit: if candidate length is 0, we change it to 1e-5 184 | # to avoid division by zero) 185 | bp = 1.0 186 | if (self.cand_lens[0] <= self.ref_len): 187 | bp = math.exp(1.0 - old_div(self.ref_len, 188 | (float(self.cand_lens[0]) if self.cand_lens[0] else 1e-5))) 189 | 190 | return bp * self.ngram_precision() 191 | 192 | def ngram_precision(self): 193 | """Return the current n-gram precision (harmonic mean of n-gram precisions up to max_ngram) 194 | according to the accumulated counts.""" 195 | prec_log_sum = 0.0 196 | for n_hits, n_len in zip(self.hits, self.cand_lens): 197 | n_hits += self.smoothing # pre-set smoothing 198 | n_len += self.smoothing 199 | n_hits = max(n_hits, self.TINY) # forced smoothing just a litle to make BLEU defined 200 | n_len = max(n_len, self.SMALL) # only applied for zeros 201 | prec_log_sum += math.log(old_div(n_hits, n_len)) 202 | 203 | return math.exp((1.0 / self.max_ngram) * prec_log_sum) 204 | 205 | 206 | class NISTScore(NGramScore): 207 | """An accumulator object capable of computing NIST score using multiple references.""" 208 | 209 | # NIST beta parameter setting (copied from mteval-13a.pl) 210 | BETA = old_div(- math.log(0.5), math.log(1.5) ** 2) 211 | 212 | def __init__(self, max_ngram=5, case_sensitive=False): 213 | """Create the scoring object. 214 | @param max_ngram: the n-gram level to compute the score for (default: 5) 215 | @param case_sensitive: use case-sensitive matching (default: no) 216 | """ 217 | super(NISTScore, self).__init__(max_ngram, case_sensitive) 218 | self.reset() 219 | 220 | def reset(self): 221 | """Reset the object, zero all counters.""" 222 | self.ref_ngrams = [defaultdict(int) for _ in range(self.max_ngram + 1)] # has 0-grams 223 | # these two don't have 0-grams 224 | self.hit_ngrams = [[] for _ in range(self.max_ngram)] 225 | self.cand_lens = [[] for _ in range(self.max_ngram)] 226 | self.avg_ref_len = 0.0 227 | 228 | def append(self, pred_sent, ref_sents): 229 | """Append a sentence for measurements, increase counters. 230 | 231 | @param pred_sent: the system output sentence (string/list of tokens) 232 | @param ref_sents: the corresponding reference sentences (list of strings/lists of tokens) 233 | """ 234 | pred_sent, ref_sents = self.check_tokenized(pred_sent, ref_sents) 235 | # collect ngram matches 236 | for n in range(self.max_ngram): 237 | self.cand_lens[n].append(len(pred_sent) - n) # keep track of output length 238 | merged_ref_ngrams = self.get_ngram_counts(n + 1, ref_sents) 239 | pred_ngrams = self.get_ngram_counts(n + 1, [pred_sent]) 240 | # collect ngram matches 241 | hit_ngrams = {} 242 | for ngram in pred_ngrams: 243 | hits = min(pred_ngrams[ngram], merged_ref_ngrams.get(ngram, 0)) 244 | if hits: 245 | hit_ngrams[ngram] = hits 246 | self.hit_ngrams[n].append(hit_ngrams) 247 | # collect total reference ngram counts 248 | for ref_sent in ref_sents: 249 | for ngram in self.ngrams(n + 1, ref_sent): 250 | self.ref_ngrams[n + 1][ngram] += 1 251 | # ref_ngrams: use 0-grams for information value as well 252 | ref_len_sum = sum(len(ref_sent) for ref_sent in ref_sents) 253 | self.ref_ngrams[0][()] += ref_len_sum 254 | # collect average reference length 255 | self.avg_ref_len += ref_len_sum / float(len(ref_sents)) 256 | 257 | def score(self): 258 | """Return the current NIST score, according to the accumulated counts.""" 259 | return self.nist() 260 | 261 | def info(self, ngram): 262 | """Return the NIST informativeness of an n-gram.""" 263 | if ngram not in self.ref_ngrams[len(ngram)]: 264 | return 0.0 265 | return math.log(self.ref_ngrams[len(ngram) - 1][ngram[:-1]] / 266 | float(self.ref_ngrams[len(ngram)][ngram]), 2) 267 | 268 | def nist_length_penalty(self, lsys, avg_lref): 269 | """Compute the NIST length penalty, based on system output length & average reference length. 270 | @param lsys: total system output length 271 | @param avg_lref: total average reference length 272 | @return: NIST length penalty term 273 | """ 274 | ratio = lsys / float(avg_lref) 275 | if ratio >= 1: 276 | return 1 277 | if ratio <= 0: 278 | return 0 279 | return math.exp(-self.BETA * math.log(ratio) ** 2) 280 | 281 | def nist(self): 282 | """Return the current NIST score, according to the accumulated counts.""" 283 | # 1st NIST term 284 | hit_infos = [0.0 for _ in range(self.max_ngram)] 285 | for n in range(self.max_ngram): 286 | for hit_ngrams in self.hit_ngrams[n]: 287 | hit_infos[n] += sum(self.info(ngram) * hits for ngram, hits in hit_ngrams.items()) 288 | total_lens = [sum(self.cand_lens[n]) for n in range(self.max_ngram)] 289 | nist_sum = sum(old_div(hit_info, total_len) for hit_info, total_len in zip(hit_infos, total_lens)) 290 | # length penalty term 291 | bp = self.nist_length_penalty(sum(self.cand_lens[0]), self.avg_ref_len) 292 | return bp * nist_sum 293 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from __future__ import absolute_import 12 | from builtins import object 13 | from .bleu_scorer import BleuScorer 14 | 15 | 16 | class Bleu(object): 17 | def __init__(self, n=4): 18 | # default compute Blue score up to 4 19 | self._n = n 20 | self._hypo_for_image = {} 21 | self.ref_for_image = {} 22 | 23 | def compute_score(self, gts, res): 24 | 25 | assert(list(gts.keys()) == list(res.keys())) 26 | imgIds = list(gts.keys()) 27 | 28 | bleu_scorer = BleuScorer(n=self._n) 29 | for id in imgIds: 30 | hypo = res[id] 31 | ref = gts[id] 32 | 33 | # Sanity check. 34 | assert(type(hypo) is list) 35 | assert(len(hypo) == 1) 36 | assert(type(ref) is list) 37 | assert(len(ref) >= 1) 38 | 39 | bleu_scorer += (hypo[0], ref) 40 | 41 | #score, scores = bleu_scorer.compute_score(option='shortest') 42 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1) 43 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 44 | 45 | # return (bleu, bleu_info) 46 | return score, scores 47 | 48 | def method(self): 49 | return "Bleu" 50 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from builtins import zip 22 | from builtins import range 23 | from builtins import object 24 | from past.utils import old_div 25 | import copy 26 | import sys, math, re 27 | from collections import defaultdict 28 | 29 | def precook(s, n=4, out=False): 30 | """Takes a string as input and returns an object that can be given to 31 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 32 | can take string arguments as well.""" 33 | words = s.split() 34 | counts = defaultdict(int) 35 | for k in range(1,n+1): 36 | for i in range(len(words)-k+1): 37 | ngram = tuple(words[i:i+k]) 38 | counts[ngram] += 1 39 | return (len(words), counts) 40 | 41 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 42 | '''Takes a list of reference sentences for a single segment 43 | and returns an object that encapsulates everything that BLEU 44 | needs to know about them.''' 45 | 46 | reflen = [] 47 | maxcounts = {} 48 | for ref in refs: 49 | rl, counts = precook(ref, n) 50 | reflen.append(rl) 51 | for (ngram,count) in counts.items(): 52 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 53 | 54 | # Calculate effective reference sentence length. 55 | if eff == "shortest": 56 | reflen = min(reflen) 57 | elif eff == "average": 58 | reflen = float(sum(reflen))/len(reflen) 59 | 60 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 61 | 62 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 63 | 64 | return (reflen, maxcounts) 65 | 66 | def cook_test(test, xxx_todo_changeme, eff=None, n=4): 67 | '''Takes a test sentence and returns an object that 68 | encapsulates everything that BLEU needs to know about it.''' 69 | (reflen, refmaxcounts) = xxx_todo_changeme 70 | testlen, counts = precook(test, n, True) 71 | 72 | result = {} 73 | 74 | # Calculate effective reference sentence length. 75 | 76 | if eff == "closest": 77 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 78 | else: ## i.e., "average" or "shortest" or None 79 | result["reflen"] = reflen 80 | 81 | result["testlen"] = testlen 82 | 83 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 84 | 85 | result['correct'] = [0]*n 86 | for (ngram, count) in counts.items(): 87 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 88 | 89 | return result 90 | 91 | class BleuScorer(object): 92 | """Bleu scorer. 93 | """ 94 | 95 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 96 | # special_reflen is used in oracle (proportional effective ref len for a node). 97 | 98 | def copy(self): 99 | ''' copy the refs.''' 100 | new = BleuScorer(n=self.n) 101 | new.ctest = copy.copy(self.ctest) 102 | new.crefs = copy.copy(self.crefs) 103 | new._score = None 104 | return new 105 | 106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 107 | ''' singular instance ''' 108 | 109 | self.n = n 110 | self.crefs = [] 111 | self.ctest = [] 112 | self.cook_append(test, refs) 113 | self.special_reflen = special_reflen 114 | 115 | def cook_append(self, test, refs): 116 | '''called by constructor and __iadd__ to avoid creating new instances.''' 117 | 118 | if refs is not None: 119 | self.crefs.append(cook_refs(refs)) 120 | if test is not None: 121 | cooked_test = cook_test(test, self.crefs[-1]) 122 | self.ctest.append(cooked_test) ## N.B.: -1 123 | else: 124 | self.ctest.append(None) # lens of crefs and ctest have to match 125 | 126 | self._score = None ## need to recompute 127 | 128 | def ratio(self, option=None): 129 | self.compute_score(option=option) 130 | return self._ratio 131 | 132 | def score_ratio(self, option=None): 133 | '''return (bleu, len_ratio) pair''' 134 | return (self.fscore(option=option), self.ratio(option=option)) 135 | 136 | def score_ratio_str(self, option=None): 137 | return "%.4f (%.2f)" % self.score_ratio(option) 138 | 139 | def reflen(self, option=None): 140 | self.compute_score(option=option) 141 | return self._reflen 142 | 143 | def testlen(self, option=None): 144 | self.compute_score(option=option) 145 | return self._testlen 146 | 147 | def retest(self, new_test): 148 | if type(new_test) is str: 149 | new_test = [new_test] 150 | assert len(new_test) == len(self.crefs), new_test 151 | self.ctest = [] 152 | for t, rs in zip(new_test, self.crefs): 153 | self.ctest.append(cook_test(t, rs)) 154 | self._score = None 155 | 156 | return self 157 | 158 | def rescore(self, new_test): 159 | ''' replace test(s) with new test(s), and returns the new score.''' 160 | 161 | return self.retest(new_test).compute_score() 162 | 163 | def size(self): 164 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 165 | return len(self.crefs) 166 | 167 | def __iadd__(self, other): 168 | '''add an instance (e.g., from another sentence).''' 169 | 170 | if type(other) is tuple: 171 | ## avoid creating new BleuScorer instances 172 | self.cook_append(other[0], other[1]) 173 | else: 174 | assert self.compatible(other), "incompatible BLEUs." 175 | self.ctest.extend(other.ctest) 176 | self.crefs.extend(other.crefs) 177 | self._score = None ## need to recompute 178 | 179 | return self 180 | 181 | def compatible(self, other): 182 | return isinstance(other, BleuScorer) and self.n == other.n 183 | 184 | def single_reflen(self, option="average"): 185 | return self._single_reflen(self.crefs[0][0], option) 186 | 187 | def _single_reflen(self, reflens, option=None, testlen=None): 188 | 189 | if option == "shortest": 190 | reflen = min(reflens) 191 | elif option == "average": 192 | reflen = float(sum(reflens))/len(reflens) 193 | elif option == "closest": 194 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 195 | else: 196 | assert False, "unsupported reflen option %s" % option 197 | 198 | return reflen 199 | 200 | def recompute_score(self, option=None, verbose=0): 201 | self._score = None 202 | return self.compute_score(option, verbose) 203 | 204 | def compute_score(self, option=None, verbose=0): 205 | n = self.n 206 | small = 1e-9 207 | tiny = 1e-15 ## so that if guess is 0 still return 0 208 | bleu_list = [[] for _ in range(n)] 209 | 210 | if self._score is not None: 211 | return self._score 212 | 213 | if option is None: 214 | option = "average" if len(self.crefs) == 1 else "closest" 215 | 216 | self._testlen = 0 217 | self._reflen = 0 218 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 219 | 220 | # for each sentence 221 | for comps in self.ctest: 222 | testlen = comps['testlen'] 223 | self._testlen += testlen 224 | 225 | if self.special_reflen is None: ## need computation 226 | reflen = self._single_reflen(comps['reflen'], option, testlen) 227 | else: 228 | reflen = self.special_reflen 229 | 230 | self._reflen += reflen 231 | 232 | for key in ['guess','correct']: 233 | for k in range(n): 234 | totalcomps[key][k] += comps[key][k] 235 | 236 | # append per image bleu score 237 | bleu = 1. 238 | for k in range(n): 239 | bleu *= old_div((float(comps['correct'][k]) + tiny),(float(comps['guess'][k]) + small)) 240 | bleu_list[k].append(bleu ** (1./(k+1))) 241 | ratio = old_div((testlen + tiny), (reflen + small)) ## N.B.: avoid zero division 242 | if ratio < 1: 243 | for k in range(n): 244 | bleu_list[k][-1] *= math.exp(1 - old_div(1,ratio)) 245 | 246 | if verbose > 1: 247 | print(comps, reflen) 248 | 249 | totalcomps['reflen'] = self._reflen 250 | totalcomps['testlen'] = self._testlen 251 | 252 | bleus = [] 253 | bleu = 1. 254 | for k in range(n): 255 | bleu *= float(totalcomps['correct'][k] + tiny) \ 256 | / (totalcomps['guess'][k] + small) 257 | bleus.append(bleu ** (1./(k+1))) 258 | ratio = old_div((self._testlen + tiny), (self._reflen + small)) ## N.B.: avoid zero division 259 | if ratio < 1: 260 | for k in range(n): 261 | bleus[k] *= math.exp(1 - old_div(1,ratio)) 262 | 263 | if verbose > 0: 264 | print(totalcomps) 265 | print("ratio:", ratio) 266 | 267 | self._score = bleus 268 | return self._score, bleu_list 269 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/cider/cider.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | # Filename: cider.py 3 | # 4 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 5 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 6 | # 7 | # Creation Date: Sun Feb 8 14:16:54 2015 8 | # 9 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 10 | 11 | from builtins import object 12 | from .cider_scorer import CiderScorer 13 | import pdb 14 | 15 | class Cider(object): 16 | """ 17 | Main Class to compute the CIDEr metric 18 | 19 | """ 20 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 21 | # set cider to sum over 1 to 4-grams 22 | self._n = n 23 | # set the standard deviation parameter for gaussian penalty 24 | self._sigma = sigma 25 | 26 | def compute_score(self, gts, res): 27 | """ 28 | Main function to compute CIDEr score 29 | :param hypo_for_image (dict) : dictionary with key and value 30 | ref_for_image (dict) : dictionary with key and value 31 | :return: cider (float) : computed CIDEr score for the corpus 32 | """ 33 | 34 | assert(list(gts.keys()) == list(res.keys())) 35 | imgIds = list(gts.keys()) 36 | 37 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 38 | 39 | for id in imgIds: 40 | hypo = res[id] 41 | ref = gts[id] 42 | 43 | # Sanity check. 44 | assert(type(hypo) is list) 45 | assert(len(hypo) == 1) 46 | assert(type(ref) is list) 47 | assert(len(ref) > 0) 48 | 49 | cider_scorer += (hypo[0], ref) 50 | 51 | (score, scores) = cider_scorer.compute_score() 52 | 53 | return score, scores 54 | 55 | def method(self): 56 | return "CIDEr" -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | from __future__ import division 6 | from builtins import zip 7 | from builtins import range 8 | from builtins import object 9 | from past.utils import old_div 10 | import copy 11 | from collections import defaultdict 12 | import numpy as np 13 | import pdb 14 | import math 15 | 16 | def precook(s, n=4, out=False): 17 | """ 18 | Takes a string as input and returns an object that can be given to 19 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 20 | can take string arguments as well. 21 | :param s: string : sentence to be converted into ngrams 22 | :param n: int : number of ngrams for which representation is calculated 23 | :return: term frequency vector for occuring ngrams 24 | """ 25 | words = s.split() 26 | counts = defaultdict(int) 27 | for k in range(1,n+1): 28 | for i in range(len(words)-k+1): 29 | ngram = tuple(words[i:i+k]) 30 | counts[ngram] += 1 31 | return counts 32 | 33 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 34 | '''Takes a list of reference sentences for a single segment 35 | and returns an object that encapsulates everything that BLEU 36 | needs to know about them. 37 | :param refs: list of string : reference sentences for some image 38 | :param n: int : number of ngrams for which (ngram) representation is calculated 39 | :return: result (list of dict) 40 | ''' 41 | return [precook(ref, n) for ref in refs] 42 | 43 | def cook_test(test, n=4): 44 | '''Takes a test sentence and returns an object that 45 | encapsulates everything that BLEU needs to know about it. 46 | :param test: list of string : hypothesis sentence for some image 47 | :param n: int : number of ngrams for which (ngram) representation is calculated 48 | :return: result (dict) 49 | ''' 50 | return precook(test, n, True) 51 | 52 | class CiderScorer(object): 53 | """CIDEr scorer. 54 | """ 55 | 56 | def copy(self): 57 | ''' copy the refs.''' 58 | new = CiderScorer(n=self.n) 59 | new.ctest = copy.copy(self.ctest) 60 | new.crefs = copy.copy(self.crefs) 61 | return new 62 | 63 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 64 | ''' singular instance ''' 65 | self.n = n 66 | self.sigma = sigma 67 | self.crefs = [] 68 | self.ctest = [] 69 | self.document_frequency = defaultdict(float) 70 | self.cook_append(test, refs) 71 | self.ref_len = None 72 | 73 | def cook_append(self, test, refs): 74 | '''called by constructor and __iadd__ to avoid creating new instances.''' 75 | 76 | if refs is not None: 77 | self.crefs.append(cook_refs(refs)) 78 | if test is not None: 79 | self.ctest.append(cook_test(test)) ## N.B.: -1 80 | else: 81 | self.ctest.append(None) # lens of crefs and ctest have to match 82 | 83 | def size(self): 84 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 85 | return len(self.crefs) 86 | 87 | def __iadd__(self, other): 88 | '''add an instance (e.g., from another sentence).''' 89 | 90 | if type(other) is tuple: 91 | ## avoid creating new CiderScorer instances 92 | self.cook_append(other[0], other[1]) 93 | else: 94 | self.ctest.extend(other.ctest) 95 | self.crefs.extend(other.crefs) 96 | 97 | return self 98 | def compute_doc_freq(self): 99 | ''' 100 | Compute term frequency for reference data. 101 | This will be used to compute idf (inverse document frequency later) 102 | The term frequency is stored in the object 103 | :return: None 104 | ''' 105 | for refs in self.crefs: 106 | # refs, k ref captions of one image 107 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 108 | self.document_frequency[ngram] += 1 109 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 110 | 111 | def compute_cider(self): 112 | def counts2vec(cnts): 113 | """ 114 | Function maps counts of ngram to vector of tfidf weights. 115 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 116 | The n-th entry of array denotes length of n-grams. 117 | :param cnts: 118 | :return: vec (array of dict), norm (array of float), length (int) 119 | """ 120 | vec = [defaultdict(float) for _ in range(self.n)] 121 | length = 0 122 | norm = [0.0 for _ in range(self.n)] 123 | for (ngram,term_freq) in cnts.items(): 124 | # give word count 1 if it doesn't appear in reference corpus 125 | df = np.log(max(1.0, self.document_frequency[ngram])) 126 | # ngram index 127 | n = len(ngram)-1 128 | # tf (term_freq) * idf (precomputed idf) for n-grams 129 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 130 | # compute norm for the vector. the norm will be used for computing similarity 131 | norm[n] += pow(vec[n][ngram], 2) 132 | 133 | if n == 1: 134 | length += term_freq 135 | norm = [np.sqrt(n) for n in norm] 136 | return vec, norm, length 137 | 138 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 139 | ''' 140 | Compute the cosine similarity of two vectors. 141 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 142 | :param vec_ref: array of dictionary for vector corresponding to reference 143 | :param norm_hyp: array of float for vector corresponding to hypothesis 144 | :param norm_ref: array of float for vector corresponding to reference 145 | :param length_hyp: int containing length of hypothesis 146 | :param length_ref: int containing length of reference 147 | :return: array of score for each n-grams cosine similarity 148 | ''' 149 | delta = float(length_hyp - length_ref) 150 | # measure consine similarity 151 | val = np.array([0.0 for _ in range(self.n)]) 152 | for n in range(self.n): 153 | # ngram 154 | for (ngram,count) in vec_hyp[n].items(): 155 | # vrama91 : added clipping 156 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 157 | 158 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 159 | val[n] /= (norm_hyp[n]*norm_ref[n]) 160 | 161 | assert(not math.isnan(val[n])) 162 | # vrama91: added a length based gaussian penalty 163 | val[n] *= np.e**(old_div(-(delta**2),(2*self.sigma**2))) 164 | return val 165 | 166 | # compute log reference length 167 | self.ref_len = np.log(float(len(self.crefs))) 168 | 169 | scores = [] 170 | for test, refs in zip(self.ctest, self.crefs): 171 | # compute vector for test captions 172 | vec, norm, length = counts2vec(test) 173 | # compute vector for ref captions 174 | score = np.array([0.0 for _ in range(self.n)]) 175 | for ref in refs: 176 | vec_ref, norm_ref, length_ref = counts2vec(ref) 177 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 178 | # change by vrama91 - mean of ngram scores, instead of sum 179 | score_avg = np.mean(score) 180 | # divide by number of references 181 | score_avg /= len(refs) 182 | # multiply score by 10 183 | score_avg *= 10.0 184 | # append score of an image to the score list 185 | scores.append(score_avg) 186 | return scores 187 | 188 | def compute_score(self, option=None, verbose=0): 189 | # compute idf 190 | self.compute_doc_freq() 191 | # assert to check document frequency 192 | assert(len(self.ctest) >= max(self.document_frequency.values())) 193 | # compute cider score 194 | score = self.compute_cider() 195 | # debug 196 | # print score 197 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from builtins import zip 4 | from builtins import object 5 | __author__ = 'tylin' 6 | from .tokenizer.ptbtokenizer import PTBTokenizer 7 | from .bleu.bleu import Bleu 8 | from .meteor.meteor import Meteor 9 | from .rouge.rouge import Rouge 10 | from .cider.cider import Cider 11 | import sys 12 | 13 | class COCOEvalCap(object): 14 | def __init__(self, coco, cocoRes): 15 | self.evalImgs = [] 16 | self.eval = {} 17 | self.imgToEval = {} 18 | self.coco = coco 19 | self.cocoRes = cocoRes 20 | self.params = {'image_id': coco.getImgIds()} 21 | 22 | def evaluate(self): 23 | imgIds = self.params['image_id'] 24 | # imgIds = self.coco.getImgIds() 25 | gts = {} 26 | res = {} 27 | for imgId in imgIds: 28 | gts[imgId] = self.coco.imgToAnns[imgId] 29 | res[imgId] = self.cocoRes.imgToAnns[imgId] 30 | 31 | # ================================================= 32 | # Set up scorers 33 | # ================================================= 34 | print('tokenization...', file=sys.stderr) 35 | tokenizer = PTBTokenizer() 36 | gts = tokenizer.tokenize(gts) 37 | res = tokenizer.tokenize(res) 38 | 39 | # ================================================= 40 | # Set up scorers 41 | # ================================================= 42 | print('setting up scorers...', file=sys.stderr) 43 | scorers = [ 44 | (Meteor(),"METEOR"), 45 | (Rouge(), "ROUGE_L"), 46 | (Cider(), "CIDEr") 47 | ] 48 | 49 | # ================================================= 50 | # Compute scores 51 | # ================================================= 52 | for scorer, method in scorers: 53 | print('computing %s score...'%(scorer.method()), file=sys.stderr) 54 | score, scores = scorer.compute_score(gts, res) 55 | if type(method) == list: 56 | for sc, scs, m in zip(score, scores, method): 57 | self.setEval(sc, m) 58 | self.setImgToEvalImgs(scs, list(gts.keys()), m) 59 | print("%s: %0.3f"%(m, sc), file=sys.stderr) 60 | else: 61 | self.setEval(score, method) 62 | self.setImgToEvalImgs(scores, list(gts.keys()), method) 63 | print("%s: %0.3f"%(method, score), file=sys.stderr) 64 | self.setEvalImgs() 65 | 66 | def setEval(self, score, method): 67 | self.eval[method] = score 68 | 69 | def setImgToEvalImgs(self, scores, imgIds, method): 70 | for imgId, score in zip(imgIds, scores): 71 | if not imgId in self.imgToEval: 72 | self.imgToEval[imgId] = {} 73 | self.imgToEval[imgId]["image_id"] = imgId 74 | self.imgToEval[imgId][method] = score 75 | 76 | def setEvalImgs(self): 77 | self.evalImgs = [eval for imgId, eval in list(self.imgToEval.items())] 78 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/meteor/data/paraphrase-en.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MtSomeThree/constrDecoding/5cacf8515352806a14d389b813da7f761c895761/LCG/eval_kit/pycocoevalcap/meteor/data/paraphrase-en.gz -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MtSomeThree/constrDecoding/5cacf8515352806a14d389b813da7f761c895761/LCG/eval_kit/pycocoevalcap/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | from builtins import range 7 | from builtins import object 8 | import os 9 | import sys 10 | import subprocess 11 | import threading 12 | 13 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 14 | METEOR_JAR = 'meteor-1.5.jar' 15 | # print METEOR_JAR 16 | 17 | class Meteor(object): 18 | 19 | def __init__(self): 20 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 21 | '-', '-', '-stdio', '-l', 'en', '-norm'] 22 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 23 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 24 | stdin=subprocess.PIPE, \ 25 | stdout=subprocess.PIPE, \ 26 | stderr=subprocess.PIPE) 27 | # Used to guarantee thread safety 28 | self.lock = threading.Lock() 29 | 30 | def compute_score(self, gts, res): 31 | assert(list(gts.keys()) == list(res.keys())) 32 | imgIds = list(gts.keys()) 33 | scores = [] 34 | 35 | eval_line = 'EVAL' 36 | self.lock.acquire() 37 | for i in imgIds: 38 | assert(len(res[i]) == 1) 39 | stat = self._stat(res[i][0], gts[i]) 40 | eval_line += ' ||| {}'.format(stat) 41 | 42 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode('UTF-8')) 43 | self.meteor_p.stdin.flush() 44 | for i in range(0,len(imgIds)): 45 | scores.append(float(self.meteor_p.stdout.readline().decode('UTF-8').strip())) 46 | score = float(self.meteor_p.stdout.readline().strip()) 47 | self.lock.release() 48 | 49 | return score, scores 50 | 51 | def method(self): 52 | return "METEOR" 53 | 54 | def _stat(self, hypothesis_str, reference_list): 55 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 56 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 57 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 58 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode('UTF-8')) 59 | self.meteor_p.stdin.flush() 60 | res = self.meteor_p.stdout.readline().decode('UTF-8').strip() 61 | return res 62 | 63 | def _score(self, hypothesis_str, reference_list): 64 | self.lock.acquire() 65 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 66 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 67 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 68 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode('UTF-8')) 69 | self.meteor_p.stdin.flush() 70 | stats = self.meteor_p.stdout.readline().strip() 71 | eval_line = 'EVAL ||| {}'.format(stats) 72 | # EVAL ||| stats 73 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode('UTF-8')) 74 | self.meteor_p.stdin.flush() 75 | score = float(self.meteor_p.stdout.readline().decode('UTF-8').strip()) 76 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 77 | # thanks for Andrej for pointing this out 78 | score = float(self.meteor_p.stdout.readline().decode('UTF-8').strip()) 79 | self.lock.release() 80 | return score 81 | 82 | def __del__(self): 83 | self.lock.acquire() 84 | self.meteor_p.stdin.close() 85 | self.meteor_p.kill() 86 | self.meteor_p.wait() 87 | self.lock.release() 88 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | from builtins import range 11 | from builtins import object 12 | import numpy as np 13 | import pdb 14 | 15 | def my_lcs(string, sub): 16 | """ 17 | Calculates longest common subsequence for a pair of tokenized strings 18 | :param string : list of str : tokens from a string split using whitespace 19 | :param sub : list of str : shorter string, also split using whitespace 20 | :returns: length (list of int): length of the longest common subsequence between the two strings 21 | 22 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 23 | """ 24 | if(len(string)< len(sub)): 25 | sub, string = string, sub 26 | 27 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 28 | 29 | for j in range(1,len(sub)+1): 30 | for i in range(1,len(string)+1): 31 | if(string[i-1] == sub[j-1]): 32 | lengths[i][j] = lengths[i-1][j-1] + 1 33 | else: 34 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 35 | 36 | return lengths[len(string)][len(sub)] 37 | 38 | class Rouge(object): 39 | ''' 40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 41 | 42 | ''' 43 | def __init__(self): 44 | # vrama91: updated the value below based on discussion with Hovey 45 | self.beta = 1.2 46 | 47 | def calc_score(self, candidate, refs): 48 | """ 49 | Compute ROUGE-L score given one candidate and references for an image 50 | :param candidate: str : candidate sentence to be evaluated 51 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 52 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 53 | """ 54 | assert(len(candidate)==1) 55 | assert(len(refs)>0) 56 | prec = [] 57 | rec = [] 58 | 59 | # split into tokens 60 | token_c = candidate[0].split(" ") 61 | 62 | for reference in refs: 63 | # split into tokens 64 | token_r = reference.split(" ") 65 | # compute the longest common subsequence 66 | lcs = my_lcs(token_r, token_c) 67 | prec.append(lcs/float(len(token_c))) 68 | rec.append(lcs/float(len(token_r))) 69 | 70 | prec_max = max(prec) 71 | rec_max = max(rec) 72 | 73 | if(prec_max!=0 and rec_max !=0): 74 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 75 | else: 76 | score = 0.0 77 | return score 78 | 79 | def compute_score(self, gts, res): 80 | """ 81 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 82 | Invoked by evaluate_captions.py 83 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 84 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 85 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 86 | """ 87 | assert(list(gts.keys()) == list(res.keys())) 88 | imgIds = list(gts.keys()) 89 | 90 | score = [] 91 | for id in imgIds: 92 | hypo = res[id] 93 | ref = gts[id] 94 | 95 | score.append(self.calc_score(hypo, ref)) 96 | 97 | # Sanity check. 98 | assert(type(hypo) is list) 99 | assert(len(hypo) == 1) 100 | assert(type(ref) is list) 101 | assert(len(ref) > 0) 102 | 103 | average_score = np.mean(np.array(score)) 104 | return average_score, np.array(score) 105 | 106 | def method(self): 107 | return "Rouge" 108 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from builtins import zip 12 | from builtins import range 13 | from builtins import object 14 | import os 15 | import sys 16 | import subprocess 17 | import tempfile 18 | import itertools 19 | 20 | # path to the stanford corenlp jar 21 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 22 | 23 | # punctuations to be removed from the sentences 24 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 25 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 26 | 27 | class PTBTokenizer(object): 28 | """Python wrapper of Stanford PTBTokenizer""" 29 | 30 | def tokenize(self, captions_for_image): 31 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 32 | 'edu.stanford.nlp.process.PTBTokenizer', \ 33 | '-preserveLines', '-lowerCase'] 34 | 35 | # ====================================================== 36 | # prepare data for PTB Tokenizer 37 | # ====================================================== 38 | final_tokenized_captions_for_image = {} 39 | image_id = [k for k, v in list(captions_for_image.items()) for _ in range(len(v))] 40 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in list(captions_for_image.items()) for c in v]) 41 | 42 | # ====================================================== 43 | # save sentences to temporary file 44 | # ====================================================== 45 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 46 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 47 | tmp_file.write(sentences.encode('UTF-8')) 48 | tmp_file.close() 49 | 50 | # ====================================================== 51 | # tokenize sentence 52 | # ====================================================== 53 | cmd.append(os.path.basename(tmp_file.name)) 54 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 55 | stdout=subprocess.PIPE, encoding='UTF-8') 56 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 57 | lines = token_lines.split('\n') 58 | # remove temp file 59 | os.remove(tmp_file.name) 60 | 61 | # ====================================================== 62 | # create dictionary for tokenized captions 63 | # ====================================================== 64 | for k, line in zip(image_id, lines): 65 | if not k in final_tokenized_captions_for_image: 66 | final_tokenized_captions_for_image[k] = [] 67 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 68 | if w not in PUNCTUATIONS]) 69 | final_tokenized_captions_for_image[k].append(tokenized_caption) 70 | 71 | return final_tokenized_captions_for_image 72 | -------------------------------------------------------------------------------- /LCG/eval_kit/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MtSomeThree/constrDecoding/5cacf8515352806a14d389b813da7f761c895761/LCG/eval_kit/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /LCG/eval_kit/pycocotools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /LCG/finetune.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plan-And-Write Style Autoregressive Model Baseline for Lexically-Constrained Text Generation 3 | """ 4 | import torch 5 | import pickle 6 | import torch.utils.data as datautils 7 | from IPython import embed 8 | import argparse 9 | import os 10 | import numpy as np 11 | import tqdm 12 | from data.dataloader import NaiveTokenizer, RandomKeywordSequentialDataset, GivenKeywordSequentialDataset, DomainAdaptationSequentialDataset 13 | from transformers.optimization import AdamW 14 | from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer 15 | from models.decoder.modeling_cdgpt import CDGPT2LMHeadModel 16 | from data.dataloader import LexicalCheckingDataset, ExactFormLexicalCheckingDataset 17 | 18 | 19 | def str2bool(v): 20 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 21 | return True 22 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 23 | return False 24 | else: 25 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | "--continue_training", 31 | default=False, 32 | type=str2bool, 33 | required=False, 34 | help="Continue the training or start from scratch.", 35 | ) 36 | parser.add_argument( 37 | "--dataset", 38 | default="yelp_review", 39 | type=str, 40 | required=False, 41 | help="Target dataset. Default: wmt16roen." 42 | ) 43 | parser.add_argument( 44 | "--eval_mode", 45 | default=False, 46 | type=bool, 47 | required=False, 48 | help="load the latest checkpoint and run the eval pipeline.", 49 | ) 50 | parser.add_argument( 51 | "--batch_size", 52 | default=256, 53 | type=int, 54 | required=False, 55 | help="effective batchsize", 56 | ) 57 | parser.add_argument( 58 | "--iter_per", 59 | default=8, 60 | type=int, 61 | required=False, 62 | help="cumulative gradient iteration cycle", 63 | ) 64 | parser.add_argument( 65 | "--bellman_reg", 66 | default="1.00", 67 | type=str, 68 | required=False, 69 | help="strength of the bellman regularization term", 70 | ) 71 | 72 | 73 | args = parser.parse_args() 74 | dataset_name = args.dataset 75 | keyword_num = 7 if dataset_name == "yelp_review" else 4 76 | eval_mode = args.eval_mode 77 | batch_size = args.batch_size 78 | iter_per = args.iter_per 79 | directory_identifier_raw = "basemodel_%s_%s" % (dataset_name, "wlex") 80 | directory_identifier = "baseline_randkey_%s" % dataset_name 81 | continue_training = args.continue_training 82 | if not os.path.exists("./checkpoints"): 83 | os.mkdir("checkpoints") 84 | if not os.path.exists("./checkpoints/%s" % directory_identifier): 85 | os.mkdir("checkpoints/%s" % directory_identifier) 86 | try: 87 | tokenizer, dataset, val_dataset = pickle.load( 88 | open("checkpoints/baselinedataset-%s.pyc" % (dataset_name), "rb")) 89 | except: 90 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large") 91 | tokenizer.bos_token = "$" 92 | tokenizer.sep_token = "#" 93 | if dataset_name[0:11] == "yelp_review" or dataset_name[0:4] == "news": 94 | dataset = RandomKeywordSequentialDataset(tokenizer=tokenizer, max_len=384, keyword_num=keyword_num) 95 | dataset.add(dataset_name) 96 | else: 97 | import datasets 98 | dataset_raw = datasets.load_dataset(dataset_name, split="train") 99 | dataset = GivenKeywordSequentialDataset(tokenizer=tokenizer, max_len=384) 100 | dataset.add(dataset_raw, field_keywords="concepts", field_sequence="target") 101 | val_dataset = None 102 | pickle.dump((tokenizer, dataset, val_dataset), open("checkpoints/baselinedataset-%s.pyc" % (dataset_name), "wb")) 103 | 104 | dataset.__getitem__(0) 105 | base_model = GPT2LMHeadModel.from_pretrained("gpt2-large") 106 | base_model.eval() 107 | base_model.cuda() 108 | config = GPT2Config.from_pretrained("gpt2", # much smaller model 109 | n_layer=4, 110 | ) 111 | generator = CDGPT2LMHeadModel(config=config, base_model=base_model) 112 | generator.cuda() 113 | generator.train() 114 | base_model.load_state_dict(torch.load("checkpoints/%s/pretrained" % directory_identifier_raw)) 115 | # generator.load_state_dict(torch.load("checkpoints/%s/model" % directory_identifier_raw)) 116 | 117 | dataloader = datautils.DataLoader( 118 | dataset, batch_size=batch_size // iter_per, shuffle=True, drop_last=False, pin_memory=True, 119 | num_workers=8 120 | ) 121 | 122 | dataset.produce_keys("./data/%s/%s-keys-generated.txt" % (dataset_name, dataset_name)) 123 | opt = AdamW(lr=2e-5, weight_decay=0.02, 124 | eps=1e-8, params=generator.parameters()) 125 | 126 | if dataset_name in ["yelp_review", "news"]: 127 | generator.load_state_dict(torch.load("checkpoints/%s/warmup" % directory_identifier)) 128 | 129 | if continue_training: 130 | generator.load_state_dict(torch.load("checkpoints/%s/model-finetune" % directory_identifier)) 131 | opt_ = torch.load("checkpoints/%s/opt" % directory_identifier) 132 | opt.load_state_dict(opt_) 133 | epoch_idx, iter_count = torch.load("checkpoints/%s/EpochIdx" % directory_identifier) 134 | else: 135 | epoch_idx, iter_count = 0, 0 136 | if type(dataloader) is list: 137 | for curriculum_step_i in range(len(dataloader)): 138 | flog_train = open("checkpoints/%s/log-%d.txt" % (directory_identifier, curriculum_step_i), "w") 139 | flog_train.close() 140 | flog_eval = open("checkpoints/%s/log-eval.txt" % (directory_identifier), "w") 141 | flog_eval.close() 142 | 143 | if eval_mode: 144 | selection_epoch = 2 145 | print(selection_epoch) 146 | generator.load_state_dict(torch.load("checkpoints/%s/model-finetune-%d-lambda-%s" % (directory_identifier, selection_epoch, args.bellman_reg))) 147 | generator.eval() 148 | if type(tokenizer) is NaiveTokenizer: 149 | tokenizer.close_vocab() 150 | fout = open("./data/%s/%s-generated-finetune.txt" % (dataset_name, dataset_name), "w") 151 | with open("./data/%s/%s-keys.txt" % (dataset_name, dataset_name), "r") as fin: 152 | for line in tqdm.tqdm(fin.readlines()): 153 | input_ids = torch.tensor([tokenizer.bos_token_id] + tokenizer.encode(line.strip()) + [tokenizer.sep_token_id]) 154 | generated = generator.generate(input_ids=input_ids.cuda().unsqueeze(dim=0), max_length=300, 155 | num_beams=20, pad_token_id=tokenizer.eos_token_id) 156 | generated_str = tokenizer.decode(generated[0][len(input_ids):-1]) 157 | print(generated_str, file=fout) 158 | fout.close() 159 | exit() 160 | bsz = args.batch_size // iter_per 161 | 162 | 163 | if dataset_name in {"common_gen"}: 164 | try: 165 | virtual_dataset = pickle.load( 166 | open("./data/%s/%s-sampled" % (dataset_name, directory_identifier_raw), "rb")) 167 | # raise NotImplementedError() 168 | except: 169 | dataset.produce_keys("./data/%s/%s-keys-generated.txt" % (dataset_name, dataset_name)) 170 | fin = open("./data/%s/%s-keys-generated.txt" % (dataset_name, dataset_name), "r") 171 | virtual_dataset = LexicalCheckingDataset(tokenizer, expansion_num=48) 172 | virtual_dataset.add(fin, generator) 173 | fin.close() 174 | 175 | pickle.dump(virtual_dataset, open("./data/%s/%s-sampled" % (dataset_name, directory_identifier_raw), "wb")) 176 | exit() 177 | virtual_dataloader = datautils.DataLoader( 178 | virtual_dataset, batch_size=batch_size // iter_per, shuffle=True, drop_last=False, pin_memory=True, 179 | num_workers=8 180 | ) 181 | stren_reg = args.bellman_reg 182 | for epoch_idx in range(10): 183 | iterator = tqdm.tqdm(virtual_dataloader) 184 | NLL_REDUCED = [] 185 | DISCREPANCY_REDUCED = [] 186 | for generated, mask, label, length in iterator: 187 | max_len = length.max().item() + 1 188 | if max_len > 300: 189 | 190 | continue 191 | generated = generated.cuda()[:, 0:max_len] 192 | mask = mask.cuda()[:, 0:max_len-1] 193 | label = label.cuda() 194 | neufact_output = generator(generated, labels=torch.ones_like(generated).to(torch.float) * label.reshape(-1, 1)) 195 | logits = neufact_output.loss 196 | reg_loss = neufact_output.reg_loss 197 | 198 | nll_reduced = -(logits * mask).sum(dim=-1) 199 | reg_loss_reduced = (reg_loss[:, :-1] * mask[:, 1:]).sum(dim=-1) 200 | NLL_REDUCED.append(nll_reduced.mean(dim=0).cpu().item()) 201 | DISCREPANCY_REDUCED.append(reg_loss_reduced.mean(dim=0).cpu().item()) 202 | if iter_count % iter_per == 0: 203 | opt.zero_grad() 204 | loss = nll_reduced + float(stren_reg) * reg_loss_reduced 205 | (loss / iter_per).mean(dim=0).backward() 206 | if iter_count % iter_per == iter_per - 1: 207 | opt.step() 208 | if (iter_count // iter_per) % 10 == 0: 209 | iterator.write("Iteration %d-%d, Loss %f, Reg Loss %f" % ( 210 | epoch_idx, iter_count // iter_per, NLL_REDUCED[-1], DISCREPANCY_REDUCED[-1])) 211 | # embed(); exit() 212 | iter_count += 1 213 | print("Now avg. nll_red=", np.mean(NLL_REDUCED)) 214 | print("Now avg. reg_red=", np.mean(DISCREPANCY_REDUCED)) 215 | epoch_idx += 1 216 | torch.save(generator.state_dict(), "checkpoints/%s/model-finetune-%d-lambda-%s" % (directory_identifier, epoch_idx, stren_reg)) 217 | torch.save(generator.state_dict(), 218 | "checkpoints/%s/model-finetune" % (directory_identifier)) 219 | torch.save(opt.state_dict(), "checkpoints/%s/opt" % directory_identifier) 220 | torch.save((epoch_idx, iter_count), "checkpoints/%s/EpochIdx" % directory_identifier) 221 | elif dataset_name in {"yelp_review", "news"}: 222 | try: 223 | virtual_dataset = pickle.load( 224 | open("./data/%s/%s-sampled" % (dataset_name, directory_identifier_raw), "rb")) 225 | # raise NotImplementedError() 226 | except: 227 | # from torch.multiprocessing import Pool, Process, set_start_method 228 | # try: 229 | # set_start_method('spawn') 230 | # except RuntimeError: 231 | # pass 232 | dataset.produce_keys("./data/%s/%s-keys-generated.txt" % (dataset_name, dataset_name)) 233 | with open("./data/%s/%s-keys-generated.txt" % (dataset_name, dataset_name), "r") as fin: 234 | virtual_dataset = ExactFormLexicalCheckingDataset(tokenizer, expansion_num=32) 235 | virtual_dataset.add(fin, generator) 236 | 237 | pickle.dump(virtual_dataset, open("./data/%s/%s-sampled" % (dataset_name, directory_identifier_raw), "wb")) 238 | exit() 239 | virtual_dataloader = datautils.DataLoader( 240 | virtual_dataset, batch_size=batch_size // iter_per, shuffle=True, drop_last=False, pin_memory=True, 241 | num_workers=8 242 | ) 243 | 244 | if __name__ == "__main__": 245 | main() -------------------------------------------------------------------------------- /LCG/generate_testset.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import tqdm 3 | # a_dev = datasets.load_dataset("common_gen", split="validation") 4 | a_dev = datasets.load_dataset("common_gen", split="test") 5 | test_base = dict() 6 | for instance in tqdm.tqdm(a_dev): 7 | keyword_seq = " ".join(instance["concepts"]) 8 | ref_seq = instance["target"] 9 | if keyword_seq in test_base: 10 | test_base[keyword_seq].append(ref_seq) 11 | else: 12 | test_base[keyword_seq] = [ref_seq] 13 | 14 | with open("./data/common_gen/common_gen-keys.txt", "w") as fkeys: 15 | with open("./data/common_gen/common_gen-test.txt", "w") as ftest: 16 | for key in tqdm.tqdm(test_base): 17 | print(key, file=fkeys) 18 | for ref in test_base[key]: 19 | print(ref, file=ftest) 20 | print(file=ftest) 21 | -------------------------------------------------------------------------------- /LCG/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MtSomeThree/constrDecoding/5cacf8515352806a14d389b813da7f761c895761/LCG/metrics/__init__.py -------------------------------------------------------------------------------- /LCG/metrics/pymteval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | BLEU & NIST measurements -- should be compatible with mteval-v13a.pl (basic tokenization). 6 | Also provides BLEU +1 smoothing (if set to work like that). 7 | 8 | TODO: International tokenization 9 | TODO: NIST with variable number of references is not the same as the edited mteval-v13a.pl, 10 | but this should be the proper way to compute it. Should be fixed there. 11 | """ 12 | 13 | from __future__ import unicode_literals 14 | from __future__ import division 15 | from builtins import zip 16 | from builtins import range 17 | from past.utils import old_div 18 | from builtins import object 19 | from collections import defaultdict 20 | import math 21 | import re 22 | 23 | 24 | class NGramScore(object): 25 | """Base class for BLEU & NIST, providing tokenization and some basic n-gram matching 26 | functions.""" 27 | 28 | def __init__(self, max_ngram, case_sensitive): 29 | """Create the scoring object. 30 | @param max_ngram: the n-gram level to compute the score for 31 | @param case_sensitive: use case-sensitive matching? 32 | """ 33 | self.max_ngram = max_ngram 34 | self.case_sensitive = case_sensitive 35 | 36 | def reset(self): 37 | """Reset the object, zero all counters.""" 38 | raise NotImplementedError() 39 | 40 | def append(self, pred_sent, ref_sents): 41 | """Add a sentence to the statistics. 42 | @param pred_sent: system output / predicted sentence 43 | @param ref_sents: reference sentences 44 | """ 45 | raise NotImplementedError() 46 | 47 | def score(self): 48 | """Compute the current score based on sentences added so far.""" 49 | raise NotImplementedError() 50 | 51 | def ngrams(self, n, sent): 52 | """Given a sentence, return n-grams of nodes for the given N. Lowercases 53 | everything if the measure should not be case-sensitive. 54 | 55 | @param n: n-gram 'N' (1 for unigrams, 2 for bigrams etc.) 56 | @param sent: the sent in question 57 | @return: n-grams of nodes, as tuples of tuples (t-lemma & formeme) 58 | """ 59 | if not self.case_sensitive: 60 | return list(zip(*[[tok.lower() for tok in sent[i:]] for i in range(n)])) 61 | return list(zip(*[sent[i:] for i in range(n)])) 62 | 63 | def check_tokenized(self, pred_sent, ref_sents): 64 | """Tokenize the predicted sentence and reference sentences, if they are not tokenized. 65 | @param pred_sent: system output / predicted sentence 66 | @param ref_sent: a list of corresponding reference sentences 67 | @return: a tuple of (pred_sent, ref_sent) where everything is tokenized 68 | """ 69 | # tokenize if needed 70 | pred_sent = pred_sent if isinstance(pred_sent, list) else self.tokenize(pred_sent) 71 | ref_sents = [ref_sent if isinstance(ref_sent, list) else self.tokenize(ref_sent) 72 | for ref_sent in ref_sents] 73 | return pred_sent, ref_sents 74 | 75 | def get_ngram_counts(self, n, sents): 76 | """Returns a dictionary with counts of all n-grams in the given sentences. 77 | @param n: the "n" in n-grams (how long the n-grams should be) 78 | @param sents: list of sentences for n-gram counting 79 | @return: a dictionary (ngram: count) listing counts of n-grams attested in any of the sentences 80 | """ 81 | merged_ngrams = {} 82 | 83 | for sent in sents: 84 | ngrams = defaultdict(int) 85 | 86 | for ngram in self.ngrams(n, sent): 87 | ngrams[ngram] += 1 88 | for ngram, cnt in ngrams.items(): 89 | merged_ngrams[ngram] = max((merged_ngrams.get(ngram, 0), cnt)) 90 | return merged_ngrams 91 | 92 | def tokenize(self, sent): 93 | """This tries to mimic multi-bleu-detok from Moses, and by extension mteval-v13b. 94 | Code taken directly from there and attempted rewrite into Python.""" 95 | # language-independent part: 96 | sent = re.sub(r'', r'', sent) # strip "skipped" tags 97 | sent = re.sub(r'-\n', r'', sent) # strip end-of-line hyphenation and join lines 98 | sent = re.sub(r'\n', r' ', sent) # join lines 99 | sent = re.sub(r'"', r'"', sent) # convert SGML tag for quote to " 100 | sent = re.sub(r'&', r'&', sent) # convert SGML tag for ampersand to & 101 | sent = re.sub(r'<', r'<', sent) # convert SGML tag for less-than to > 102 | sent = re.sub(r'>', r'>', sent) # convert SGML tag for greater-than to < 103 | 104 | # language-dependent part (assuming Western languages): 105 | sent = " " + sent + " " # pad with spaces 106 | sent = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', r' \1 ', sent) # tokenize punctuation 107 | sent = re.sub(r'([^0-9])([\.,])', r'\1 \2 ', sent) # tokenize period and comma unless preceded by a digit 108 | sent = re.sub(r'([\.,])([^0-9])', r' \1 \2', sent) # tokenize period and comma unless followed by a digit 109 | sent = re.sub(r'([0-9])(-)', r'\1 \2 ', sent) # tokenize dash when preceded by a digit 110 | sent = re.sub(r'\s+', r' ', sent) # one space only between words 111 | sent = sent.strip() # remove padding 112 | 113 | return sent.split(' ') 114 | 115 | 116 | class BLEUScore(NGramScore): 117 | """An accumulator object capable of computing BLEU score using multiple references. 118 | 119 | The BLEU score is always smoothed a bit so that it's never undefined. For sentence-level 120 | measurements, proper smoothing should be used via the smoothing parameter (set to 1.0 for 121 | the same behavior as default Moses's MERT sentence BLEU). 122 | """ 123 | 124 | TINY = 1e-15 125 | SMALL = 1e-9 126 | 127 | def __init__(self, max_ngram=4, case_sensitive=False, smoothing=0.0): 128 | """Create the scoring object. 129 | @param max_ngram: the n-gram level to compute the score for (default: 4) 130 | @param case_sensitive: use case-sensitive matching (default: no) 131 | @param smoothing: constant to add for smoothing (defaults to 0.0, sentBLEU uses 1.0) 132 | """ 133 | super(BLEUScore, self).__init__(max_ngram, case_sensitive) 134 | self.smoothing = smoothing 135 | self.reset() 136 | 137 | def reset(self): 138 | """Reset the object, zero all counters.""" 139 | self.ref_len = 0 140 | self.cand_lens = [0] * self.max_ngram 141 | self.hits = [0] * self.max_ngram 142 | 143 | def append(self, pred_sent, ref_sents): 144 | """Append a sentence for measurements, increase counters. 145 | 146 | @param pred_sent: the system output sentence (string/list of tokens) 147 | @param ref_sents: the corresponding reference sentences (list of strings/lists of tokens) 148 | """ 149 | pred_sent, ref_sents = self.check_tokenized(pred_sent, ref_sents) 150 | 151 | # compute n-gram matches 152 | for i in range(self.max_ngram): 153 | self.hits[i] += self.compute_hits(i + 1, pred_sent, ref_sents) 154 | self.cand_lens[i] += len(pred_sent) - i 155 | 156 | # take the reference that is closest in length to the candidate 157 | # (if there are two of the same distance, take the shorter one) 158 | closest_ref = min(ref_sents, key=lambda ref_sent: (abs(len(ref_sent) - len(pred_sent)), len(ref_sent))) 159 | self.ref_len += len(closest_ref) 160 | 161 | def score(self): 162 | """Return the current BLEU score, according to the accumulated counts.""" 163 | return self.bleu() 164 | 165 | def compute_hits(self, n, pred_sent, ref_sents): 166 | """Compute clipped n-gram hits for the given sentences and the given N 167 | 168 | @param n: n-gram 'N' (1 for unigrams, 2 for bigrams etc.) 169 | @param pred_sent: the system output sentence (tree/tokens) 170 | @param ref_sents: the corresponding reference sentences (list/tuple of trees/tokens) 171 | """ 172 | merged_ref_ngrams = self.get_ngram_counts(n, ref_sents) 173 | pred_ngrams = self.get_ngram_counts(n, [pred_sent]) 174 | 175 | hits = 0 176 | for ngram, cnt in pred_ngrams.items(): 177 | hits += min(merged_ref_ngrams.get(ngram, 0), cnt) 178 | 179 | return hits 180 | 181 | def bleu(self): 182 | """Return the current BLEU score, according to the accumulated counts.""" 183 | # brevity penalty (smoothed a bit: if candidate length is 0, we change it to 1e-5 184 | # to avoid division by zero) 185 | bp = 1.0 186 | if (self.cand_lens[0] <= self.ref_len): 187 | bp = math.exp(1.0 - old_div(self.ref_len, 188 | (float(self.cand_lens[0]) if self.cand_lens[0] else 1e-5))) 189 | 190 | return bp * self.ngram_precision() 191 | 192 | def ngram_precision(self): 193 | """Return the current n-gram precision (harmonic mean of n-gram precisions up to max_ngram) 194 | according to the accumulated counts.""" 195 | prec_log_sum = 0.0 196 | for n_hits, n_len in zip(self.hits, self.cand_lens): 197 | n_hits += self.smoothing # pre-set smoothing 198 | n_len += self.smoothing 199 | n_hits = max(n_hits, self.TINY) # forced smoothing just a litle to make BLEU defined 200 | n_len = max(n_len, self.SMALL) # only applied for zeros 201 | prec_log_sum += math.log(old_div(n_hits, n_len)) 202 | 203 | return math.exp((1.0 / self.max_ngram) * prec_log_sum) 204 | 205 | 206 | class NISTScore(NGramScore): 207 | """An accumulator object capable of computing NIST score using multiple references.""" 208 | 209 | # NIST beta parameter setting (copied from mteval-13a.pl) 210 | BETA = old_div(- math.log(0.5), math.log(1.5) ** 2) 211 | 212 | def __init__(self, max_ngram=5, case_sensitive=False): 213 | """Create the scoring object. 214 | @param max_ngram: the n-gram level to compute the score for (default: 5) 215 | @param case_sensitive: use case-sensitive matching (default: no) 216 | """ 217 | super(NISTScore, self).__init__(max_ngram, case_sensitive) 218 | self.reset() 219 | 220 | def reset(self): 221 | """Reset the object, zero all counters.""" 222 | self.ref_ngrams = [defaultdict(int) for _ in range(self.max_ngram + 1)] # has 0-grams 223 | # these two don't have 0-grams 224 | self.hit_ngrams = [[] for _ in range(self.max_ngram)] 225 | self.cand_lens = [[] for _ in range(self.max_ngram)] 226 | self.avg_ref_len = 0.0 227 | 228 | def append(self, pred_sent, ref_sents): 229 | """Append a sentence for measurements, increase counters. 230 | 231 | @param pred_sent: the system output sentence (string/list of tokens) 232 | @param ref_sents: the corresponding reference sentences (list of strings/lists of tokens) 233 | """ 234 | pred_sent, ref_sents = self.check_tokenized(pred_sent, ref_sents) 235 | # collect ngram matches 236 | for n in range(self.max_ngram): 237 | self.cand_lens[n].append(len(pred_sent) - n) # keep track of output length 238 | merged_ref_ngrams = self.get_ngram_counts(n + 1, ref_sents) 239 | pred_ngrams = self.get_ngram_counts(n + 1, [pred_sent]) 240 | # collect ngram matches 241 | hit_ngrams = {} 242 | for ngram in pred_ngrams: 243 | hits = min(pred_ngrams[ngram], merged_ref_ngrams.get(ngram, 0)) 244 | if hits: 245 | hit_ngrams[ngram] = hits 246 | self.hit_ngrams[n].append(hit_ngrams) 247 | # collect total reference ngram counts 248 | for ref_sent in ref_sents: 249 | for ngram in self.ngrams(n + 1, ref_sent): 250 | self.ref_ngrams[n + 1][ngram] += 1 251 | # ref_ngrams: use 0-grams for information value as well 252 | ref_len_sum = sum(len(ref_sent) for ref_sent in ref_sents) 253 | self.ref_ngrams[0][()] += ref_len_sum 254 | # collect average reference length 255 | self.avg_ref_len += ref_len_sum / float(len(ref_sents)) 256 | 257 | def score(self): 258 | """Return the current NIST score, according to the accumulated counts.""" 259 | return self.nist() 260 | 261 | def info(self, ngram): 262 | """Return the NIST informativeness of an n-gram.""" 263 | if ngram not in self.ref_ngrams[len(ngram)]: 264 | return 0.0 265 | return math.log(self.ref_ngrams[len(ngram) - 1][ngram[:-1]] / 266 | float(self.ref_ngrams[len(ngram)][ngram]), 2) 267 | 268 | def nist_length_penalty(self, lsys, avg_lref): 269 | """Compute the NIST length penalty, based on system output length & average reference length. 270 | @param lsys: total system output length 271 | @param avg_lref: total average reference length 272 | @return: NIST length penalty term 273 | """ 274 | ratio = lsys / float(avg_lref) 275 | if ratio >= 1: 276 | return 1 277 | if ratio <= 0: 278 | return 0 279 | return math.exp(-self.BETA * math.log(ratio) ** 2) 280 | 281 | def nist(self): 282 | """Return the current NIST score, according to the accumulated counts.""" 283 | # 1st NIST term 284 | hit_infos = [0.0 for _ in range(self.max_ngram)] 285 | for n in range(self.max_ngram): 286 | for hit_ngrams in self.hit_ngrams[n]: 287 | hit_infos[n] += sum(self.info(ngram) * hits for ngram, hits in hit_ngrams.items()) 288 | total_lens = [sum(self.cand_lens[n]) for n in range(self.max_ngram)] 289 | nist_sum = sum(old_div(hit_info, total_len) for hit_info, total_len in zip(hit_infos, total_lens)) 290 | # length penalty term 291 | bp = self.nist_length_penalty(sum(self.cand_lens[0]), self.avg_ref_len) 292 | return bp * nist_sum 293 | -------------------------------------------------------------------------------- /LCG/models/._.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MtSomeThree/constrDecoding/5cacf8515352806a14d389b813da7f761c895761/LCG/models/._.DS_Store -------------------------------------------------------------------------------- /LCG/pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plan-And-Write Style Autoregressive Model Baseline for Lexically-Constrained Text Generation 3 | """ 4 | import torch 5 | import pickle 6 | import torch.utils.data as datautils 7 | from IPython import embed 8 | import argparse 9 | import os 10 | import tqdm 11 | from data.dataloader import NaiveTokenizer, RandomKeywordSequentialDataset, GivenKeywordSequentialDataset, DomainAdaptationSequentialDataset 12 | from transformers.optimization import AdamW 13 | from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer 14 | from models.decoder.modeling_cdgpt import CDGPT2LMHeadModel 15 | 16 | 17 | def str2bool(v): 18 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 19 | return True 20 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 21 | return False 22 | else: 23 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | "--continue_training", 29 | default=False, 30 | type=str2bool, 31 | required=False, 32 | help="Continue the training or start from scratch.", 33 | ) 34 | parser.add_argument( 35 | "--dataset", 36 | default="yelp_review", 37 | type=str, 38 | required=False, 39 | help="Target dataset. Default: wmt16roen." 40 | ) 41 | parser.add_argument( 42 | "--eval_mode", 43 | default=False, 44 | type=str2bool, 45 | required=False, 46 | help="load the latest checkpoint and run the eval pipeline.", 47 | ) 48 | parser.add_argument( 49 | "--batch_size", 50 | default=128, 51 | type=int, 52 | required=False, 53 | help="effective batchsize", 54 | ) 55 | parser.add_argument( 56 | "--iter_per", 57 | default=8, 58 | type=int, 59 | required=False, 60 | help="cumulative gradient iteration cycle", 61 | ) 62 | 63 | args = parser.parse_args() 64 | dataset_name = args.dataset 65 | 66 | if dataset_name[0:11] == "yelp_review": 67 | keyword_num = 7 68 | elif dataset_name[0:4] == "news": 69 | keyword_num = 4 70 | else: 71 | keyword_num = None 72 | 73 | eval_mode = args.eval_mode 74 | batch_size = args.batch_size 75 | iter_per = args.iter_per 76 | directory_identifier = "basemodel_%s_%s" % (dataset_name, "wlex") 77 | continue_training = args.continue_training 78 | if not os.path.exists("./checkpoints"): 79 | os.mkdir("checkpoints") 80 | if not os.path.exists("./checkpoints/%s" % directory_identifier): 81 | os.mkdir("checkpoints/%s" % directory_identifier) 82 | ckpt_name = "gpt2" 83 | 84 | try: 85 | tokenizer, dataset, val_dataset = pickle.load( 86 | open("checkpoints/baselinedataset-%s-%s.pyc" % (dataset_name, "wlex" ), "rb")) 87 | # raise NotImplementedError() 88 | except: 89 | tokenizer = GPT2Tokenizer.from_pretrained(ckpt_name) 90 | tokenizer.bos_token = "$" 91 | tokenizer.sep_token = "#" 92 | if dataset_name[0:11] == "yelp_review" or dataset_name[0:4] == "news": 93 | dataset = RandomKeywordSequentialDataset(tokenizer=tokenizer, max_len=384, keyword_num=keyword_num) 94 | dataset.add(dataset_name) 95 | else: 96 | import datasets 97 | dataset_raw = datasets.load_dataset(dataset_name, split="train") 98 | dataset = GivenKeywordSequentialDataset(tokenizer=tokenizer, max_len=384) 99 | dataset.add(dataset_raw, field_keywords="concepts", field_sequence="target") 100 | val_dataset = None 101 | pickle.dump((tokenizer, dataset, val_dataset), 102 | open("checkpoints/baselinedataset-%s-%s.pyc" % (dataset_name, "wlex"), "wb")) 103 | 104 | dataset.__getitem__(0) 105 | base_model = GPT2LMHeadModel.from_pretrained(ckpt_name) 106 | base_model.train() 107 | base_model.cuda() 108 | generator = base_model 109 | 110 | dataloader = datautils.DataLoader( 111 | dataset, batch_size=batch_size // iter_per, shuffle=True, drop_last=False, pin_memory=True, 112 | num_workers=8 113 | ) 114 | 115 | 116 | from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup, get_constant_schedule 117 | opt = AdamW(lr=1e-5, weight_decay=0.02, 118 | eps=1e-8, params=base_model.parameters()) 119 | # lr_scheduler = get_constant_schedule(opt) 120 | lr_scheduler = get_linear_schedule_with_warmup(opt, num_training_steps=5000000, 121 | num_warmup_steps=400) 122 | if continue_training: 123 | generator.load_state_dict(torch.load("checkpoints/%s/pretrained" % directory_identifier)) 124 | opt_ = torch.load("checkpoints/%s/opt" % directory_identifier) 125 | opt.load_state_dict(opt_) 126 | epoch_idx, iter_count = torch.load("checkpoints/%s/EpochIdx" % directory_identifier) 127 | else: 128 | epoch_idx, iter_count = 0, 0 129 | if type(dataloader) is list: 130 | for curriculum_step_i in range(len(dataloader)): 131 | flog_train = open("checkpoints/%s/log-%d.txt" % (directory_identifier, curriculum_step_i), "w") 132 | flog_train.close() 133 | flog_eval = open("checkpoints/%s/log-eval.txt" % (directory_identifier), "w") 134 | flog_eval.close() 135 | if eval_mode: 136 | epoch_id_spec = 4 137 | generator.load_state_dict(torch.load("checkpoints/%s/pretrained-%d" % (directory_identifier, epoch_id_spec))) 138 | torch.save(base_model.state_dict(), "checkpoints/%s/pretrained" % (directory_identifier)) 139 | generator.eval() 140 | def get_top_k_and_prob(line, always_tracing=None, k=3): 141 | with torch.no_grad(): 142 | input_ids = torch.tensor( 143 | [tokenizer.bos_token_id] + tokenizer.encode("kid room dance") + [tokenizer.sep_token_id] + tokenizer.encode(line)) 144 | logits = base_model(input_ids.cuda().unsqueeze(dim=0)).logits.log_softmax(dim=-1) 145 | logits = logits[0][-1] 146 | argmax_logits = logits.topk(k=k) 147 | print("Top %d token: %s" % (k, " ".join([tokenizer.decode(idx) for idx in argmax_logits.indices]))) 148 | print("Top %d token: %s" % (k, " ".join(["%d" % idx for idx in argmax_logits.indices]))) 149 | print("Top %d prob: %s"% (k, " ".join(["%f" % probs.exp() for probs in argmax_logits.values]))) 150 | argmax_logits = (-logits).topk(k=k) 151 | print("Lowest %d token: %s" % (k, " ".join([tokenizer.decode(idx) for idx in argmax_logits.indices]))) 152 | print("Lowest %d prob: %s"% (k, " ".join(["%f" % -probs for probs in argmax_logits.values]))) 153 | if always_tracing is not None: 154 | always_tracing = torch.tensor(tokenizer.encode(always_tracing, add_special_tokens=False)) 155 | print("Pinned token: %s" % (" ".join([tokenizer.decode(idx) for idx in always_tracing]))) 156 | print("Pinned token: %s" % (" ".join(["%d" % idx for idx in always_tracing]))) 157 | print("Pinned token prob: %s"% (" ".join(["%f" % probs.exp() for probs in logits[always_tracing]]))) 158 | if type(tokenizer) is NaiveTokenizer: 159 | tokenizer.close_vocab() 160 | fout = open("./data/%s/%s-generated-baseline.txt" % (dataset_name, dataset_name), "w") 161 | with open("./data/%s/%s-keys.txt" % (dataset_name, dataset_name), "r") as fin: 162 | for line in tqdm.tqdm(fin.readlines()): 163 | input_ids = torch.tensor([tokenizer.bos_token_id] + tokenizer.encode(line.strip()) + [tokenizer.sep_token_id]) 164 | # generated = generator.generate(input_ids=input_ids.cuda().unsqueeze(dim=0), max_length=300, do_sample=True, top_p=0.4, top_k=5) 165 | generated = generator.generate(input_ids=input_ids.cuda().unsqueeze(dim=0), max_length=300, 166 | num_beams=20, pad_token_id=tokenizer.eos_token_id) 167 | generated_str = tokenizer.decode(generated[0][len(input_ids):-1]) 168 | print(generated_str, file=fout) 169 | fout.close() 170 | exit() 171 | truncate_num = 0 172 | for epoch_id in range(500): 173 | iterator = tqdm.tqdm(dataloader) 174 | for input_ids, mask, effective_len in iterator: 175 | max_len = effective_len.max() 176 | input_ids = input_ids.cuda()[:, truncate_num:max_len] 177 | mask = mask.cuda()[:, truncate_num:max_len - 1] 178 | 179 | logits = base_model(input_ids).logits.log_softmax(dim=-1)[:, :-1, :] 180 | nll_all = - (logits.gather(dim=-1, index=input_ids[:, 1:].unsqueeze(dim=-1)).reshape_as(mask) * mask) 181 | nll_reduced = nll_all.sum(dim=-1) 182 | 183 | if iter_count % iter_per == 0: 184 | opt.zero_grad() 185 | (nll_reduced / iter_per).mean(dim=0).backward() 186 | if iter_count % iter_per == iter_per - 1: 187 | opt.step() 188 | lr_scheduler.step() 189 | if (iter_count // iter_per) % 10 == 0: 190 | iterator.write("Iteration %d-%d, Loss %f" % ( 191 | epoch_idx, iter_count // iter_per, nll_reduced.mean(dim=0).cpu().item())) 192 | # embed(); exit() 193 | iter_count += 1 194 | epoch_idx += 1 195 | torch.save(base_model.state_dict(), "checkpoints/%s/pretrained-%d" % (directory_identifier, epoch_id)) 196 | torch.save((opt.state_dict(), lr_scheduler.state_dict()), "checkpoints/%s/opt" % directory_identifier) 197 | torch.save((epoch_idx, iter_count), "checkpoints/%s/EpochIdx" % directory_identifier) 198 | 199 | if __name__ == "__main__": 200 | main() 201 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from __future__ import absolute_import 12 | from builtins import object 13 | from .bleu_scorer import BleuScorer 14 | 15 | 16 | class Bleu(object): 17 | def __init__(self, n=4): 18 | # default compute Blue score up to 4 19 | self._n = n 20 | self._hypo_for_image = {} 21 | self.ref_for_image = {} 22 | 23 | def compute_score(self, gts, res): 24 | 25 | assert(list(gts.keys()) == list(res.keys())) 26 | imgIds = list(gts.keys()) 27 | 28 | bleu_scorer = BleuScorer(n=self._n) 29 | for id in imgIds: 30 | hypo = res[id] 31 | ref = gts[id] 32 | 33 | # Sanity check. 34 | assert(type(hypo) is list) 35 | assert(len(hypo) == 1) 36 | assert(type(ref) is list) 37 | assert(len(ref) >= 1) 38 | 39 | bleu_scorer += (hypo[0], ref) 40 | 41 | #score, scores = bleu_scorer.compute_score(option='shortest') 42 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1) 43 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 44 | 45 | # return (bleu, bleu_info) 46 | return score, scores 47 | 48 | def method(self): 49 | return "Bleu" 50 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from builtins import zip 22 | from builtins import range 23 | from builtins import object 24 | from past.utils import old_div 25 | import copy 26 | import sys, math, re 27 | from collections import defaultdict 28 | 29 | def precook(s, n=4, out=False): 30 | """Takes a string as input and returns an object that can be given to 31 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 32 | can take string arguments as well.""" 33 | words = s.split() 34 | counts = defaultdict(int) 35 | for k in range(1,n+1): 36 | for i in range(len(words)-k+1): 37 | ngram = tuple(words[i:i+k]) 38 | counts[ngram] += 1 39 | return (len(words), counts) 40 | 41 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 42 | '''Takes a list of reference sentences for a single segment 43 | and returns an object that encapsulates everything that BLEU 44 | needs to know about them.''' 45 | 46 | reflen = [] 47 | maxcounts = {} 48 | for ref in refs: 49 | rl, counts = precook(ref, n) 50 | reflen.append(rl) 51 | for (ngram,count) in counts.items(): 52 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 53 | 54 | # Calculate effective reference sentence length. 55 | if eff == "shortest": 56 | reflen = min(reflen) 57 | elif eff == "average": 58 | reflen = float(sum(reflen))/len(reflen) 59 | 60 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 61 | 62 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 63 | 64 | return (reflen, maxcounts) 65 | 66 | def cook_test(test, xxx_todo_changeme, eff=None, n=4): 67 | '''Takes a test sentence and returns an object that 68 | encapsulates everything that BLEU needs to know about it.''' 69 | (reflen, refmaxcounts) = xxx_todo_changeme 70 | testlen, counts = precook(test, n, True) 71 | 72 | result = {} 73 | 74 | # Calculate effective reference sentence length. 75 | 76 | if eff == "closest": 77 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 78 | else: ## i.e., "average" or "shortest" or None 79 | result["reflen"] = reflen 80 | 81 | result["testlen"] = testlen 82 | 83 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 84 | 85 | result['correct'] = [0]*n 86 | for (ngram, count) in counts.items(): 87 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 88 | 89 | return result 90 | 91 | class BleuScorer(object): 92 | """Bleu scorer. 93 | """ 94 | 95 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 96 | # special_reflen is used in oracle (proportional effective ref len for a node). 97 | 98 | def copy(self): 99 | ''' copy the refs.''' 100 | new = BleuScorer(n=self.n) 101 | new.ctest = copy.copy(self.ctest) 102 | new.crefs = copy.copy(self.crefs) 103 | new._score = None 104 | return new 105 | 106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 107 | ''' singular instance ''' 108 | 109 | self.n = n 110 | self.crefs = [] 111 | self.ctest = [] 112 | self.cook_append(test, refs) 113 | self.special_reflen = special_reflen 114 | 115 | def cook_append(self, test, refs): 116 | '''called by constructor and __iadd__ to avoid creating new instances.''' 117 | 118 | if refs is not None: 119 | self.crefs.append(cook_refs(refs)) 120 | if test is not None: 121 | cooked_test = cook_test(test, self.crefs[-1]) 122 | self.ctest.append(cooked_test) ## N.B.: -1 123 | else: 124 | self.ctest.append(None) # lens of crefs and ctest have to match 125 | 126 | self._score = None ## need to recompute 127 | 128 | def ratio(self, option=None): 129 | self.compute_score(option=option) 130 | return self._ratio 131 | 132 | def score_ratio(self, option=None): 133 | '''return (bleu, len_ratio) pair''' 134 | return (self.fscore(option=option), self.ratio(option=option)) 135 | 136 | def score_ratio_str(self, option=None): 137 | return "%.4f (%.2f)" % self.score_ratio(option) 138 | 139 | def reflen(self, option=None): 140 | self.compute_score(option=option) 141 | return self._reflen 142 | 143 | def testlen(self, option=None): 144 | self.compute_score(option=option) 145 | return self._testlen 146 | 147 | def retest(self, new_test): 148 | if type(new_test) is str: 149 | new_test = [new_test] 150 | assert len(new_test) == len(self.crefs), new_test 151 | self.ctest = [] 152 | for t, rs in zip(new_test, self.crefs): 153 | self.ctest.append(cook_test(t, rs)) 154 | self._score = None 155 | 156 | return self 157 | 158 | def rescore(self, new_test): 159 | ''' replace test(s) with new test(s), and returns the new score.''' 160 | 161 | return self.retest(new_test).compute_score() 162 | 163 | def size(self): 164 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 165 | return len(self.crefs) 166 | 167 | def __iadd__(self, other): 168 | '''add an instance (e.g., from another sentence).''' 169 | 170 | if type(other) is tuple: 171 | ## avoid creating new BleuScorer instances 172 | self.cook_append(other[0], other[1]) 173 | else: 174 | assert self.compatible(other), "incompatible BLEUs." 175 | self.ctest.extend(other.ctest) 176 | self.crefs.extend(other.crefs) 177 | self._score = None ## need to recompute 178 | 179 | return self 180 | 181 | def compatible(self, other): 182 | return isinstance(other, BleuScorer) and self.n == other.n 183 | 184 | def single_reflen(self, option="average"): 185 | return self._single_reflen(self.crefs[0][0], option) 186 | 187 | def _single_reflen(self, reflens, option=None, testlen=None): 188 | 189 | if option == "shortest": 190 | reflen = min(reflens) 191 | elif option == "average": 192 | reflen = float(sum(reflens))/len(reflens) 193 | elif option == "closest": 194 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 195 | else: 196 | assert False, "unsupported reflen option %s" % option 197 | 198 | return reflen 199 | 200 | def recompute_score(self, option=None, verbose=0): 201 | self._score = None 202 | return self.compute_score(option, verbose) 203 | 204 | def compute_score(self, option=None, verbose=0): 205 | n = self.n 206 | small = 1e-9 207 | tiny = 1e-15 ## so that if guess is 0 still return 0 208 | bleu_list = [[] for _ in range(n)] 209 | 210 | if self._score is not None: 211 | return self._score 212 | 213 | if option is None: 214 | option = "average" if len(self.crefs) == 1 else "closest" 215 | 216 | self._testlen = 0 217 | self._reflen = 0 218 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 219 | 220 | # for each sentence 221 | for comps in self.ctest: 222 | testlen = comps['testlen'] 223 | self._testlen += testlen 224 | 225 | if self.special_reflen is None: ## need computation 226 | reflen = self._single_reflen(comps['reflen'], option, testlen) 227 | else: 228 | reflen = self.special_reflen 229 | 230 | self._reflen += reflen 231 | 232 | for key in ['guess','correct']: 233 | for k in range(n): 234 | totalcomps[key][k] += comps[key][k] 235 | 236 | # append per image bleu score 237 | bleu = 1. 238 | for k in range(n): 239 | bleu *= old_div((float(comps['correct'][k]) + tiny),(float(comps['guess'][k]) + small)) 240 | bleu_list[k].append(bleu ** (1./(k+1))) 241 | ratio = old_div((testlen + tiny), (reflen + small)) ## N.B.: avoid zero division 242 | if ratio < 1: 243 | for k in range(n): 244 | bleu_list[k][-1] *= math.exp(1 - old_div(1,ratio)) 245 | 246 | if verbose > 1: 247 | print(comps, reflen) 248 | 249 | totalcomps['reflen'] = self._reflen 250 | totalcomps['testlen'] = self._testlen 251 | 252 | bleus = [] 253 | bleu = 1. 254 | for k in range(n): 255 | bleu *= float(totalcomps['correct'][k] + tiny) \ 256 | / (totalcomps['guess'][k] + small) 257 | bleus.append(bleu ** (1./(k+1))) 258 | ratio = old_div((self._testlen + tiny), (self._reflen + small)) ## N.B.: avoid zero division 259 | if ratio < 1: 260 | for k in range(n): 261 | bleus[k] *= math.exp(1 - old_div(1,ratio)) 262 | 263 | if verbose > 0: 264 | print(totalcomps) 265 | print("ratio:", ratio) 266 | 267 | self._score = bleus 268 | return self._score, bleu_list 269 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/cider/cider.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | # Filename: cider.py 3 | # 4 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 5 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 6 | # 7 | # Creation Date: Sun Feb 8 14:16:54 2015 8 | # 9 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 10 | 11 | from builtins import object 12 | from .cider_scorer import CiderScorer 13 | import pdb 14 | 15 | class Cider(object): 16 | """ 17 | Main Class to compute the CIDEr metric 18 | 19 | """ 20 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 21 | # set cider to sum over 1 to 4-grams 22 | self._n = n 23 | # set the standard deviation parameter for gaussian penalty 24 | self._sigma = sigma 25 | 26 | def compute_score(self, gts, res): 27 | """ 28 | Main function to compute CIDEr score 29 | :param hypo_for_image (dict) : dictionary with key and value 30 | ref_for_image (dict) : dictionary with key and value 31 | :return: cider (float) : computed CIDEr score for the corpus 32 | """ 33 | 34 | assert(list(gts.keys()) == list(res.keys())) 35 | imgIds = list(gts.keys()) 36 | 37 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 38 | 39 | for id in imgIds: 40 | hypo = res[id] 41 | ref = gts[id] 42 | 43 | # Sanity check. 44 | assert(type(hypo) is list) 45 | assert(len(hypo) == 1) 46 | assert(type(ref) is list) 47 | assert(len(ref) > 0) 48 | 49 | cider_scorer += (hypo[0], ref) 50 | 51 | (score, scores) = cider_scorer.compute_score() 52 | 53 | return score, scores 54 | 55 | def method(self): 56 | return "CIDEr" -------------------------------------------------------------------------------- /LCG/pycocoevalcap/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | from __future__ import division 6 | from builtins import zip 7 | from builtins import range 8 | from builtins import object 9 | from past.utils import old_div 10 | import copy 11 | from collections import defaultdict 12 | import numpy as np 13 | import pdb 14 | import math 15 | 16 | def precook(s, n=4, out=False): 17 | """ 18 | Takes a string as input and returns an object that can be given to 19 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 20 | can take string arguments as well. 21 | :param s: string : sentence to be converted into ngrams 22 | :param n: int : number of ngrams for which representation is calculated 23 | :return: term frequency vector for occuring ngrams 24 | """ 25 | words = s.split() 26 | counts = defaultdict(int) 27 | for k in range(1,n+1): 28 | for i in range(len(words)-k+1): 29 | ngram = tuple(words[i:i+k]) 30 | counts[ngram] += 1 31 | return counts 32 | 33 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 34 | '''Takes a list of reference sentences for a single segment 35 | and returns an object that encapsulates everything that BLEU 36 | needs to know about them. 37 | :param refs: list of string : reference sentences for some image 38 | :param n: int : number of ngrams for which (ngram) representation is calculated 39 | :return: result (list of dict) 40 | ''' 41 | return [precook(ref, n) for ref in refs] 42 | 43 | def cook_test(test, n=4): 44 | '''Takes a test sentence and returns an object that 45 | encapsulates everything that BLEU needs to know about it. 46 | :param test: list of string : hypothesis sentence for some image 47 | :param n: int : number of ngrams for which (ngram) representation is calculated 48 | :return: result (dict) 49 | ''' 50 | return precook(test, n, True) 51 | 52 | class CiderScorer(object): 53 | """CIDEr scorer. 54 | """ 55 | 56 | def copy(self): 57 | ''' copy the refs.''' 58 | new = CiderScorer(n=self.n) 59 | new.ctest = copy.copy(self.ctest) 60 | new.crefs = copy.copy(self.crefs) 61 | return new 62 | 63 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 64 | ''' singular instance ''' 65 | self.n = n 66 | self.sigma = sigma 67 | self.crefs = [] 68 | self.ctest = [] 69 | self.document_frequency = defaultdict(float) 70 | self.cook_append(test, refs) 71 | self.ref_len = None 72 | 73 | def cook_append(self, test, refs): 74 | '''called by constructor and __iadd__ to avoid creating new instances.''' 75 | 76 | if refs is not None: 77 | self.crefs.append(cook_refs(refs)) 78 | if test is not None: 79 | self.ctest.append(cook_test(test)) ## N.B.: -1 80 | else: 81 | self.ctest.append(None) # lens of crefs and ctest have to match 82 | 83 | def size(self): 84 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 85 | return len(self.crefs) 86 | 87 | def __iadd__(self, other): 88 | '''add an instance (e.g., from another sentence).''' 89 | 90 | if type(other) is tuple: 91 | ## avoid creating new CiderScorer instances 92 | self.cook_append(other[0], other[1]) 93 | else: 94 | self.ctest.extend(other.ctest) 95 | self.crefs.extend(other.crefs) 96 | 97 | return self 98 | def compute_doc_freq(self): 99 | ''' 100 | Compute term frequency for reference data. 101 | This will be used to compute idf (inverse document frequency later) 102 | The term frequency is stored in the object 103 | :return: None 104 | ''' 105 | for refs in self.crefs: 106 | # refs, k ref captions of one image 107 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 108 | self.document_frequency[ngram] += 1 109 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 110 | 111 | def compute_cider(self): 112 | def counts2vec(cnts): 113 | """ 114 | Function maps counts of ngram to vector of tfidf weights. 115 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 116 | The n-th entry of array denotes length of n-grams. 117 | :param cnts: 118 | :return: vec (array of dict), norm (array of float), length (int) 119 | """ 120 | vec = [defaultdict(float) for _ in range(self.n)] 121 | length = 0 122 | norm = [0.0 for _ in range(self.n)] 123 | for (ngram,term_freq) in cnts.items(): 124 | # give word count 1 if it doesn't appear in reference corpus 125 | df = np.log(max(1.0, self.document_frequency[ngram])) 126 | # ngram index 127 | n = len(ngram)-1 128 | # tf (term_freq) * idf (precomputed idf) for n-grams 129 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 130 | # compute norm for the vector. the norm will be used for computing similarity 131 | norm[n] += pow(vec[n][ngram], 2) 132 | 133 | if n == 1: 134 | length += term_freq 135 | norm = [np.sqrt(n) for n in norm] 136 | return vec, norm, length 137 | 138 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 139 | ''' 140 | Compute the cosine similarity of two vectors. 141 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 142 | :param vec_ref: array of dictionary for vector corresponding to reference 143 | :param norm_hyp: array of float for vector corresponding to hypothesis 144 | :param norm_ref: array of float for vector corresponding to reference 145 | :param length_hyp: int containing length of hypothesis 146 | :param length_ref: int containing length of reference 147 | :return: array of score for each n-grams cosine similarity 148 | ''' 149 | delta = float(length_hyp - length_ref) 150 | # measure consine similarity 151 | val = np.array([0.0 for _ in range(self.n)]) 152 | for n in range(self.n): 153 | # ngram 154 | for (ngram,count) in vec_hyp[n].items(): 155 | # vrama91 : added clipping 156 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 157 | 158 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 159 | val[n] /= (norm_hyp[n]*norm_ref[n]) 160 | 161 | assert(not math.isnan(val[n])) 162 | # vrama91: added a length based gaussian penalty 163 | val[n] *= np.e**(old_div(-(delta**2),(2*self.sigma**2))) 164 | return val 165 | 166 | # compute log reference length 167 | self.ref_len = np.log(float(len(self.crefs))) 168 | 169 | scores = [] 170 | for test, refs in zip(self.ctest, self.crefs): 171 | # compute vector for test captions 172 | vec, norm, length = counts2vec(test) 173 | # compute vector for ref captions 174 | score = np.array([0.0 for _ in range(self.n)]) 175 | for ref in refs: 176 | vec_ref, norm_ref, length_ref = counts2vec(ref) 177 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 178 | # change by vrama91 - mean of ngram scores, instead of sum 179 | score_avg = np.mean(score) 180 | # divide by number of references 181 | score_avg /= len(refs) 182 | # multiply score by 10 183 | score_avg *= 10.0 184 | # append score of an image to the score list 185 | scores.append(score_avg) 186 | return scores 187 | 188 | def compute_score(self, option=None, verbose=0): 189 | # compute idf 190 | self.compute_doc_freq() 191 | # assert to check document frequency 192 | assert(len(self.ctest) >= max(self.document_frequency.values())) 193 | # compute cider score 194 | score = self.compute_cider() 195 | # debug 196 | # print score 197 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /LCG/pycocoevalcap/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from builtins import zip 4 | from builtins import object 5 | __author__ = 'tylin' 6 | from .tokenizer.ptbtokenizer import PTBTokenizer 7 | from .bleu.bleu import Bleu 8 | from .meteor.meteor import Meteor 9 | from .rouge.rouge import Rouge 10 | from .cider.cider import Cider 11 | import sys 12 | 13 | class COCOEvalCap(object): 14 | def __init__(self, coco, cocoRes): 15 | self.evalImgs = [] 16 | self.eval = {} 17 | self.imgToEval = {} 18 | self.coco = coco 19 | self.cocoRes = cocoRes 20 | self.params = {'image_id': coco.getImgIds()} 21 | 22 | def evaluate(self): 23 | imgIds = self.params['image_id'] 24 | # imgIds = self.coco.getImgIds() 25 | gts = {} 26 | res = {} 27 | for imgId in imgIds: 28 | gts[imgId] = self.coco.imgToAnns[imgId] 29 | res[imgId] = self.cocoRes.imgToAnns[imgId] 30 | 31 | # ================================================= 32 | # Set up scorers 33 | # ================================================= 34 | print('tokenization...', file=sys.stderr) 35 | tokenizer = PTBTokenizer() 36 | gts = tokenizer.tokenize(gts) 37 | res = tokenizer.tokenize(res) 38 | 39 | # ================================================= 40 | # Set up scorers 41 | # ================================================= 42 | print('setting up scorers...', file=sys.stderr) 43 | scorers = [ 44 | (Meteor(),"METEOR"), 45 | (Rouge(), "ROUGE_L"), 46 | (Cider(), "CIDEr") 47 | ] 48 | 49 | # ================================================= 50 | # Compute scores 51 | # ================================================= 52 | for scorer, method in scorers: 53 | print('computing %s score...'%(scorer.method()), file=sys.stderr) 54 | score, scores = scorer.compute_score(gts, res) 55 | if type(method) == list: 56 | for sc, scs, m in zip(score, scores, method): 57 | self.setEval(sc, m) 58 | self.setImgToEvalImgs(scs, list(gts.keys()), m) 59 | print("%s: %0.3f"%(m, sc), file=sys.stderr) 60 | else: 61 | self.setEval(score, method) 62 | self.setImgToEvalImgs(scores, list(gts.keys()), method) 63 | print("%s: %0.3f"%(method, score), file=sys.stderr) 64 | self.setEvalImgs() 65 | 66 | def setEval(self, score, method): 67 | self.eval[method] = score 68 | 69 | def setImgToEvalImgs(self, scores, imgIds, method): 70 | for imgId, score in zip(imgIds, scores): 71 | if not imgId in self.imgToEval: 72 | self.imgToEval[imgId] = {} 73 | self.imgToEval[imgId]["image_id"] = imgId 74 | self.imgToEval[imgId][method] = score 75 | 76 | def setEvalImgs(self): 77 | self.evalImgs = [eval for imgId, eval in list(self.imgToEval.items())] 78 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/meteor/data/paraphrase-en.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MtSomeThree/constrDecoding/5cacf8515352806a14d389b813da7f761c895761/LCG/pycocoevalcap/meteor/data/paraphrase-en.gz -------------------------------------------------------------------------------- /LCG/pycocoevalcap/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MtSomeThree/constrDecoding/5cacf8515352806a14d389b813da7f761c895761/LCG/pycocoevalcap/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /LCG/pycocoevalcap/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | from builtins import range 7 | from builtins import object 8 | import os 9 | import sys 10 | import subprocess 11 | import threading 12 | 13 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 14 | METEOR_JAR = 'meteor-1.5.jar' 15 | # print METEOR_JAR 16 | 17 | class Meteor(object): 18 | 19 | def __init__(self): 20 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 21 | '-', '-', '-stdio', '-l', 'en', '-norm'] 22 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 23 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 24 | stdin=subprocess.PIPE, \ 25 | stdout=subprocess.PIPE, \ 26 | stderr=subprocess.PIPE) 27 | # Used to guarantee thread safety 28 | self.lock = threading.Lock() 29 | 30 | def compute_score(self, gts, res): 31 | assert(list(gts.keys()) == list(res.keys())) 32 | imgIds = list(gts.keys()) 33 | scores = [] 34 | 35 | eval_line = 'EVAL' 36 | self.lock.acquire() 37 | for i in imgIds: 38 | assert(len(res[i]) == 1) 39 | stat = self._stat(res[i][0], gts[i]) 40 | eval_line += ' ||| {}'.format(stat) 41 | 42 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode('UTF-8')) 43 | self.meteor_p.stdin.flush() 44 | for i in range(0,len(imgIds)): 45 | scores.append(float(self.meteor_p.stdout.readline().decode('UTF-8').strip())) 46 | score = float(self.meteor_p.stdout.readline().strip()) 47 | self.lock.release() 48 | 49 | return score, scores 50 | 51 | def method(self): 52 | return "METEOR" 53 | 54 | def _stat(self, hypothesis_str, reference_list): 55 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 56 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 57 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 58 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode('UTF-8')) 59 | self.meteor_p.stdin.flush() 60 | res = self.meteor_p.stdout.readline().decode('UTF-8').strip() 61 | return res 62 | 63 | def _score(self, hypothesis_str, reference_list): 64 | self.lock.acquire() 65 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 66 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 67 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 68 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode('UTF-8')) 69 | self.meteor_p.stdin.flush() 70 | stats = self.meteor_p.stdout.readline().strip() 71 | eval_line = 'EVAL ||| {}'.format(stats) 72 | # EVAL ||| stats 73 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode('UTF-8')) 74 | self.meteor_p.stdin.flush() 75 | score = float(self.meteor_p.stdout.readline().decode('UTF-8').strip()) 76 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 77 | # thanks for Andrej for pointing this out 78 | score = float(self.meteor_p.stdout.readline().decode('UTF-8').strip()) 79 | self.lock.release() 80 | return score 81 | 82 | def __del__(self): 83 | self.lock.acquire() 84 | self.meteor_p.stdin.close() 85 | self.meteor_p.kill() 86 | self.meteor_p.wait() 87 | self.lock.release() 88 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | from builtins import range 11 | from builtins import object 12 | import numpy as np 13 | import pdb 14 | 15 | def my_lcs(string, sub): 16 | """ 17 | Calculates longest common subsequence for a pair of tokenized strings 18 | :param string : list of str : tokens from a string split using whitespace 19 | :param sub : list of str : shorter string, also split using whitespace 20 | :returns: length (list of int): length of the longest common subsequence between the two strings 21 | 22 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 23 | """ 24 | if(len(string)< len(sub)): 25 | sub, string = string, sub 26 | 27 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 28 | 29 | for j in range(1,len(sub)+1): 30 | for i in range(1,len(string)+1): 31 | if(string[i-1] == sub[j-1]): 32 | lengths[i][j] = lengths[i-1][j-1] + 1 33 | else: 34 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 35 | 36 | return lengths[len(string)][len(sub)] 37 | 38 | class Rouge(object): 39 | ''' 40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 41 | 42 | ''' 43 | def __init__(self): 44 | # vrama91: updated the value below based on discussion with Hovey 45 | self.beta = 1.2 46 | 47 | def calc_score(self, candidate, refs): 48 | """ 49 | Compute ROUGE-L score given one candidate and references for an image 50 | :param candidate: str : candidate sentence to be evaluated 51 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 52 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 53 | """ 54 | assert(len(candidate)==1) 55 | assert(len(refs)>0) 56 | prec = [] 57 | rec = [] 58 | 59 | # split into tokens 60 | token_c = candidate[0].split(" ") 61 | 62 | for reference in refs: 63 | # split into tokens 64 | token_r = reference.split(" ") 65 | # compute the longest common subsequence 66 | lcs = my_lcs(token_r, token_c) 67 | prec.append(lcs/float(len(token_c))) 68 | rec.append(lcs/float(len(token_r))) 69 | 70 | prec_max = max(prec) 71 | rec_max = max(rec) 72 | 73 | if(prec_max!=0 and rec_max !=0): 74 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 75 | else: 76 | score = 0.0 77 | return score 78 | 79 | def compute_score(self, gts, res): 80 | """ 81 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 82 | Invoked by evaluate_captions.py 83 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 84 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 85 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 86 | """ 87 | assert(list(gts.keys()) == list(res.keys())) 88 | imgIds = list(gts.keys()) 89 | 90 | score = [] 91 | for id in imgIds: 92 | hypo = res[id] 93 | ref = gts[id] 94 | 95 | score.append(self.calc_score(hypo, ref)) 96 | 97 | # Sanity check. 98 | assert(type(hypo) is list) 99 | assert(len(hypo) == 1) 100 | assert(type(ref) is list) 101 | assert(len(ref) > 0) 102 | 103 | average_score = np.mean(np.array(score)) 104 | return average_score, np.array(score) 105 | 106 | def method(self): 107 | return "Rouge" 108 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from builtins import zip 12 | from builtins import range 13 | from builtins import object 14 | import os 15 | import sys 16 | import subprocess 17 | import tempfile 18 | import itertools 19 | 20 | # path to the stanford corenlp jar 21 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 22 | 23 | # punctuations to be removed from the sentences 24 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 25 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 26 | 27 | class PTBTokenizer(object): 28 | """Python wrapper of Stanford PTBTokenizer""" 29 | 30 | def tokenize(self, captions_for_image): 31 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 32 | 'edu.stanford.nlp.process.PTBTokenizer', \ 33 | '-preserveLines', '-lowerCase'] 34 | 35 | # ====================================================== 36 | # prepare data for PTB Tokenizer 37 | # ====================================================== 38 | final_tokenized_captions_for_image = {} 39 | image_id = [k for k, v in list(captions_for_image.items()) for _ in range(len(v))] 40 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in list(captions_for_image.items()) for c in v]) 41 | 42 | # ====================================================== 43 | # save sentences to temporary file 44 | # ====================================================== 45 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 46 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 47 | tmp_file.write(sentences.encode('UTF-8')) 48 | tmp_file.close() 49 | 50 | # ====================================================== 51 | # tokenize sentence 52 | # ====================================================== 53 | cmd.append(os.path.basename(tmp_file.name)) 54 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 55 | stdout=subprocess.PIPE, encoding='UTF-8') 56 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 57 | lines = token_lines.split('\n') 58 | # remove temp file 59 | os.remove(tmp_file.name) 60 | 61 | # ====================================================== 62 | # create dictionary for tokenized captions 63 | # ====================================================== 64 | for k, line in zip(image_id, lines): 65 | if not k in final_tokenized_captions_for_image: 66 | final_tokenized_captions_for_image[k] = [] 67 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 68 | if w not in PUNCTUATIONS]) 69 | final_tokenized_captions_for_image[k].append(tokenized_caption) 70 | 71 | return final_tokenized_captions_for_image 72 | -------------------------------------------------------------------------------- /LCG/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MtSomeThree/constrDecoding/5cacf8515352806a14d389b813da7f761c895761/LCG/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /LCG/pycocotools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /LCG/pycocotools/coco.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from builtins import range 4 | from builtins import object 5 | from past.utils import old_div 6 | __author__ = 'tylin' 7 | __version__ = '1.0.1' 8 | # Interface for accessing the Microsoft COCO dataset. 9 | 10 | # Microsoft COCO is a large image dataset designed for object detection, 11 | # segmentation, and caption generation. pycocotools is a Python API that 12 | # assists in loading, parsing and visualizing the annotations in COCO. 13 | # Please visit http://mscoco.org/ for more information on COCO, including 14 | # for the data, paper, and tutorials. The exact format of the annotations 15 | # is also described on the COCO website. For example usage of the pycocotools 16 | # please see pycocotools_demo.ipynb. In addition to this API, please download both 17 | # the COCO images and annotations in order to run the demo. 18 | 19 | # An alternative to using the API is to load the annotations directly 20 | # into Python dictionary 21 | # Using the API provides additional utility functions. Note that this API 22 | # supports both *instance* and *caption* annotations. In the case of 23 | # captions not all functions are defined (e.g. categories are undefined). 24 | 25 | # The following API functions are defined: 26 | # COCO - COCO api class that loads COCO annotation file and prepare data structures. 27 | # decodeMask - Decode binary mask M encoded via run-length encoding. 28 | # encodeMask - Encode binary mask M using run-length encoding. 29 | # getAnnIds - Get ann ids that satisfy given filter conditions. 30 | # getCatIds - Get cat ids that satisfy given filter conditions. 31 | # getImgIds - Get img ids that satisfy given filter conditions. 32 | # loadAnns - Load anns with the specified ids. 33 | # loadCats - Load cats with the specified ids. 34 | # loadImgs - Load imgs with the specified ids. 35 | # segToMask - Convert polygon segmentation to binary mask. 36 | # showAnns - Display the specified annotations. 37 | # loadRes - Load result file and create result api object. 38 | # Throughout the API "ann"=annotation, "cat"=category, and "img"=image. 39 | # Help on each functions can be accessed by: "help COCO>function". 40 | 41 | # See also COCO>decodeMask, 42 | # COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds, 43 | # COCO>getImgIds, COCO>loadAnns, COCO>loadCats, 44 | # COCO>loadImgs, COCO>segToMask, COCO>showAnns 45 | 46 | # Microsoft COCO Toolbox. Version 1.0 47 | # Data, paper, and tutorials available at: http://mscoco.org/ 48 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2014. 49 | # Licensed under the Simplified BSD License [see bsd.txt] 50 | 51 | import json 52 | import datetime 53 | #import matplotlib.pyplot as plt 54 | from matplotlib.collections import PatchCollection 55 | from matplotlib.patches import Polygon 56 | import numpy as np 57 | from skimage.draw import polygon 58 | import copy 59 | import sys 60 | 61 | 62 | class COCO(object): 63 | def __init__(self, annotation_file=None): 64 | """ 65 | Constructor of Microsoft COCO helper class for reading and visualizing annotations. 66 | :param annotation_file (str): location of annotation file 67 | :param image_folder (str): location to the folder that hosts images. 68 | :return: 69 | """ 70 | # load dataset 71 | self.dataset = {} 72 | self.anns = [] 73 | self.imgToAnns = {} 74 | self.catToImgs = {} 75 | self.imgs = [] 76 | self.cats = [] 77 | if not annotation_file == None: 78 | print('loading annotations into memory...', file=sys.stderr) 79 | time_t = datetime.datetime.utcnow() 80 | dataset = json.load(open(annotation_file, 'r')) 81 | print(datetime.datetime.utcnow() - time_t, file=sys.stderr) 82 | self.dataset = dataset 83 | self.createIndex() 84 | 85 | def createIndex(self): 86 | # create index 87 | print('creating index...', file=sys.stderr) 88 | imgToAnns = {ann['image_id']: [] for ann in self.dataset['annotations']} 89 | anns = {ann['id']: [] for ann in self.dataset['annotations']} 90 | for ann in self.dataset['annotations']: 91 | imgToAnns[ann['image_id']] += [ann] 92 | anns[ann['id']] = ann 93 | 94 | imgs = {im['id']: {} for im in self.dataset['images']} 95 | for img in self.dataset['images']: 96 | imgs[img['id']] = img 97 | 98 | cats = [] 99 | catToImgs = [] 100 | if self.dataset['type'] == 'instances': 101 | cats = {cat['id']: [] for cat in self.dataset['categories']} 102 | for cat in self.dataset['categories']: 103 | cats[cat['id']] = cat 104 | catToImgs = {cat['id']: [] for cat in self.dataset['categories']} 105 | for ann in self.dataset['annotations']: 106 | catToImgs[ann['category_id']] += [ann['image_id']] 107 | 108 | print('index created!', file=sys.stderr) 109 | 110 | # create class members 111 | self.anns = anns 112 | self.imgToAnns = imgToAnns 113 | self.catToImgs = catToImgs 114 | self.imgs = imgs 115 | self.cats = cats 116 | 117 | def info(self): 118 | """ 119 | Print information about the annotation file. 120 | :return: 121 | """ 122 | for key, value in list(self.dataset['info'].items()): 123 | print('%s: %s'%(key, value), file=sys.stderr) 124 | 125 | def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None): 126 | """ 127 | Get ann ids that satisfy given filter conditions. default skips that filter 128 | :param imgIds (int array) : get anns for given imgs 129 | catIds (int array) : get anns for given cats 130 | areaRng (float array) : get anns for given area range (e.g. [0 inf]) 131 | iscrowd (boolean) : get anns for given crowd label (False or True) 132 | :return: ids (int array) : integer array of ann ids 133 | """ 134 | imgIds = imgIds if type(imgIds) == list else [imgIds] 135 | catIds = catIds if type(catIds) == list else [catIds] 136 | 137 | if len(imgIds) == len(catIds) == len(areaRng) == 0: 138 | anns = self.dataset['annotations'] 139 | else: 140 | if not len(imgIds) == 0: 141 | anns = sum([self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns],[]) 142 | else: 143 | anns = self.dataset['annotations'] 144 | anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds] 145 | anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]] 146 | if self.dataset['type'] == 'instances': 147 | if not iscrowd == None: 148 | ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd] 149 | else: 150 | ids = [ann['id'] for ann in anns] 151 | else: 152 | ids = [ann['id'] for ann in anns] 153 | return ids 154 | 155 | def getCatIds(self, catNms=[], supNms=[], catIds=[]): 156 | """ 157 | filtering parameters. default skips that filter. 158 | :param catNms (str array) : get cats for given cat names 159 | :param supNms (str array) : get cats for given supercategory names 160 | :param catIds (int array) : get cats for given cat ids 161 | :return: ids (int array) : integer array of cat ids 162 | """ 163 | catNms = catNms if type(catNms) == list else [catNms] 164 | supNms = supNms if type(supNms) == list else [supNms] 165 | catIds = catIds if type(catIds) == list else [catIds] 166 | 167 | if len(catNms) == len(supNms) == len(catIds) == 0: 168 | cats = self.dataset['categories'] 169 | else: 170 | cats = self.dataset['categories'] 171 | cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms] 172 | cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms] 173 | cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds] 174 | ids = [cat['id'] for cat in cats] 175 | return ids 176 | 177 | def getImgIds(self, imgIds=[], catIds=[]): 178 | ''' 179 | Get img ids that satisfy given filter conditions. 180 | :param imgIds (int array) : get imgs for given ids 181 | :param catIds (int array) : get imgs with all given cats 182 | :return: ids (int array) : integer array of img ids 183 | ''' 184 | imgIds = imgIds if type(imgIds) == list else [imgIds] 185 | catIds = catIds if type(catIds) == list else [catIds] 186 | 187 | if len(imgIds) == len(catIds) == 0: 188 | ids = list(self.imgs.keys()) 189 | else: 190 | ids = set(imgIds) 191 | for catId in catIds: 192 | if len(ids) == 0: 193 | ids = set(self.catToImgs[catId]) 194 | else: 195 | ids &= set(self.catToImgs[catId]) 196 | return list(ids) 197 | 198 | def loadAnns(self, ids=[]): 199 | """ 200 | Load anns with the specified ids. 201 | :param ids (int array) : integer ids specifying anns 202 | :return: anns (object array) : loaded ann objects 203 | """ 204 | if type(ids) == list: 205 | return [self.anns[id] for id in ids] 206 | elif type(ids) == int: 207 | return [self.anns[ids]] 208 | 209 | def loadCats(self, ids=[]): 210 | """ 211 | Load cats with the specified ids. 212 | :param ids (int array) : integer ids specifying cats 213 | :return: cats (object array) : loaded cat objects 214 | """ 215 | if type(ids) == list: 216 | return [self.cats[id] for id in ids] 217 | elif type(ids) == int: 218 | return [self.cats[ids]] 219 | 220 | def loadImgs(self, ids=[]): 221 | """ 222 | Load anns with the specified ids. 223 | :param ids (int array) : integer ids specifying img 224 | :return: imgs (object array) : loaded img objects 225 | """ 226 | if type(ids) == list: 227 | return [self.imgs[id] for id in ids] 228 | elif type(ids) == int: 229 | return [self.imgs[ids]] 230 | 231 | def showAnns(self, anns): 232 | """ 233 | Display the specified annotations. 234 | :param anns (array of object): annotations to display 235 | :return: None 236 | """ 237 | if len(anns) == 0: 238 | return 0 239 | if self.dataset['type'] == 'instances': 240 | #ax = plt.gca() 241 | polygons = [] 242 | color = [] 243 | for ann in anns: 244 | c = np.random.random((1, 3)).tolist()[0] 245 | if type(ann['segmentation']) == list: 246 | # polygon 247 | for seg in ann['segmentation']: 248 | poly = np.array(seg).reshape((old_div(len(seg),2), 2)) 249 | polygons.append(Polygon(poly, True,alpha=0.4)) 250 | color.append(c) 251 | else: 252 | # mask 253 | mask = COCO.decodeMask(ann['segmentation']) 254 | img = np.ones( (mask.shape[0], mask.shape[1], 3) ) 255 | if ann['iscrowd'] == 1: 256 | color_mask = old_div(np.array([2.0,166.0,101.0]),255) 257 | if ann['iscrowd'] == 0: 258 | color_mask = np.random.random((1, 3)).tolist()[0] 259 | for i in range(3): 260 | img[:,:,i] = color_mask[i] 261 | #ax.imshow(np.dstack( (img, mask*0.5) )) 262 | p = PatchCollection(polygons, facecolors=color, edgecolors=(0,0,0,1), linewidths=3, alpha=0.4) 263 | #ax.add_collection(p) 264 | if self.dataset['type'] == 'captions': 265 | for ann in anns: 266 | print(ann['caption'], file=sys.stderr) 267 | 268 | def loadRes(self, resFile=None, resData=None): 269 | """ 270 | Load result file and return a result api object. 271 | :param resFile (str) : file name of result file 272 | :param resData (obj) : pre-loaded result data 273 | :return: res (obj) : result api object 274 | """ 275 | assert resFile or resData, 'must be provided result data in a list or a path to result file' 276 | res = COCO() 277 | res.dataset['images'] = [img for img in self.dataset['images']] 278 | res.dataset['info'] = copy.deepcopy(self.dataset['info']) 279 | res.dataset['type'] = copy.deepcopy(self.dataset['type']) 280 | res.dataset['licenses'] = copy.deepcopy(self.dataset['licenses']) 281 | 282 | print('Loading and preparing results... ', file=sys.stderr) 283 | time_t = datetime.datetime.utcnow() 284 | if resData: 285 | anns = resData 286 | else: 287 | anns = json.load(open(resFile)) 288 | assert type(anns) == list, 'results in not an array of objects' 289 | annsImgIds = [ann['image_id'] for ann in anns] 290 | assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ 291 | 'Results do not correspond to current coco set' 292 | if 'caption' in anns[0]: 293 | imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) 294 | res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] 295 | for id, ann in enumerate(anns): 296 | ann['id'] = id 297 | elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: 298 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 299 | for id, ann in enumerate(anns): 300 | bb = ann['bbox'] 301 | x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]] 302 | ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 303 | ann['area'] = bb[2]*bb[3] 304 | ann['id'] = id 305 | ann['iscrowd'] = 0 306 | elif 'segmentation' in anns[0]: 307 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 308 | for id, ann in enumerate(anns): 309 | ann['area']=sum(ann['segmentation']['counts'][2:-1:2]) 310 | ann['bbox'] = [] 311 | ann['id'] = id 312 | ann['iscrowd'] = 0 313 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()), file=sys.stderr) 314 | 315 | res.dataset['annotations'] = anns 316 | res.createIndex() 317 | return res 318 | 319 | 320 | @staticmethod 321 | def decodeMask(R): 322 | """ 323 | Decode binary mask M encoded via run-length encoding. 324 | :param R (object RLE) : run-length encoding of binary mask 325 | :return: M (bool 2D array) : decoded binary mask 326 | """ 327 | N = len(R['counts']) 328 | M = np.zeros( (R['size'][0]*R['size'][1], )) 329 | n = 0 330 | val = 1 331 | for pos in range(N): 332 | val = not val 333 | for c in range(R['counts'][pos]): 334 | R['counts'][pos] 335 | M[n] = val 336 | n += 1 337 | return M.reshape((R['size']), order='F') 338 | 339 | @staticmethod 340 | def encodeMask(M): 341 | """ 342 | Encode binary mask M using run-length encoding. 343 | :param M (bool 2D array) : binary mask to encode 344 | :return: R (object RLE) : run-length encoding of binary mask 345 | """ 346 | [h, w] = M.shape 347 | M = M.flatten(order='F') 348 | N = len(M) 349 | counts_list = [] 350 | pos = 0 351 | # counts 352 | counts_list.append(1) 353 | diffs = np.logical_xor(M[0:N-1], M[1:N]) 354 | for diff in diffs: 355 | if diff: 356 | pos +=1 357 | counts_list.append(1) 358 | else: 359 | counts_list[pos] += 1 360 | # if array starts from 1. start with 0 counts for 0 361 | if M[0] == 1: 362 | counts_list = [0] + counts_list 363 | return {'size': [h, w], 364 | 'counts': counts_list , 365 | } 366 | 367 | @staticmethod 368 | def segToMask( S, h, w ): 369 | """ 370 | Convert polygon segmentation to binary mask. 371 | :param S (float array) : polygon segmentation mask 372 | :param h (int) : target mask height 373 | :param w (int) : target mask width 374 | :return: M (bool 2D array) : binary mask 375 | """ 376 | M = np.zeros((h,w), dtype=np.bool) 377 | for s in S: 378 | N = len(s) 379 | rr, cc = polygon(np.array(s[1:N:2]), np.array(s[0:N:2])) # (y, x) 380 | M[rr, cc] = 1 381 | return M 382 | -------------------------------------------------------------------------------- /LCG/readme.md: -------------------------------------------------------------------------------- 1 | # NADO 2 | 3 | ## Training 4 | Usage: First run 5 | ```shell 6 | python pretrain.py 7 | ``` 8 | to train the base distribution. 9 | 10 | Then run 11 | ```shell 12 | python finetune.py 13 | ``` 14 | to generate the data for training NADO layers. After the data has been generated and dumped, the script will terminate itself. Run: 15 | ```shell 16 | python finetune.py 17 | ``` 18 | again to actually train the model. 19 | 20 | ## Evaluation/Inference/Parameter tuning. 21 | 22 | See the argument descriptions of each script argument for details -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tao Meng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Controllable Text Generation with Neurally-Decomposed Oracle 2 | This repository contains the source code to reproduce the experiments in NeurIPS 2022 paper 3 | [Controllable Text Generation with Neurally-Decomposed Oracle](https://arxiv.org/abs/2205.14219) by [Tao Meng](https://mtsomethree.github.io/), [Sidi Lu](http://sidilu.cn/), [Nanyun Peng](https://vnpeng.net/) and [Kai-Wei Chang](http://web.cs.ucla.edu/~kwchang/). 4 | 5 | We are now working on the camera ready and the codebase is not a in a stable version. If you come up with some technical issue, please feel free to leave an issue or send an email to the first author. 6 | 7 | - ### Abstract 8 | We propose a general and efficient framework to control auto-regressive generation models with NeurAlly-Decomposed Oracle (NADO). Given a pre-trained base language model and a sequence-level boolean oracle function, we propose to decompose the oracle function into token-level guidance to steer the base model in text generation. Specifically, the token-level guidance is approximated by a neural model trained with examples sampled from the base model, demanding no additional auxiliary labeled data. We present the closed-form optimal solution to incorporate the token-level guidance into the base model for controllable generation. We further provide a theoretical analysis of how the approximation quality of NADO affects the controllable generation results. Experiments conducted on two applications: (1) text generation with lexical constraints and (2) machine translation with formality control demonstrate that our framework efficiently guides the base model towards the given oracle while maintaining high generation quality. 9 | 10 | - ### Experiments 11 | This repository will contain both experiments described in this paper. So far the LCG part is still under construction and expected to come out later October. 12 | 13 | - ### Data 14 | 15 | The machine translation formality change experiments leverage the [CALLHOME Spanish-English Speech Translation Corpus](https://aclanthology.org/2013.iwslt-papers.14/) as source data, and evaluate the BLUE score with the [fluent references](https://aclanthology.org/N19-1285/). Note that LDC access is required for the first dataset. 16 | 17 | - ### Running experiments 18 | 19 | **Requirements** 20 | 21 | ```bash 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | **Running** 26 | 27 | ```bash 28 | python train_MT.py 29 | ``` 30 | 31 | The code will automatically download [MarianMT model](https://huggingface.co/docs/transformers/model_doc/marian) and sample translated texts from source texts from Fisher-and-Callhome Corpus. The sampled data will be dumped in ./dump/MT directory. The sampled data is labeled by an formality oracle trained in [FUDGE paper](https://arxiv.org/abs/2104.05218). A NADO model will be trained by those labeled sampled data. The translated results will be evaluated based on oracle scores and the BLEU scores compared to fluent references. 32 | 33 | ```bash 34 | Alternative arguments: 35 | --sample_batch_size the batch size in sampling. Must be integer times of 8. 36 | --batch_size the batch size in training. Must be a divider of sample_batch_size 37 | --regularization the strength of regularization 38 | --max_length the maximum length accepted in training or evaluation 39 | ``` 40 | -------------------------------------------------------------------------------- /bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | """ 23 | 24 | import collections 25 | import math 26 | 27 | 28 | def _get_ngrams(segment, max_order): 29 | """Extracts all n-grams upto a given maximum order from an input segment. 30 | 31 | Args: 32 | segment: text segment from which n-grams will be extracted. 33 | max_order: maximum length in tokens of the n-grams returned by this 34 | methods. 35 | 36 | Returns: 37 | The Counter containing all n-grams upto max_order in segment 38 | with a count of how many times each n-gram occurred. 39 | """ 40 | ngram_counts = collections.Counter() 41 | for order in range(1, max_order + 1): 42 | for i in range(0, len(segment) - order + 1): 43 | ngram = tuple(segment[i:i+order]) 44 | ngram_counts[ngram] += 1 45 | return ngram_counts 46 | 47 | 48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 49 | smooth=False): 50 | """Computes BLEU score of translated segments against one or more references. 51 | 52 | Args: 53 | reference_corpus: list of lists of references for each translation. Each 54 | reference should be tokenized into a list of tokens. 55 | translation_corpus: list of translations to score. Each translation 56 | should be tokenized into a list of tokens. 57 | max_order: Maximum n-gram order to use when computing BLEU score. 58 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 59 | 60 | Returns: 61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 62 | precisions and brevity penalty. 63 | """ 64 | matches_by_order = [0] * max_order 65 | possible_matches_by_order = [0] * max_order 66 | reference_length = 0 67 | translation_length = 0 68 | for (references, translation) in zip(reference_corpus, 69 | translation_corpus): 70 | reference_length += min(len(r) for r in references) 71 | translation_length += len(translation) 72 | 73 | merged_ref_ngram_counts = collections.Counter() 74 | for reference in references: 75 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 76 | translation_ngram_counts = _get_ngrams(translation, max_order) 77 | overlap = translation_ngram_counts & merged_ref_ngram_counts 78 | for ngram in overlap: 79 | matches_by_order[len(ngram)-1] += overlap[ngram] 80 | for order in range(1, max_order+1): 81 | possible_matches = len(translation) - order + 1 82 | if possible_matches > 0: 83 | possible_matches_by_order[order-1] += possible_matches 84 | 85 | precisions = [0] * max_order 86 | for i in range(0, max_order): 87 | if smooth: 88 | precisions[i] = ((matches_by_order[i] + 1.) / 89 | (possible_matches_by_order[i] + 1.)) 90 | else: 91 | if possible_matches_by_order[i] > 0: 92 | precisions[i] = (float(matches_by_order[i]) / 93 | possible_matches_by_order[i]) 94 | else: 95 | precisions[i] = 0.0 96 | 97 | if min(precisions) > 0: 98 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 99 | geo_mean = math.exp(p_log_sum) 100 | else: 101 | geo_mean = 0 102 | 103 | ratio = float(translation_length) / reference_length 104 | 105 | if ratio > 1.0: 106 | bp = 1. 107 | else: 108 | bp = math.exp(1 - 1. / ratio) 109 | 110 | bleu = geo_mean * bp 111 | 112 | return (bleu, precisions, bp, ratio, translation_length, reference_length) -------------------------------------------------------------------------------- /check_diff.py: -------------------------------------------------------------------------------- 1 | from neural_constr import NeuralConstraintFunction 2 | nado_file = open('./dump/MT/fisher_10k_8-4-0.60.log', 'r') 3 | base_file = open('./dump/MT/fisher_base.log', 'r') 4 | ref_file = open('./fluent-fisher/noids/test.noid.cleaned_0', 'r') 5 | output_file = open('different.log', 'a') 6 | 7 | constraint_function = NeuralConstraintFunction() 8 | constraint_function.init_FUDGE_formality() 9 | 10 | cnt = 1 11 | for line1, line2, line3 in zip(nado_file, base_file, ref_file): 12 | v1 = constraint_function(line1) 13 | v2 = constraint_function(line2) 14 | if v1 < v2: 15 | output_file.write("%d\n%s\n%s\n%s\n\n\n"%(cnt, line1, line2, line3)) 16 | cnt += 1 17 | -------------------------------------------------------------------------------- /constant.py: -------------------------------------------------------------------------------- 1 | PAD_TOKEN = '[PAD]' 2 | EOT_TOKEN = '<|endoftext|>' 3 | SEP = 50256 # just use the weird eot token 4 | 5 | TOPIC_MODEL_STRING = 'gpt2-medium' 6 | FORMALITY_MODEL_STRING = 'Helsinki-NLP/opus-mt-es-en' 7 | 8 | DIR_END_SPLIT_POSITIONS = 32 9 | 10 | TOPIC_VAL_SIZE = 100000 11 | FORMALITY_VAL_SIZE = 2000 12 | VOCAB_SIZE = 50000 13 | 14 | FORMALITY_MAX_LEN = 200 15 | 16 | GLOVE_PRINT_PROGRESS_FREQ = 1000000 17 | GLOVE_DIM = 300 18 | HIDDEN_DIM = 300 19 | RNN_DIM = 150 20 | 21 | MIN_SENTENCE_LENGTH = 3 22 | 23 | POETRY_LINE_SYLLABLES = 10 24 | MAX_SYLLABLES_PER_WORD = 10 # no way anything is more 25 | MAX_COUNT_SYLLABLE_DIST = 10 26 | MAX_COUNT_SYLLABLE_INPUT_LENGTH = 25 # for just a couplet, shouldn't need more 27 | COUNT_SYLLABLE_DIM = 100 28 | UNKNOWN_RHYME_GROUP = 'UNKNOWN_RHYME_GROUP' 29 | PHRASE_ENDS = '.?!' 30 | 31 | POETRY_BANNED_TOKENS = [198, 50256, 628, 220] # newlines and eos and such -------------------------------------------------------------------------------- /constraints.py: -------------------------------------------------------------------------------- 1 | class LogicalConstraintFunction(object): 2 | def __init__(self, idx): 3 | self.idx = idx 4 | 5 | def set_constraint_id(self, idx): 6 | self.idx = idx 7 | 8 | def constraint_function1(self, sentence, input=None): #6439, 6308 || 3930 9 | flag1 = 0 10 | flag2 = 0 11 | flag3 = 0 12 | for x in sentence: 13 | if x % 16 == 0: 14 | flag1 = 1 15 | if x % 16 == 1: 16 | flag2 = 1 17 | if x % 16 == 2: 18 | flag3 = 1 19 | if flag1 + flag2 + flag3 == 3: 20 | return True 21 | return False 22 | 23 | def constraint_function2(self, sentence, input=None): 24 | for x in sentence: 25 | if x % 64 == 0: 26 | return True 27 | return False 28 | 29 | def __call__(self, sentence, input=None): 30 | if self.idx == 1: 31 | return self.constraint_function1(sentence, input) 32 | if self.idx == 2: 33 | return self.constraint_function2(sentence, input) 34 | 35 | def constraint_function1(sentence, input=None): #6439, 6308 || 3930 36 | flag1 = 0 37 | flag2 = 0 38 | flag3 = 0 39 | for x in sentence: 40 | if x % 16 == 0: 41 | flag1 = 1 42 | if x % 16 == 1: 43 | flag2 = 1 44 | if x % 16 == 2: 45 | flag3 = 1 46 | if flag1 + flag2 + flag3 == 3: 47 | return True 48 | return False 49 | 50 | def constraint_function2(sentence, input=None): 51 | for x in sentence: 52 | if x % 64 == 0: 53 | return True 54 | return False 55 | 56 | def get_constraint_function(id): 57 | if id == 1: 58 | return constraint_function1 59 | if id == 2: 60 | return constraint_function2 -------------------------------------------------------------------------------- /fudge_related.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence 8 | 9 | import numpy as np 10 | import torch 11 | from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, MarianTokenizer, MarianMTModel 12 | 13 | from constant import * 14 | 15 | class Model(nn.Module): 16 | def __init__(self, args, gpt_pad_id, vocab_size, rhyme_group_size=None, glove_embeddings=None, verbose=True): 17 | super(Model, self).__init__() 18 | self.topic = args.task == 'topic' 19 | self.formality = args.task == 'formality' 20 | self.iambic = args.task == 'iambic' 21 | self.rhyme = args.task == 'rhyme' 22 | self.newline = args.task == 'newline' 23 | 24 | self.marian_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=0) # 0 in marian is '' 25 | self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0.5) # want it to be causal so we can learn all positions 26 | self.out_linear = nn.Linear(HIDDEN_DIM, 1) 27 | 28 | def forward(self, inputs, lengths=None, future_words=None, log_probs=None, syllables_to_go=None, future_word_num_syllables=None, rhyme_group_index=None, run_classifier=False): 29 | """ 30 | inputs: token ids, batch x seq, right-padded with 0s 31 | lengths: lengths of inputs; batch 32 | future_words: batch x N words to check if not predict next token, else batch 33 | log_probs: N 34 | syllables_to_go: batch 35 | """ 36 | 37 | inputs = self.marian_embed(inputs) 38 | inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) 39 | rnn_output, _ = self.rnn(inputs) 40 | rnn_output, _ = pad_packed_sequence(rnn_output) 41 | rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 42 | return self.out_linear(rnn_output).squeeze(2) 43 | 44 | 45 | def avg_formality(preds, model, tokenizer, device='cuda:2'): 46 | probs = [] 47 | for sent in preds: 48 | encoded_input = tokenizer.encode(sent, return_tensors='pt').to(device) 49 | lengths = torch.LongTensor([encoded_input.shape[1]]).to(device) 50 | scores = model(encoded_input, lengths=lengths) # batch x seq 51 | score = scores.flatten()[-1].item() 52 | probs.append(math.exp(score) / (1 + math.exp(score))) # sigmoided score = prob 53 | return np.mean(probs) 54 | 55 | if __name__ == '__main__': 56 | 57 | 58 | device = 'cuda:2' 59 | tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-es-en') 60 | tokenizer.add_special_tokens({'pad_token': PAD_TOKEN}) 61 | pad_id = tokenizer.encode(PAD_TOKEN)[0] 62 | 63 | checkpoint = torch.load('model.pth.tar', map_location=device) 64 | model_args = checkpoint['args'] 65 | conditioning_model = Model(model_args, pad_id, 0) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway 66 | conditioning_model.load_state_dict(checkpoint['state_dict']) 67 | conditioning_model = conditioning_model.to(device) 68 | conditioning_model.eval() 69 | 70 | pred = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.'] 71 | 72 | print('avg formality prob according to model', avg_formality(pred, conditioning_model, tokenizer, device=device)) -------------------------------------------------------------------------------- /neural_constr.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import BertTokenizer, BertForSequenceClassification 3 | from transformers import XLMRobertaTokenizerFast, XLMRobertaForSequenceClassification 4 | from transformers import MarianTokenizer 5 | from torch.optim import AdamW, Adam 6 | from torch.nn.functional import one_hot 7 | from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence 8 | 9 | from constant import * 10 | 11 | import torch 12 | import numpy as np 13 | import random 14 | import math 15 | import os 16 | import argparse 17 | 18 | class FUDGEModel(torch.nn.Module): 19 | def __init__(self, args, gpt_pad_id, vocab_size, rhyme_group_size=None, glove_embeddings=None, verbose=True): 20 | super(FUDGEModel, self).__init__() 21 | self.topic = args.task == 'topic' 22 | self.formality = args.task == 'formality' 23 | self.iambic = args.task == 'iambic' 24 | self.rhyme = args.task == 'rhyme' 25 | self.newline = args.task == 'newline' 26 | 27 | self.marian_embed = torch.nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=0) # 0 in marian is '' 28 | self.rnn = torch.nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0.5) # want it to be causal so we can learn all positions 29 | self.out_linear = torch.nn.Linear(HIDDEN_DIM, 1) 30 | 31 | def forward(self, inputs, lengths=None, future_words=None, log_probs=None, syllables_to_go=None, future_word_num_syllables=None, rhyme_group_index=None, run_classifier=False): 32 | """ 33 | inputs: token ids, batch x seq, right-padded with 0s 34 | lengths: lengths of inputs; batch 35 | future_words: batch x N words to check if not predict next token, else batch 36 | log_probs: N 37 | syllables_to_go: batch 38 | """ 39 | 40 | inputs = self.marian_embed(inputs) 41 | inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False) 42 | rnn_output, _ = self.rnn(inputs) 43 | rnn_output, _ = pad_packed_sequence(rnn_output) 44 | rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300 45 | return self.out_linear(rnn_output).squeeze(2) 46 | 47 | class NeuralConstraintFunction(object): 48 | def __init__(self, model=None, tokenizer=None): 49 | self.model = model 50 | self.tokenizer = tokenizer 51 | self.device = 'cpu' 52 | self.batch_size = 512 53 | self.fudge = False 54 | 55 | def set_device(self, device): 56 | self.device = device 57 | self.model.to(device) 58 | 59 | def init_sentiment(self, dump_dir='./dump/sentiment.pt'): 60 | self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 61 | checkpoint = torch.load(dump_dir, map_location='cpu') 62 | self.model = BertForSequenceClassification.from_pretrained("bert-base-uncased") 63 | self.model.load_state_dict(checkpoint['model_state_dict']) 64 | 65 | def init_formality(self, dump_dir='./dump/formality.pt'): 66 | self.tokenizer = XLMRobertaTokenizerFast.from_pretrained('SkolkovoInstitute/xlmr_formality_classifier') 67 | self.model = XLMRobertaForSequenceClassification.from_pretrained('SkolkovoInstitute/xlmr_formality_classifier') 68 | 69 | def init_GYAFC_formality(self, dump_dir='./dump/GYAFC_formality.pt'): 70 | self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 71 | checkpoint = torch.load(dump_dir, map_location='cpu') 72 | self.model = BertForSequenceClassification.from_pretrained("bert-base-uncased") 73 | self.model.load_state_dict(checkpoint['model_state_dict']) 74 | 75 | def init_FUDGE_formality(self, dump_dir='./test_evaluator_gyafc_family_relationships/model.pth.tar'): 76 | self.fudge = True 77 | self.tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-es-en') 78 | self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN}) 79 | 80 | checkpoint = torch.load(dump_dir, map_location=self.device) 81 | model_args = checkpoint['args'] 82 | pad_id = self.tokenizer.encode(PAD_TOKEN)[0] 83 | self.model = FUDGEModel(model_args, pad_id, 0) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway 84 | self.model.load_state_dict(checkpoint['state_dict']) 85 | self.model = self.model.to(self.device) 86 | self.model.eval() 87 | 88 | def avg_formality(preds): 89 | probs = [] 90 | for sent in preds: 91 | encoded_input = self.tokenizer.encode(sent, return_tensors='pt').to(self.device) 92 | lengths = torch.LongTensor([encoded_input.shape[1]]).to(self.device) 93 | scores = self.model(encoded_input, lengths=lengths) # batch x seq 94 | score = scores.flatten()[-1].item() 95 | probs.append(math.exp(score) / (1 + math.exp(score))) # sigmoided score = prob 96 | return np.mean(probs) 97 | 98 | def __call__(self, text, return_logits=False, soft=False): 99 | if self.fudge: 100 | probs = [] 101 | for sent in text: 102 | encoded_input = self.tokenizer.encode(sent, return_tensors='pt').to(self.device) 103 | lengths = torch.LongTensor([encoded_input.shape[1]]).to(self.device) 104 | scores = self.model(encoded_input, lengths=lengths) # batch x seq 105 | score = scores.flatten()[-1].item() 106 | if soft: 107 | probs.append(math.exp(score) / (1 + math.exp(score))) # sigmoided score = prob 108 | else: 109 | if score > 0: 110 | probs.append(1) 111 | else: 112 | probs.append(0) 113 | return np.sum(probs) 114 | 115 | encoding_dict = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128) 116 | input_ids = encoding_dict['input_ids'].to(self.device) 117 | attention_mask = encoding_dict['attention_mask'].to(self.device) 118 | 119 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) 120 | 121 | pred_logits = outputs.logits 122 | pred = pred_logits.max(dim=1)[1] 123 | satisfied = int(pred.sum()) 124 | 125 | if return_logits: 126 | return satisfied, pred, pred_logits 127 | else: 128 | return satisfied 129 | 130 | 131 | def eval_constr_model(model, tokenizer, loader, args=None): 132 | if args is not None: 133 | device = args.device 134 | else: 135 | device = "cuda:6" 136 | 137 | model.to(device) 138 | total = 0 139 | correct = 0 140 | for samples, labels in loader: 141 | total += len(samples) 142 | labels = torch.LongTensor(labels).to(device) 143 | 144 | encoding_dict = tokenizer(samples, return_tensors='pt', padding=True, truncation=True, max_length=128) 145 | input_ids = encoding_dict['input_ids'].to(device) 146 | attention_mask = encoding_dict['attention_mask'].to(device) 147 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, 148 | labels=one_hot(labels, num_classes=2).float()) 149 | 150 | pred_logits = outputs.logits 151 | pred = pred_logits.max(dim=1)[1] 152 | 153 | correct += torch.where(pred==labels)[0].shape[0] 154 | 155 | print ("Eval done! Accuracy: %.4f, Loss: %.4f"%(float(correct) / float(total), outputs.loss.item())) 156 | 157 | return total, correct, outputs.loss.item() 158 | 159 | 160 | def train_constr_model(train_samples, train_labels, valid_samples, valid_labels, args=None): 161 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 162 | model = BertForSequenceClassification.from_pretrained("bert-base-uncased") 163 | 164 | if args is not None: 165 | num_epochs = args.num_epochs 166 | lr = args.lr 167 | device = args.device 168 | else: 169 | num_epochs = 10 170 | lr = 0.00002 171 | device = "cuda:6" 172 | 173 | model.to(device) 174 | optimizer = AdamW(model.parameters(), lr=lr) 175 | 176 | for epoch in range(num_epochs): 177 | loss_list = [] 178 | train_loader = dataset_loader(train_samples, train_labels) 179 | for samples, labels in train_loader: 180 | encoding_dict = tokenizer(samples, return_tensors='pt', padding=True, truncation=True, max_length=128) 181 | input_ids = encoding_dict['input_ids'].to(device) 182 | attention_mask = encoding_dict['attention_mask'].to(device) 183 | labels = one_hot(torch.LongTensor(labels), num_classes=2).float().to(device) 184 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 185 | 186 | optimizer.zero_grad() 187 | loss = outputs.loss 188 | loss.backward() 189 | optimizer.step() 190 | 191 | loss_list.append(loss.item()) 192 | print ("Epoch %d, avg loss: %.4f"%(epoch, torch.Tensor(loss_list).mean())) 193 | valid_loader = dataset_loader(valid_samples, valid_labels) 194 | total, correct, loss = eval_constr_model(model, tokenizer, valid_loader) 195 | 196 | return model, tokenizer 197 | 198 | def dataset_loader(samples, labels, batch_size=16): 199 | length = len(samples) 200 | 201 | shuffle_list = list(zip(samples, labels)) 202 | random.shuffle(shuffle_list) 203 | samples, labels = zip(*shuffle_list) 204 | samples = list(samples) 205 | labels = list(labels) 206 | 207 | samples_batch = [] 208 | labels_batch = [] 209 | 210 | N = int(length / batch_size) 211 | 212 | for i in range(N): 213 | yield samples[i * batch_size: (i + 1) * batch_size], labels[i * batch_size: (i + 1) * batch_size] 214 | 215 | if batch_size * N < length: 216 | yield samples[N * batch_size:], labels[N * batch_size:] 217 | 218 | def init_sentiment(dump_dir='./dump/sentiment.pt'): 219 | dataset = load_dataset("SetFit/sst2") 220 | if dump_dir is None or (not os.path.exists(dump_dir)): 221 | model, tokenizer = train_constr_model(dataset['train']['text'], dataset['train']['label'], 222 | dataset['validation']['text'], dataset['validation']['label']) 223 | torch.save({'model_state_dict': model.state_dict()}, dump_dir) 224 | else: 225 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 226 | checkpoint = torch.load(dump_dir, map_location='cpu') 227 | model = BertForSequenceClassification.from_pretrained("bert-base-uncased") 228 | model.load_state_dict(checkpoint['model_state_dict']) 229 | 230 | test_loader = dataset_loader(dataset['test']['text'], dataset['test']['label']) 231 | 232 | return model, tokenizer 233 | 234 | def init_formality(dump_dir='./dump/formality.pt'): 235 | tokenizer = XLMRobertaTokenizerFast.from_pretrained('SkolkovoInstitute/xlmr_formality_classifier') 236 | model = XLMRobertaForSequenceClassification.from_pretrained('SkolkovoInstitute/xlmr_formality_classifier') 237 | 238 | return model, tokenizer 239 | 240 | def load_text_list(file_dir, max_length=128): 241 | l = [] 242 | with open(file_dir, 'r') as f: 243 | for line in f: 244 | if len(line) > max_length: 245 | l.append(line[:max_length - 1]) 246 | else: 247 | l.append(line.strip()) 248 | return l 249 | 250 | def init_GYAFC_formality(dump_dir='./dump/GYAFC.pt'): 251 | if dump_dir is None or (not os.path.exists(dump_dir)): 252 | train_0 = load_text_list('Entertainment_Music/train/informal') 253 | train_1 = load_text_list('Entertainment_Music/train/formal') 254 | train_labels = [0] * len(train_0) + [1] * len(train_1) 255 | train_samples = train_0 + train_1 256 | valid_0 = load_text_list('Entertainment_Music/tune/informal') 257 | valid_1 = load_text_list('Entertainment_Music/tune/formal') 258 | valid_labels = [0] * len(valid_0) + [1] * len(valid_1) 259 | valid_samples = valid_0 + valid_1 260 | 261 | model, tokenizer = train_constr_model(train_samples, train_labels, 262 | valid_samples, valid_labels) 263 | torch.save({'model_state_dict': model.state_dict()}, dump_dir) 264 | else: 265 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 266 | checkpoint = torch.load(dump_dir, map_location='cpu') 267 | model = BertForSequenceClassification.from_pretrained("bert-base-uncased") 268 | model.load_state_dict(checkpoint['model_state_dict']) 269 | 270 | return model, tokenizer 271 | 272 | 273 | def neural_constr_function(model, tokenizer, text, device='cpu'): 274 | model.to(device) 275 | 276 | encoding_dict = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128) 277 | input_ids = encoding_dict['input_ids'].to(device) 278 | attention_mask = encoding_dict['attention_mask'].to(device) 279 | 280 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) 281 | 282 | pred_logits = outputs.logits 283 | pred = pred_logits.max(dim=1)[1] 284 | 285 | return pred, pred_logits 286 | 287 | 288 | 289 | if __name__ == "__main__": 290 | fudge = NeuralConstraintFunction() 291 | fudge.init_FUDGE_formality() 292 | 293 | pred = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.'] 294 | print (fudge(pred) / 3.0, fudge(pred, soft=True) / 3.0) 295 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.1 2 | aiosignal==1.2.0 3 | async-timeout==4.0.2 4 | asynctest==0.13.0 5 | attrs==21.4.0 6 | brotlipy==0.7.0 7 | certifi==2021.10.8 8 | cffi @ file:///opt/conda/conda-bld/cffi_1642701102775/work 9 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 10 | click==8.0.3 11 | colorama==0.4.4 12 | cryptography @ file:///tmp/build/80754af9/cryptography_1639414570729/work 13 | datasets==1.18.2 14 | dill==0.3.4 15 | filelock @ file:///opt/conda/conda-bld/filelock_1642510437405/work 16 | frozenlist==1.3.0 17 | fsspec==2022.1.0 18 | huggingface-hub @ file:///tmp/build/80754af9/huggingface_hub_1639662742275/work 19 | idna @ file:///tmp/build/80754af9/idna_1637925883363/work 20 | importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1638529598726/work 21 | joblib @ file:///tmp/build/80754af9/joblib_1635411271373/work 22 | mkl-fft==1.3.1 23 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626179032232/work 24 | mkl-service==2.4.0 25 | multidict==6.0.2 26 | multiprocess==0.70.12.2 27 | nltk==3.7 28 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1634106693478/work 29 | olefile==0.46 30 | packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work 31 | pandas==1.3.5 32 | Pillow==8.4.0 33 | portalocker==2.4.0 34 | pyarrow==6.0.1 35 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 36 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1635333100036/work 37 | pyparsing @ file:///tmp/build/80754af9/pyparsing_1635766073266/work 38 | PySocks @ file:///tmp/build/80754af9/pysocks_1594394576006/work 39 | python-dateutil==2.8.2 40 | pytz==2021.3 41 | PyYAML==6.0 42 | regex @ file:///opt/conda/conda-bld/regex_1642021319040/work 43 | requests @ file:///opt/conda/conda-bld/requests_1641824580448/work 44 | sacrebleu==2.0.0 45 | sacremoses @ file:///tmp/build/80754af9/sacremoses_1633107328213/work 46 | sentencepiece==0.1.96 47 | six @ file:///tmp/build/80754af9/six_1623709665295/work 48 | tabulate==0.8.9 49 | tokenizers @ file:///tmp/build/80754af9/tokenizers_1639593992616/work 50 | torch==1.7.1 51 | torchaudio==0.7.0a0+a853dff 52 | torchvision @ file:///tmp/build/80754af9/torchvision_1622219711511/work 53 | tqdm==4.62.3 54 | transformers @ file:///tmp/build/80754af9/transformers_1639665351690/work 55 | typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1631814937681/work 56 | urllib3==1.26.7 57 | xxhash==2.0.2 58 | yarl==1.7.2 59 | zipp @ file:///opt/conda/conda-bld/zipp_1641824620731/work 60 | -------------------------------------------------------------------------------- /train_rc.py: -------------------------------------------------------------------------------- 1 | from GPT2base import ConstrainedLM 2 | from constraints import LogicalConstraintFunction 3 | from neural_constr import NeuralConstraintFunction 4 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Model, GPT2Config 5 | from torch.optim import AdamW, Adam 6 | from torch.nn.functional import log_softmax, softmax 7 | import datasets 8 | import copy 9 | import torch 10 | import argparse 11 | import os 12 | 13 | def sample_from_GPT2(model, tokenizer, constraint_function, args): 14 | if os.path.exists(args.samples_file): 15 | print ("Load Sampling Data from %s..."%(args.samples_file)) 16 | # samples_list, labels_list, masks_list = torch.load(args.samples_file) 17 | # samples_list = [x.to(args.device) for x in samples_list] 18 | # labels_list = [x.to(args.device) for x in labels_list] 19 | # masks_list = [x.to(args.device) for x in masks_list] 20 | return torch.load(args.samples_file, map_location=args.device) 21 | 22 | print ("Initializing Sampling...") 23 | model.set_constraint_factor(0.0) 24 | 25 | sentence_prefix = ["I"] * args.sample_batch_size 26 | encodings_dict = tokenizer.batch_encode_plus(sentence_prefix) 27 | 28 | input_ids = torch.tensor(encodings_dict['input_ids']).to(args.device) 29 | attention_mask = torch.tensor(encodings_dict['attention_mask']).to(args.device) 30 | 31 | ''' 32 | input_ids = tokenizer.encode( 33 | input_ids=input_ids, 34 | attention_mask=attention_mask, 35 | add_special_tokens=False, 36 | return_tensors="pt", 37 | ).to(args.device) 38 | ''' 39 | 40 | labels_list = [] 41 | samples_list = [] 42 | masks_list = [] 43 | logprobs_list = [] 44 | 45 | print("Sampling Data...") 46 | 47 | for i in range(args.num_sample_batches): 48 | outputs = model.generate( 49 | input_ids=input_ids, 50 | attention_mask=attention_mask, 51 | do_sample=True, 52 | max_length=args.max_length, # desired output sentence length 53 | pad_token_id=model.config.eos_token_id, 54 | output_scores=True, 55 | return_dict_in_generate=True, 56 | ) 57 | 58 | output_ids = outputs.sequences 59 | scores = outputs.scores 60 | 61 | labels = [] 62 | masks = [] 63 | logprobs = [] 64 | texts = [] 65 | 66 | for j in range(args.sample_batch_size): 67 | length = args.max_length - torch.where(output_ids[j] == model.config.eos_token_id)[0].shape[0] 68 | constr_input = tokenizer.decode(output_ids[j], skip_special_tokens=True) 69 | if isinstance(constraint_function, LogicalConstraintFunction): 70 | if constraint_function(constr_input): 71 | labels.append(1) 72 | else: 73 | labels.append(0) 74 | else: 75 | texts.append(constr_input) 76 | mask = [1] * length + [0] * (args.max_length - length) 77 | masks.append(mask) 78 | 79 | #logprob = outputs.sequences_scores[j] * float(length) 80 | 81 | 82 | logprob = 0.0 83 | for k in range(length - 1): 84 | logprob += log_softmax(scores[k][j], dim=0)[output_ids[j][k + 1]] 85 | 86 | logprobs.append(logprob) 87 | 88 | labels = torch.Tensor(labels).unsqueeze(1).to(args.device) 89 | if isinstance(constraint_function, NeuralConstraintFunction): 90 | _, labels, _ = constraint_function(texts, return_logits=True) 91 | labels.float().squeeze().unsqueeze(1).to(args.device) 92 | 93 | masks = torch.Tensor(masks).to(args.device) 94 | logprobs = torch.Tensor(logprobs).to(args.device) 95 | 96 | for j in range(int(args.sample_batch_size / args.batch_size)): 97 | labels_list.append(labels[j * args.batch_size : (j + 1) * args.batch_size]) 98 | samples_list.append(output_ids[j * args.batch_size : (j + 1) * args.batch_size]) 99 | masks_list.append(masks[j * args.batch_size : (j + 1) * args.batch_size]) 100 | logprobs_list.append(logprobs[j * args.batch_size : (j + 1) * args.batch_size]) 101 | if i % 10 == 9: 102 | print ("%d sample batches..."%(i + 1)) 103 | 104 | # print ("Last batch sampled text:") 105 | # for j in range(args.sample_batch_size): 106 | # generated_text = tokenizer.decode( 107 | # output_ids[j], 108 | # clean_up_tokenization_spaces=True) 109 | # print (generated_text) 110 | torch.save((samples_list, labels_list, masks_list, logprobs_list), args.samples_file) 111 | return samples_list, labels_list, masks_list, logprobs_list 112 | 113 | def train_rc(model, samples_list, labels_list, masks_list, logprobs_list, args): 114 | print ("Strat Training...") 115 | for p in model.parameters(): 116 | p.requires_grad = False 117 | rc_parameters = [] 118 | for n, p in model.named_parameters(): 119 | if "model_rc" in n: 120 | p.requires_grad = True 121 | rc_parameters.append(p) 122 | 123 | print ("%d parameters in total"%(sum(p.numel() for p in model.parameters()))) 124 | print ("%d parameters in rc"%(sum(p.numel() for p in model.parameters() if p.requires_grad))) 125 | 126 | optimizer = AdamW(params=rc_parameters, lr=args.lr) 127 | 128 | 129 | for epoch in range(args.num_epochs): 130 | cnt = 1 131 | loss_list = [] 132 | for samples, labels, masks, logprobs in zip(samples_list, labels_list, masks_list, logprobs_list): 133 | labels = labels.float() 134 | #probs = softmax(logprobs, dim=0) * float(labels.shape[0]) 135 | 136 | outputs = model(input_ids=samples, attention_mask=masks, rc_weights=logprobs, rc_labels=labels) 137 | loss = outputs.loss 138 | 139 | optimizer.zero_grad() 140 | loss.backward() 141 | optimizer.step() 142 | 143 | cnt += 1 144 | loss_list.append(loss.item()) 145 | 146 | print ("Epoch %d: avg loss: %.4f"%(epoch, torch.Tensor(loss_list).mean())) 147 | 148 | satisfied = test_rc(model, tokenizer, constraint_function, args, use_constr=True, sample_text=(epoch == args.num_epochs - 1)) 149 | print (float(satisfied) / float(args.num_test) / float(args.sample_batch_size)) 150 | 151 | 152 | def test_rc(model, tokenizer, constraint_function, args, use_constr=True, sample_text=False): 153 | if use_constr: 154 | model.set_constraint_factor(1.0) 155 | else: 156 | model.set_constraint_factor(0.0) 157 | model.set_temperature(1.0) 158 | sentence_prefix = ["I"] * args.sample_batch_size 159 | encodings_dict = tokenizer.batch_encode_plus(sentence_prefix) 160 | 161 | input_ids = torch.tensor(encodings_dict['input_ids']).to(args.device) 162 | attention_mask = torch.tensor(encodings_dict['attention_mask']).to(args.device) 163 | 164 | #model(input_ids=input_ids) 165 | satisfied = 0 166 | 167 | for i in range(args.num_test): 168 | output_ids = model.generate( 169 | input_ids=input_ids, 170 | attention_mask=attention_mask, 171 | do_sample=True, 172 | max_length=args.max_length, # desired output sentence length 173 | pad_token_id=model.config.eos_token_id, 174 | ) 175 | 176 | if isinstance(constraint_function, NeuralConstraintFunction): 177 | texts = [] 178 | for j in range(args.sample_batch_size): 179 | texts.append(tokenizer.decode(output_ids[j], skip_special_tokens=True)) 180 | satisfied += constraint_function(texts) 181 | 182 | else: 183 | for j in range(args.sample_batch_size): 184 | if constraint_function(output_ids[j]): 185 | satisfied += 1 186 | 187 | if sample_text: 188 | for j in range(min(20, args.sample_batch_size)): 189 | generate_text = tokenizer.decode(output_ids[j], skip_special_tokens=True) 190 | print (generate_text) 191 | if isinstance(constraint_function, NeuralConstraintFunction): 192 | print (constraint_function(generate_text)) 193 | else: 194 | print (constraint_function(output_ids[j])) 195 | 196 | model.set_temperature(args.temperature) 197 | return satisfied 198 | 199 | def fine_tune_GPT2_with_pos_samples(model, samples_list, labels_list, masks_list, logprobs_list, args): 200 | model.set_constraint_factor(0.0) 201 | 202 | fine_tune_parameters = [] 203 | for n, p in model.named_parameters(): 204 | if "model_rc" in n: 205 | p.requires_grad = False 206 | else: 207 | fine_tune_parameters.append(p) 208 | 209 | optimizer = Adam(params=fine_tune_parameters, lr=args.lr) 210 | 211 | 212 | for epoch in range(args.num_epochs): 213 | cnt = 1 214 | loss_list = [] 215 | for samples, labels, masks, logprobs in zip(samples_list, labels_list, masks_list, logprobs_list): 216 | labels = labels.float() 217 | if labels.sum() < 0.5: 218 | continue 219 | outputs = model(input_ids=samples, attention_mask=masks, labels=samples, rc_weights=logprobs, rc_labels=labels.squeeze(1)) 220 | loss = outputs.loss 221 | 222 | optimizer.zero_grad() 223 | loss.backward() 224 | optimizer.step() 225 | 226 | cnt += 1 227 | loss_list.append(loss.item()) 228 | 229 | print ("Epoch %d: avg loss: %.4f"%(epoch, torch.Tensor(loss_list).mean())) 230 | 231 | satisfied = test_rc(model, tokenizer, constraint_function, args, use_constr=True, sample_text=(epoch == args.num_epochs - 1)) 232 | print (float(satisfied) / float(args.num_test) / float(args.sample_batch_size)) 233 | 234 | 235 | if __name__ == "__main__": 236 | parser = argparse.ArgumentParser() 237 | parser.add_argument('--num_sample_batches', type=int, default=200) 238 | parser.add_argument('--sample_batch_size', type=int, default=512) 239 | parser.add_argument('--batch_size', type=int, default=32) 240 | parser.add_argument('--num_epochs', type=int, default=10) 241 | parser.add_argument('--lr', type=float, default=0.00003) 242 | parser.add_argument('--cuda', type=int, default=-1) 243 | parser.add_argument('--num_test', type=int, default=100) 244 | parser.add_argument('--max_length', type=int, default=30) 245 | parser.add_argument('--device', type=str, default=None) 246 | parser.add_argument('--dump_dir', type=str, default=None) 247 | parser.add_argument('--load_dir', type=str, default=None) 248 | parser.add_argument('--use_rc_transformer', action='store_true') 249 | parser.add_argument('--num_rc_layers', type=int, default=-1) 250 | parser.add_argument('--baseline_fine_tune', action='store_true') 251 | parser.add_argument('--constraint_id', type=int, default=1) 252 | parser.add_argument('--samples_file', type=str, default=None) 253 | parser.add_argument('--temperature', type=float, default=1.0) 254 | 255 | args = parser.parse_args() 256 | 257 | #constraint_function = LogicalConstraintFunction(args.constraint_id) 258 | constraint_function = NeuralConstraintFunction() 259 | constraint_function.init_formality() 260 | 261 | if args.device is None: 262 | if args.cuda == -1: 263 | args.device = 'cpu' 264 | else: 265 | args.device = "cuda:%d"%(args.cuda) 266 | 267 | if args.samples_file is None: 268 | args.samples_file = './dump/formality_%d-%d-%d.pt'%(args.num_sample_batches, args.sample_batch_size, args.batch_size) 269 | 270 | model = ConstrainedLM.from_pretrained("gpt2") 271 | model.set_temperature(args.temperature) 272 | if args.num_rc_layers != -1: 273 | new_config = copy.copy(model.config) 274 | new_config.n_layer = args.num_rc_layers 275 | model.set_model_rc_transformer(GPT2Model(new_config)) 276 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 277 | 278 | 279 | if args.baseline_fine_tune: 280 | 281 | model.to(args.device) 282 | samples_list, labels_list, masks_list, logprobs_list = sample_from_GPT2(model, tokenizer, constraint_function, args) 283 | fine_tune_GPT2_with_pos_samples(model, samples_list, labels_list, masks_list, logprobs_list, args) 284 | 285 | else: 286 | 287 | model.set_use_rc_transformer(args.use_rc_transformer) 288 | 289 | if args.load_dir is not None: 290 | model.model_rc.load_state_dict(torch.load(args.load_dir)) 291 | model.to(args.device) 292 | satisfied = test_rc(model, tokenizer, constraint_function, args, use_constr=False, sample_text=False) 293 | print (float(satisfied) / float(args.num_test) / float(args.sample_batch_size)) 294 | else: 295 | model.to(args.device) 296 | satisfied = test_rc(model, tokenizer, constraint_function, args, use_constr=False, sample_text=False) 297 | print (float(satisfied) / float(args.num_test) / float(args.sample_batch_size)) 298 | samples_list, labels_list, masks_list, logprobs_list = sample_from_GPT2(model, tokenizer, constraint_function, args) 299 | model.set_constraint_factor(1.0) 300 | train_rc(model, samples_list, labels_list, masks_list, logprobs_list, args) 301 | 302 | satisfied = test_rc(model, tokenizer, constraint_function, args, use_constr=True, sample_text=True) 303 | 304 | print (float(satisfied) / float(args.num_test) / float(args.sample_batch_size)) 305 | if args.dump_dir is None: 306 | args.dump_dir = './dump/%d.pt'%(int(float(10000 * satisfied) / float(args.num_test) / float(args.sample_batch_size))) 307 | if args.load_dir is None: 308 | torch.save(model.model_rc.state_dict(), args.dump_dir) 309 | 310 | 311 | 312 | --------------------------------------------------------------------------------