├── model ├── __init__.py ├── postprocessing.py ├── loss.py ├── entailment_score.py ├── utils.py ├── seq2seq_vocab.py └── optim.py ├── metrics ├── 3rdparty │ └── .create ├── pred.txt ├── ref.txt ├── test_metrics.py ├── README.md ├── tokenizers.py ├── multi-bleu.perl └── dstc_example.py ├── bert_score ├── __init__.py ├── rescale_baseline │ └── en │ │ └── roberta-large.tsv ├── score.py ├── scorer.py └── utils.py ├── Software.zip ├── framework.PNG ├── BT ├── average_model.sh ├── evaluate_back.sh ├── preprocess.sh ├── evaluate.sh ├── train_en-fr.sh ├── train_fr-en.sh ├── bpe_split.py ├── get_bt_input_file.py └── recover.py ├── multi_gpt2 └── config.json ├── cal_novelty.py ├── inference_multi_gpt2.sh ├── train_gpt2.sh ├── train_multi_gpt2_distilled.sh ├── train_seq2seq.sh ├── data_manipulation ├── prepare_model │ ├── extract_personas_and_responses.py │ ├── finetune_bert_and_gpt2.py │ ├── train_nli_model.py │ └── train_coherence_nli.py ├── data_diversification │ ├── get_augmented_scores.py │ └── filter_augmented_data.py └── data_distillation │ ├── get_distilled_dataset.py │ └── calculate_entailment.py ├── train_gpt2_D3.sh ├── train_seq2seq_D3.sh ├── attention_experiment ├── get_target_samples.py └── attention_exp.py ├── README.md ├── inference.py └── config.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/3rdparty/.create: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/pred.txt: -------------------------------------------------------------------------------- 1 | i lkie this thing 2 | -------------------------------------------------------------------------------- /metrics/ref.txt: -------------------------------------------------------------------------------- 1 | i do not like this 2 | -------------------------------------------------------------------------------- /bert_score/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.6" 2 | -------------------------------------------------------------------------------- /Software.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoyu-noob/D3/HEAD/Software.zip -------------------------------------------------------------------------------- /framework.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoyu-noob/D3/HEAD/framework.PNG -------------------------------------------------------------------------------- /BT/average_model.sh: -------------------------------------------------------------------------------- 1 | python ../fairseq/scripts/average_checkpoints.py \ 2 | --input checkpoints/enfr/ \ 3 | --num-epoch-checkpoints 5 \ 4 | --output ./checkpoints/enfr/model_avg.pt \ 5 | -------------------------------------------------------------------------------- /BT/evaluate_back.sh: -------------------------------------------------------------------------------- 1 | sub_num=$1 2 | CUDA_VISIBLE_DEVICES=$sub_num python ../fairseq/interactive.py ./corpus/fren/ \ 3 | -s fr -t en \ 4 | --path ./checkpoints/fren/model_avg.pt \ 5 | --beam 5 \ 6 | --nbest 5 \ 7 | --batch-size 128 \ 8 | --buffer-size 8000 \ 9 | --input ./en-fr.out$sub_num > fr-en.log$sub_num \ 10 | -------------------------------------------------------------------------------- /BT/preprocess.sh: -------------------------------------------------------------------------------- 1 | ROOT=./ 2 | SRC=en 3 | TGT=fr 4 | DICT=./ 5 | 6 | python ../fairseq/preprocess.py \ 7 | --source-lang en \ 8 | --target-lang fr \ 9 | --trainpref train \ 10 | --validpref valid \ 11 | --testpref test \ 12 | --destdir ../corpus/enfr \ 13 | --thresholdtgt 0 \ 14 | --thresholdsrc 0 \ 15 | --workers 90 -------------------------------------------------------------------------------- /BT/evaluate.sh: -------------------------------------------------------------------------------- 1 | sub_num=$1 2 | CUDA_VISIBLE_DEVICES=$sub_num python ../fairseq/interactive.py ./corpus/enfr/ \ 3 | -s en -t fr \ 4 | --path ./checkpoints/enfr/model_avg.pt \ 5 | --beam 5 \ 6 | --nbest 5 \ 7 | --batch-size 128 \ 8 | --buffer-size 8000 \ 9 | --input ./th0.99_entail_dev_bpe.txt$sub_num > en-fr.log$sub_num \ 10 | -------------------------------------------------------------------------------- /multi_gpt2/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "GPT2LMHeadModel" 4 | ], 5 | "initializer_range": 0.02, 6 | "layer_norm_epsilon": 1e-05, 7 | "n_ctx": 1024, 8 | "n_embd": 768, 9 | "n_head": 12, 10 | "n_layer": 12, 11 | "n_positions": 1024, 12 | "vocab_size": 50257, 13 | "shared_module": false, 14 | "shared_attention": false, 15 | "context_size": 2 16 | } -------------------------------------------------------------------------------- /cal_novelty.py: -------------------------------------------------------------------------------- 1 | import json 2 | from metrics import cal_novelty 3 | 4 | with open('original_utterances.json', 'r') as f: 5 | original_utterances = json.load(f) 6 | with open('new_utterances.json', 'r') as f: 7 | new_utterances = json.load(f) 8 | un1, un2, un3, un4 = cal_novelty(original_utterances, new_utterances[40000:50000]) 9 | print(un1) 10 | print(un2) 11 | print(un3) 12 | print(un4) -------------------------------------------------------------------------------- /inference_multi_gpt2.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python inference.py \ 2 | --test_datasets augmented_data/th0.99_self_without_response.json \ 3 | --test_datasets_cache augmented_data/th0.99_self_without_response_gpt2_cache \ 4 | --generate_file_name augmented_data/th0.99_self_with_generated_response.json \ 5 | --data_type entailment \ 6 | --load_last runs/multi_gpt2/best_model \ 7 | --max_history_size 1 \ 8 | --model_type gpt2 \ 9 | --shared_module 0 \ 10 | --shared_attention 0 \ 11 | --attention_pooling_type sw \ 12 | -------------------------------------------------------------------------------- /metrics/test_metrics.py: -------------------------------------------------------------------------------- 1 | from metrics import * 2 | from tokenizers import * 3 | 4 | # evaluation 5 | 6 | 7 | nist, bleu, meteor, entropy, diversity, avg_len = nlp_metrics( 8 | path_refs=['demo/ref0.txt', 'demo/ref1.txt'], 9 | path_hyp='demo/hyp.txt') 10 | 11 | print(nist) 12 | print(bleu) 13 | print(meteor) 14 | print(entropy) 15 | print(diversity) 16 | print(avg_len) 17 | 18 | # tokenization 19 | 20 | s = " I don't know:). how about this?https://github.com/golsun/deep-RL-time-series" 21 | print(clean_str(s)) 22 | -------------------------------------------------------------------------------- /train_gpt2.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --train_batch_size 256 \ 3 | --batch_split 32 \ 4 | --n_epochs 15 \ 5 | --lr 6.25e-5 \ 6 | --train_datasets datasets/train_self_original.txt \ 7 | --valid_datasets datasets/valid_self_original.txt \ 8 | --test_datasets datasets/test_self_original.txt \ 9 | --train_datasets_cache datasets/th0.99_entail_gpt2_self_cache \ 10 | --valid_datasets_cache datasets/valid_gpt2_self_cache \ 11 | --test_datasets_cache datasets/test_gpt2_self_cache \ 12 | --model_type gpt2 \ 13 | --single_input \ 14 | --hits_weight 1 \ 15 | --s2s_weight 1 \ 16 | --negative_samples 1 \ 17 | --patience 5 \ 18 | --max_history_size 5 \ 19 | --entail_score_refs_file ./persona_test_entailment_idx.json \ 20 | -------------------------------------------------------------------------------- /BT/train_en-fr.sh: -------------------------------------------------------------------------------- 1 | python ../fairseq/train.py ./enfr/ \ 2 | --arch transformer_wmt_en_de \ 3 | --criterion label_smoothed_cross_entropy \ 4 | --label-smoothing 0.1 \ 5 | --lr 7e-4 \ 6 | --warmup-init-lr 1e-7 \ 7 | --min-lr 1e-9 \ 8 | --lr-scheduler inverse_sqrt \ 9 | --warmup-updates 4000 \ 10 | --optimizer adam \ 11 | --adam-betas '(0.9,0.98)' \ 12 | --adam-eps 1e-6 \ 13 | -s en \ 14 | -t fr \ 15 | --max-tokens 8192 \ 16 | --max-update 100000 \ 17 | --weight-decay 0.01 \ 18 | --seed 0 \ 19 | --save-dir ./checkpoints/enfr/ \ 20 | --ddp-backend=no_c10d \ 21 | --fp16 \ 22 | --update-freq 8 \ 23 | --dropout 0.3 \ 24 | --no-progress-bar \ 25 | --log-format simple \ 26 | --log-interval 50 \ 27 | --save-interval-updates 4000 \ 28 | --share-decoder-input-output-embed \ 29 | -------------------------------------------------------------------------------- /BT/train_fr-en.sh: -------------------------------------------------------------------------------- 1 | python ../fairseq/train.py ./fren/ \ 2 | --arch transformer_wmt_en_de \ 3 | --criterion label_smoothed_cross_entropy \ 4 | --label-smoothing 0.1 \ 5 | --lr 7e-4 \ 6 | --warmup-init-lr 1e-7 \ 7 | --min-lr 1e-9 \ 8 | --lr-scheduler inverse_sqrt \ 9 | --warmup-updates 4000 \ 10 | --optimizer adam \ 11 | --adam-betas '(0.9,0.98)' \ 12 | --adam-eps 1e-6 \ 13 | -s fr \ 14 | -t en \ 15 | --max-tokens 8192 \ 16 | --max-update 100000 \ 17 | --weight-decay 0.01 \ 18 | --seed 0 \ 19 | --save-dir ./checkpoints/fren/ \ 20 | --ddp-backend=no_c10d \ 21 | --fp16 \ 22 | --update-freq 8 \ 23 | --dropout 0.3 \ 24 | --no-progress-bar \ 25 | --log-format simple \ 26 | --log-interval 50 \ 27 | --save-interval-updates 4000 \ 28 | --share-decoder-input-output-embed \ 29 | -------------------------------------------------------------------------------- /train_multi_gpt2_distilled.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --train_batch_size 256 \ 3 | --batch_split 32 \ 4 | --lr 5e-4 \ 5 | --train_datasets datasets/th0.99_model_entail_train_data.json \ 6 | --valid_datasets datasets/th0.99_model_entail_dev_data.json \ 7 | --test_datasets datasets/th0.99_model_entail_dev_data.json \ 8 | --train_datasets_cache datasets/th0.99_model_entail_train_gpt2_cache \ 9 | --valid_datasets_cache datasets/th0.99_model_entail_valid_gpt2_cache \ 10 | --test_datasets_cache datasets/th0.99_model_entail_valid_gpt2_cache \ 11 | --model_type gpt2 \ 12 | --data_type entailment \ 13 | --shared_module 0 \ 14 | --shared_attention 0 \ 15 | --n_epochs 5 \ 16 | --attention_pooling_type sw \ 17 | --extra_module_lr_rate 5.0 \ 18 | --max_history_size 1 \ 19 | -------------------------------------------------------------------------------- /BT/bpe_split.py: -------------------------------------------------------------------------------- 1 | import sentencepiece as spm 2 | import sys 3 | 4 | split_num = int(sys.argv[1]) 5 | sp=spm.SentencePieceProcessor() 6 | sp.load('sentence.bpe.model') 7 | 8 | with open('th0.99_entail_dev.txt', 'r') as f: 9 | lines = f.readlines() 10 | new_lines = [] 11 | for line in lines: 12 | nl = sp.EncodeAsPieces(line.strip()) 13 | new_lines.append(' '.join(nl) + '\n') 14 | file_name = 'th0.99_entail_dev_bpe.txt' 15 | avg_size = len(new_lines) // split_num 16 | start = 0 17 | for i in range(split_num): 18 | if i != split_num - 1: 19 | sub_lines = new_lines[start:start + avg_size] 20 | else: 21 | sub_lines = new_lines[start:] 22 | start += avg_size 23 | with open(file_name + str(i), 'w') as f: 24 | f.writelines(sub_lines) 25 | 26 | -------------------------------------------------------------------------------- /train_seq2seq.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --train_batch_size 256 \ 3 | --batch_split 8 \ 4 | --n_epochs 100 \ 5 | --lr 2e-4 \ 6 | --train_datasets datasets/train_self_original.txt \ 7 | --valid_datasets datasets/valid_self_original.txt \ 8 | --test_datasets datasets/test_self_original.txt \ 9 | --train_datasets_cache datasets/train_seq2seq_self_cache \ 10 | --valid_datasets_cache datasets/valid_seq2seq_self_cache \ 11 | --test_datasets_cache datasets/test_seq2seq_self_cache \ 12 | --vocab_path datasets/vocab/persona_self_vocab.bin \ 13 | --model_type seq2seq \ 14 | --single_input \ 15 | --pointer_gen \ 16 | --label_smoothing 0.1 \ 17 | --s2s_weight 1 \ 18 | --negative_samples 0 \ 19 | --max_history_size 5 \ 20 | --patience 15 \ 21 | --entail_score_refs_file ./persona_test_entailment_idx.json \ 22 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/roberta-large.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.3712891,0.37132213,0.36826715 3 | 1,0.67176163,0.6717439,0.6703483 4 | 2,0.70031923,0.7003052,0.69969934 5 | 3,0.7080897,0.7081011,0.707698 6 | 4,0.6976306,0.69762677,0.69710517 7 | 5,0.7187199,0.71873325,0.71828526 8 | 6,0.74678195,0.74678224,0.74642223 9 | 7,0.7772428,0.7772184,0.77691925 10 | 8,0.8021733,0.8021747,0.8019093 11 | 9,0.8067641,0.80678225,0.8065291 12 | 10,0.8366976,0.8367098,0.8364913 13 | 11,0.8163513,0.816369,0.8161064 14 | 12,0.8175406,0.8175611,0.81728977 15 | 13,0.82106245,0.8210674,0.82080233 16 | 14,0.81487834,0.8148861,0.8145652 17 | 15,0.8243552,0.8243522,0.8240494 18 | 16,0.8341641,0.8341684,0.833912 19 | 17,0.83150584,0.8314941,0.83122575 20 | 18,0.8314624,0.83146274,0.8311686 21 | 19,0.82761073,0.8276117,0.8273196 22 | 20,0.799873,0.79988,0.79956234 23 | 21,0.8082163,0.80819315,0.8079286 24 | 22,0.83196104,0.83195347,0.83174026 25 | 23,0.8408042,0.8408027,0.8405716 26 | 24,0.96022236,0.96021587,0.960168 -------------------------------------------------------------------------------- /model/postprocessing.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | 4 | import nltk 5 | from nltk.corpus import wordnet 6 | 7 | 8 | def augment_replica(seq): 9 | _exceptions = ['your', 'persona'] 10 | pos2wn = {'NN': wordnet.NOUN, 11 | 'JJ': wordnet.ADJ, 12 | 'VBP': wordnet.VERB, 13 | 'RB': wordnet.ADV} 14 | 15 | synonyms = defaultdict(list) 16 | 17 | tagged_seq = seq.replace('i ', 'I ') 18 | tagged_seq = nltk.pos_tag(nltk.word_tokenize(tagged_seq)) 19 | 20 | for word, pos in tagged_seq: 21 | if pos not in pos2wn or word in _exceptions: 22 | continue 23 | 24 | pos = pos2wn[pos] 25 | synnets = wordnet.synsets(word, pos=pos) 26 | 27 | for synnet in synnets: 28 | for syn in synnet.lemma_names(): 29 | if syn != word: 30 | synonyms[word].append(syn.replace('_', ' ')) 31 | break 32 | if synonyms: 33 | for key, values in synonyms.items(): 34 | seq = seq.replace(key, random.choice(list(values))) 35 | 36 | return seq -------------------------------------------------------------------------------- /data_manipulation/prepare_model/extract_personas_and_responses.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | INPUT_FILE = '../../datasets/train_self_original.txt' 4 | 5 | with open(INPUT_FILE, 'r') as f: 6 | lines = f.readlines() 7 | personas, utterances = [], [] 8 | for line in lines: 9 | line = line.strip() 10 | if len(line) == 0: 11 | continue 12 | space_idx = line.find(' ') 13 | if space_idx == -1: 14 | dialog_idx = int(line) 15 | else: 16 | dialog_idx = int(line[:space_idx]) 17 | dialog_line = line[space_idx + 1:].split('\t') 18 | dialog_line = [l.strip() for l in dialog_line] 19 | 20 | if dialog_line[0].startswith('your persona:'): 21 | persona_info = dialog_line[0].replace('your persona: ', '') 22 | if persona_info[-1] == '.' and persona_info[-2] != ' ': 23 | persona_info = persona_info[:-1] + ' .' 24 | personas.append(persona_info) 25 | elif len(dialog_line) > 1: 26 | utterances.append(dialog_line[1]) 27 | 28 | with open('personas.json', 'w') as f: 29 | json.dump(list(set(personas)), f) 30 | with open('responses.json', 'w') as f: 31 | json.dump(list(set(utterances)), f) -------------------------------------------------------------------------------- /train_gpt2_D3.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --train_batch_size 256 \ 3 | --batch_split 32 \ 4 | --n_epochs 15 \ 5 | --lr 6.25e-5 \ 6 | --train_datasets datasets/train_self_original.txt \ 7 | --valid_datasets datasets/valid_self_original.txt \ 8 | --test_datasets datasets/test_self_original.txt \ 9 | --train_datasets_cache datasets/train_gpt2_self_cache \ 10 | --valid_datasets_cache datasets/valid_gpt2_self_cache \ 11 | --test_datasets_cache datasets/test_gpt2_self_cache \ 12 | --model_type gpt2 \ 13 | --single_input \ 14 | --hits_weight 1 \ 15 | --s2s_weight 1 \ 16 | --negative_samples 1 \ 17 | --patience 5 \ 18 | --max_history_size 5 \ 19 | --curriculum_learning \ 20 | --curriculum_train_datasets datasets/augmented/th0.99_self_augmented.json \ 21 | --curriculum_train_datasets_cache datasets/augmented/th0.99_self_augmented_gpt2_cache \ 22 | --curriculum_valid_datasets datasets/augmented/th0.99_dev_augmented.json \ 23 | --curriculum_valid_datasets_cache datasets/augmented/th0.99_dev_augmented_gpt2_cache \ 24 | --curriculum_lr 6.25e-5 \ 25 | --curriculum_n_epochs 15 \ 26 | --curriculum_patience 5 \ 27 | --curriculum_max_history_size 1 \ 28 | --entail_score_refs_file .persona_test_entailment_idx.json \ 29 | -------------------------------------------------------------------------------- /train_seq2seq_D3.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --train_batch_size 256 \ 3 | --batch_split 8 \ 4 | --n_epochs 100 \ 5 | --lr 2e-4 \ 6 | --train_datasets datasets/train_self_original.txt \ 7 | --valid_datasets datasets/valid_self_original.txt \ 8 | --test_datasets datasets/test_self_original.txt \ 9 | --train_datasets_cache datasets/train_seq2seq_self_cache \ 10 | --valid_datasets_cache datasets/valid_seq2seq_self_cache \ 11 | --test_datasets_cache datasets/test_seq2seq_self_cache \ 12 | --curriculum_learning \ 13 | --curriculum_train_datasets datasets/augmented/th0.99_self_augmented.json \ 14 | --curriculum_train_datasets_cache datasets/augmented/th0.99_self_augmented_seq2seq_cache \ 15 | --curriculum_valid_datasets datasets/augmented/th0.99_dev_augmented.json \ 16 | --curriculum_valid_datasets_cache datasets/augmented/th0.99_dev_augmented_seq2seq_cache \ 17 | --curriculum_lr 2e-4 \ 18 | --curriculum_n_epochs 50 \ 19 | --curriculum_patience 15 \ 20 | --curriculum_max_history_size 1 \ 21 | --curriculum_data_type entailment \ 22 | --extend_exist_vocab datasets/vocab/persona_self_vocab.bin \ 23 | --vocab_path datasets/vocab/persona_self_augmented_vocab.bin \ 24 | --model_type seq2seq \ 25 | --single_input \ 26 | --label_smoothing 0.1 \ 27 | --pointer_gen \ 28 | --s2s_weight 1 \ 29 | --patience 15 \ 30 | --negative_samples 0 \ 31 | --max_history_size 5 \ 32 | --entail_score_refs_file ./persona_test_entailment_idx.json \ 33 | -------------------------------------------------------------------------------- /BT/get_bt_input_file.py: -------------------------------------------------------------------------------- 1 | import json 2 | from sys import argv 3 | 4 | # The input file of the distilled data (json file) or the original file 5 | input_file = argv[1] 6 | # The output file that only save the utterances of the input data for back-translation 7 | output_file = argv[2] 8 | 9 | new_lines = [] 10 | line_idx = [] 11 | if '.txt' in input_file: 12 | with open(input_file, 'r') as f: 13 | lines = f.readlines() 14 | for i, line in enumerate(lines): 15 | persona_idx = line.find('your persona: ') 16 | if persona_idx != -1 and persona_idx < 5: 17 | new_lines.append(line[persona_idx + 14:]) 18 | line_idx.append(i) 19 | else: 20 | space_idx = line.find(' ') 21 | history = line[space_idx + 1:].split('\t') 22 | new_lines.append(history[0].strip() + '\n') 23 | new_lines.append(history[1].strip() + '\n') 24 | line_idx.append(i) 25 | line_idx.append(i) 26 | if '.json' in input_file: 27 | with open(input_file, 'r') as f: 28 | data = json.load(f) 29 | for i, d in enumerate(data): 30 | new_lines.append(d[0].strip() + '\n') 31 | line_idx.append([i, 0]) 32 | for j, u in enumerate(d[1]): 33 | new_lines.append(u.strip() + '\n') 34 | line_idx.append([i, 3 + j]) 35 | new_lines.append(d[2].strip() + '\n') 36 | line_idx.append([i, 2]) 37 | with open(output_file, 'w') as f: 38 | f.writelines(new_lines) 39 | with open(output_file + '_idx.json', 'w') as f: 40 | json.dump(line_idx, f) -------------------------------------------------------------------------------- /metrics/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Adapted from the Automatic evaluation script for DSTC7 Task 2 3 | Original repo: https://github.com/mgalley/DSTC7-End-to-End-Conversation-Modeling 4 | 5 | # To install 3rd party scripts 6 | Download the following 3rd-party packages and save in folder `3rdparty`: 7 | 8 | * [**mteval-v14c.pl**](ftp://jaguar.ncsl.nist.gov/mt/resources/mteval-v14c.pl) to compute [NIST](http://www.mt-archive.info/HLT-2002-Doddington.pdf). You may need to install the following [perl](https://www.perl.org/get.html) modules (e.g. by `cpan install`): XML:Twig, Sort:Naturally and String:Util. 9 | * [**meteor-1.5**](http://www.cs.cmu.edu/~alavie/METEOR/download/meteor-1.5.tar.gz) to compute [METEOR](http://www.cs.cmu.edu/~alavie/METEOR/index.html). It requires [Java](https://www.java.com/en/download/help/download_options.xml). 10 | 11 | 12 | # What does it do? 13 | (Based on this [repo](https://github.com/golsun/NLP-tools) by [Sean Xiang Gao](https://www.linkedin.com/in/gxiang1228/)) 14 | 15 | * **evaluation**: calculate automated NLP metrics (BLEU, NIST, METEOR, entropy, etc...) 16 | ```python 17 | from metrics import nlp_metrics 18 | nist, bleu, meteor, entropy, diversity, avg_len = nlp_metrics( 19 | path_refs=["demo/ref0.txt", "demo/ref1.txt"], 20 | path_hyp="demo/hyp.txt") 21 | 22 | # nist = [1.8338, 2.0838, 2.1949, 2.1949] 23 | # bleu = [0.4667, 0.441, 0.4017, 0.3224] 24 | # meteor = 0.2832 25 | # entropy = [2.5232, 2.4849, 2.1972, 1.7918] 26 | # diversity = [0.8667, 1.000] 27 | # avg_len = 5.0000 28 | ``` 29 | * **tokenization**: clean string and deal with punctation, contraction, url, mention, tag, etc 30 | ```python 31 | from tokenizers import clean_str 32 | s = " I don't know:). how about this?https://github.com" 33 | clean_str(s) 34 | 35 | # i do n't know :) . how about this ? __url__ 36 | ``` 37 | -------------------------------------------------------------------------------- /metrics/tokenizers.py: -------------------------------------------------------------------------------- 1 | # author: Xiang Gao @ Microsoft Research, Oct 2018 2 | # clean and tokenize natural language text 3 | 4 | import re 5 | from util import * 6 | from nltk.tokenize import TweetTokenizer 7 | 8 | def clean_str(txt): 9 | #print("in=[%s]" % txt) 10 | txt = txt.lower() 11 | txt = re.sub('^',' ', txt) 12 | txt = re.sub('$',' ', txt) 13 | 14 | # url and tag 15 | words = [] 16 | for word in txt.split(): 17 | i = word.find('http') 18 | if i >= 0: 19 | word = word[:i] + ' ' + '__url__' 20 | words.append(word.strip()) 21 | txt = ' '.join(words) 22 | 23 | # remove markdown URL 24 | txt = re.sub(r'\[([^\]]*)\] \( *__url__ *\)', r'\1', txt) 25 | 26 | # remove illegal char 27 | txt = re.sub('__url__','URL',txt) 28 | txt = re.sub(r"[^A-Za-z0-9():,.!?\"\']", " ", txt) 29 | txt = re.sub('URL','__url__',txt) 30 | 31 | # contraction 32 | add_space = ["'s", "'m", "'re", "n't", "'ll","'ve","'d","'em"] 33 | tokenizer = TweetTokenizer(preserve_case=False) 34 | txt = ' ' + ' '.join(tokenizer.tokenize(txt)) + ' ' 35 | txt = txt.replace(" won't ", " will n't ") 36 | txt = txt.replace(" can't ", " can n't ") 37 | for a in add_space: 38 | txt = txt.replace(a+' ', ' '+a+' ') 39 | 40 | txt = re.sub(r'^\s+', '', txt) 41 | txt = re.sub(r'\s+$', '', txt) 42 | txt = re.sub(r'\s+', ' ', txt) # remove extra spaces 43 | 44 | #print("out=[%s]" % txt) 45 | return txt 46 | 47 | 48 | if __name__ == '__main__': 49 | ss = [ 50 | " I don't know:). how about this?https://github.com/golsun/deep-RL-time-series", 51 | "please try [ GitHub ] ( https://github.com )", 52 | ] 53 | for s in ss: 54 | print(s) 55 | print(clean_str(s)) 56 | print() 57 | 58 | -------------------------------------------------------------------------------- /data_manipulation/prepare_model/finetune_bert_and_gpt2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from sys import argv 4 | from tqdm import tqdm 5 | 6 | from transformers.modeling_gpt2 import GPT2LMHeadModel 7 | from transformers.tokenization_gpt2 import GPT2Tokenizer 8 | from transformers.modeling_bert import BertForPreTraining 9 | from transformers.tokenization_bert import BertTokenizer 10 | from transformers import AdamW 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data import TensorDataset 13 | from torch.nn import CrossEntropyLoss 14 | 15 | LR = 1e-5 16 | BATCH_SIZE = 32 17 | STEPS = 100 18 | input_file = argv[1] 19 | output_model = argv[2] 20 | base_model = 'BERT' 21 | 22 | BERT_MODEL_PATH = '../bert_model' 23 | GPT2_MODEL_PATH = '../gpt2-small' 24 | 25 | with open(input_file, 'r') as f: 26 | sentences = json.load(f) 27 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 28 | if base_model == 'GPT2': 29 | tokenizer = GPT2Tokenizer.from_pretrained(GPT2_MODEL_PATH) 30 | tokenizer.pad_token = tokenizer.eos_token 31 | model = GPT2LMHeadModel.from_pretrained(GPT2_MODEL_PATH) 32 | elif base_model == 'BERT': 33 | tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_PATH) 34 | tokenizer.pad_token = '[PAD]' 35 | model = BertForPreTraining.from_pretrained(BERT_MODEL_PATH) 36 | optimizer = AdamW(model.parameters(), lr=LR, correct_bias=True) 37 | model.to(device) 38 | 39 | all_inputs = tokenizer(sentences, return_tensors='pt', padding=True).data 40 | dataset = TensorDataset(all_inputs['input_ids'], all_inputs['attention_mask']) 41 | dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) 42 | steps = 0 43 | s2s_loss = 0 44 | while True: 45 | tqdm_data = tqdm(dataloader) 46 | for batch in tqdm_data: 47 | optimizer.zero_grad() 48 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]} 49 | for k, v in inputs.items(): 50 | inputs[k] = v.to(device) 51 | if base_model == 'GPT2': 52 | loss = model(**inputs, labels=inputs['input_ids'])[0] 53 | else: 54 | logits = model(**inputs)[0] 55 | loss_fct = CrossEntropyLoss() 56 | labels = inputs['input_ids'] 57 | labels = labels.masked_fill((labels==0).long(), -100) 58 | loss = loss_fct(logits.view(-1, model.config.vocab_size), labels.view(-1)) 59 | loss.backward() 60 | optimizer.step() 61 | tqdm_data.set_postfix({'s2s_loss': loss.item()}) 62 | steps += 1 63 | if steps > STEPS: 64 | break 65 | if steps > STEPS: 66 | break 67 | torch.save(model.state_dict(), output_model) -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class LabelSmoothingLoss(nn.Module): 23 | def __init__(self, n_labels, smoothing=0.0, ignore_index=-100, size_average=True): 24 | super(LabelSmoothingLoss, self).__init__() 25 | assert 0 <= smoothing <= 1 26 | 27 | self.ignore_index = ignore_index 28 | self.confidence = 1 - smoothing 29 | 30 | if smoothing > 0: 31 | self.criterion = nn.KLDivLoss(size_average=size_average) 32 | n_ignore_idxs = 1 + (ignore_index >= 0) 33 | one_hot = torch.full((1, n_labels), fill_value=(smoothing / (n_labels - n_ignore_idxs))) 34 | if ignore_index >= 0: 35 | one_hot[0, ignore_index] = 0 36 | self.register_buffer('one_hot', one_hot) 37 | else: 38 | self.criterion = nn.NLLLoss(size_average=size_average, ignore_index=ignore_index) 39 | 40 | def forward(self, log_inputs, targets): 41 | if self.confidence < 1: 42 | tdata = targets.data 43 | 44 | tmp = self.one_hot.repeat(targets.shape[0], 1) 45 | tmp.scatter_(1, tdata.unsqueeze(1), self.confidence) 46 | 47 | if self.ignore_index >= 0: 48 | mask = torch.nonzero(tdata.eq(self.ignore_index)).squeeze(-1) 49 | if mask.numel() > 0: 50 | tmp.index_fill_(0, mask, 0) 51 | 52 | targets = tmp 53 | 54 | return self.criterion(log_inputs, targets) 55 | 56 | class SoftCrossEntropyLoss(nn.Module): 57 | def __init(self): 58 | super(SoftCrossEntropyLoss).__init__() 59 | 60 | def forward(self, input, soft_targets, lengths): 61 | log_input = -F.log_softmax(input, dim=-1) 62 | loss = torch.mean(torch.sum(torch.sum(torch.mul(log_input, soft_targets), dim=-1), dim=-1) / (lengths + 0.01)) 63 | return loss 64 | -------------------------------------------------------------------------------- /model/entailment_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, TensorDataset 3 | import json 4 | from tqdm import tqdm 5 | from typing import Union 6 | 7 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelWithLMHead 8 | from transformers.data.processors.utils import InputExample, InputFeatures 9 | 10 | class EntailmentScorer: 11 | def __init__(self, pred_file, entail_idx_file, model_path, device): 12 | with open(pred_file, 'r') as f: 13 | lines = f.readlines() 14 | self.all_preds = [line.strip() for line in lines] 15 | with open(entail_idx_file, 'r') as f: 16 | refs = json.load(f) 17 | self.all_data = [] 18 | for ref in refs: 19 | persona = ref['persona'] 20 | idx = ref['index'] 21 | for i in idx: 22 | self.all_data.append([[p, self.all_preds[i]] for p in persona]) 23 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 24 | self.model = AutoModelForSequenceClassification.from_pretrained(model_path) 25 | self.model.to(device) 26 | self.device = device 27 | 28 | def _convert_examples_to_features(self, examples, tokenizer, label_list=None, max_length=128, output_mode=None,): 29 | if max_length is None: 30 | max_length = tokenizer.max_len 31 | 32 | label_map = {label: i for i, label in enumerate(label_list)} 33 | 34 | def label_from_example(example: InputExample) -> Union[int, float, None]: 35 | if example.label is None: 36 | return None 37 | if output_mode == "classification": 38 | return label_map[example.label] 39 | elif output_mode == "regression": 40 | return float(example.label) 41 | raise KeyError(output_mode) 42 | 43 | labels = [label_from_example(example) for example in examples] 44 | 45 | batch_encoding = tokenizer( 46 | [(example.text_a, example.text_b) for example in examples], 47 | max_length=max_length, 48 | padding="max_length", 49 | truncation=True, 50 | ) 51 | 52 | features = [] 53 | for i in range(len(examples)): 54 | inputs = {k: batch_encoding[k][i] for k in batch_encoding} 55 | feature = InputFeatures(**inputs, label=labels[i]) 56 | features.append(feature) 57 | 58 | return features 59 | 60 | def calculate_entailment_score(self): 61 | self.model.eval() 62 | entailed_results = [] 63 | with torch.no_grad(): 64 | for i in tqdm(range(len(self.all_data))): 65 | cur_data = self.all_data[i] 66 | cnt = 0 67 | input_examples = [] 68 | for sample in cur_data: 69 | input_examples.append(InputExample(str(cnt), sample[0], sample[1], '0')) 70 | cnt += 1 71 | features = self._convert_examples_to_features( 72 | input_examples, 73 | self.tokenizer, 74 | label_list=['0', '1'], 75 | max_length=128, 76 | output_mode='classification', 77 | ) 78 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(self.device) 79 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long).to(self.device) 80 | dataset = TensorDataset(all_input_ids, all_attention_mask) 81 | train_dataloader = DataLoader(dataset, batch_size=8) 82 | all_logits = None 83 | for batch in train_dataloader: 84 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]} 85 | outputs = self.model(**inputs) 86 | if all_logits is None: 87 | all_logits = outputs[0].detach() 88 | else: 89 | all_logits = torch.cat((all_logits, outputs[0]), dim=0) 90 | results = torch.argmax(all_logits, dim=1) 91 | entailed_results.append(torch.sum(results - 1).item()) 92 | return sum(entailed_results)/len(entailed_results) 93 | -------------------------------------------------------------------------------- /data_manipulation/data_diversification/get_augmented_scores.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import math 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | from sys import argv 7 | 8 | from transformers.modeling_gpt2 import GPT2LMHeadModel 9 | from transformers.tokenization_gpt2 import GPT2Tokenizer 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data import TensorDataset 12 | from transformers import AutoModelForSequenceClassification 13 | from transformers import AutoTokenizer 14 | from transformers import glue_convert_examples_to_features as convert_examples_to_features 15 | from transformers.data.processors.utils import InputExample 16 | 17 | input_file = argv[1] 18 | output_file = argv[2] 19 | BATCH_SIZE = 16 20 | 21 | def calculate_ppls(responses, device): 22 | print('calculate PPL scores...') 23 | tokenizer = GPT2Tokenizer.from_pretrained('./gpt2_utterance_model') 24 | model = GPT2LMHeadModel.from_pretrained('./gpt2_utterance_model') 25 | model.to(device) 26 | tokenizer.pad_token = tokenizer.eos_token 27 | ppls = [] 28 | for r in tqdm(responses): 29 | inputs = tokenizer(r, return_tensors='pt').data 30 | for k, v in inputs.items(): 31 | inputs[k] = v.to(device) 32 | loss = model(**inputs, labels=inputs['input_ids'])[0] 33 | ppls.append(math.exp(loss.item())) 34 | return ppls 35 | 36 | def calculate_nli_scores(sentences1, sentences2, model, tokenizer, device): 37 | input_examples = [] 38 | cnt = 0 39 | for i in range(len(sentences1)): 40 | input_examples.append(InputExample(str(cnt), sentences1[i], sentences2[i], '0')) 41 | cnt += 1 42 | features = convert_examples_to_features( 43 | input_examples, 44 | tokenizer, 45 | label_list=['0', '1', '2'], 46 | max_length=128, 47 | output_mode='classification', 48 | ) 49 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(device) 50 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long).to(device) 51 | dataset = TensorDataset(all_input_ids, all_attention_mask) 52 | dataloader = DataLoader(dataset, batch_size=BATCH_SIZE) 53 | model.eval() 54 | all_probs = None 55 | const = torch.tensor([-1, 0, 1], device=device) 56 | with torch.no_grad(): 57 | for batch in tqdm(dataloader): 58 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]} 59 | outputs = model(**inputs) 60 | cur_scores = torch.sum(nn.Softmax(dim=-1)(outputs[0].detach()) * const, dim=-1) 61 | if all_probs is None: 62 | all_probs = cur_scores.cpu() 63 | else: 64 | all_probs = torch.cat((all_probs, cur_scores.cpu()), dim=0) 65 | scores = all_probs.tolist() 66 | return scores 67 | 68 | def calculate_entailment_scores(personas, responses, device): 69 | print('calculate entailment scores...') 70 | tokenizer = AutoTokenizer.from_pretrained('../roberta_mnli') 71 | model = AutoModelForSequenceClassification.from_pretrained('../roberta_mnli') 72 | entailment_scores = calculate_nli_scores(personas, responses, model, tokenizer, device) 73 | return entailment_scores 74 | 75 | def calculate_coherence_scores(history, responses, device): 76 | print('calculate coherence scores...') 77 | tokenizer = AutoTokenizer.from_pretrained('../coherence_nli_model') 78 | model = AutoModelForSequenceClassification.from_pretrained('../coherence_nli_model') 79 | coherence_scores = calculate_nli_scores(history, responses, model, tokenizer, device) 80 | return coherence_scores 81 | 82 | if __name__ == '__main__': 83 | with open(input_file, 'r', encoding='utf-8') as f: 84 | input_data = json.load(f) 85 | responses = [d[2] for d in input_data] 86 | personas = [d[0] for d in input_data] 87 | history = [d[1][-1] for d in input_data] 88 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 89 | ppls = calculate_ppls(responses, device) 90 | entailment_scores = calculate_entailment_scores(personas, responses, device) 91 | coherence_scores = calculate_coherence_scores(history, responses, device) 92 | with open(output_file, 'w') as f: 93 | json.dump({'ppls': ppls, 'entailment_scores': entailment_scores, 'coherence_scores': coherence_scores}, f) -------------------------------------------------------------------------------- /data_manipulation/data_diversification/filter_augmented_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from sys import argv 3 | from metrics import cal_novelty 4 | 5 | PPL_WEIGHT = 0.2 6 | ENTAILMENT_WEIGHT = 0.6 7 | COHERENCE_WEIGHT = 0.2 8 | PPL_NORMALIZER = 50 9 | TH = 0.15 10 | 11 | def cal_new_token(data): 12 | token_set = set() 13 | for d in data: 14 | for t in d[0].split(): 15 | token_set.add(t) 16 | for s in d[1]: 17 | for t in s.split(): 18 | token_set.add(t) 19 | for t in d[2].split(): 20 | token_set.add(t) 21 | return len(token_set) 22 | 23 | def get_novelty(original_data, new_data): 24 | original_sentence_persona, original_sentence_utterances = set(), set() 25 | new_persona, new_utterances = [], [] 26 | for d in original_data: 27 | original_sentence_persona.add(d[0]) 28 | original_sentence_utterances.add(d[2]) 29 | original_sentence_utterances.add(d[1][-1]) 30 | for d in new_data: 31 | new_persona.append(d[0]) 32 | new_utterances.append(d[2]) 33 | new_utterances.append(d[1][-1]) 34 | n1, n2, n3, n4 = cal_novelty(list(original_sentence_persona), new_persona) 35 | un1, un2, un3, un4 = cal_novelty(list(original_sentence_utterances), new_utterances) 36 | print('111') 37 | 38 | PUNCS = [',', '.', ';', '!', '?', ':'] 39 | TYPE_MAP = {(0, 0, 0): 'M_P_G', (0, 0, 1): 'M_P_R', (0, 1, 0): 'M_O_G', (0, 1, 1): 'M_O_R', 40 | (1, 0, 0): 'P_P_G', (1, 0, 1): 'P_P_R', (1, 1, 0): 'P_O_G', (1, 1, 1): 'P_O_R'} 41 | 42 | 43 | input_prefix = argv[1] 44 | with open(input_prefix + '_with_replace_response.json', 'r') as f: 45 | replace_data = json.load(f) 46 | with open(input_prefix + '_with_replace_response_type.json', 'r') as f: 47 | replace_types = json.load(f) 48 | replace_raw_idx = [d for d in replace_types['raw_data_idx']] 49 | replace_types = [d + [1] for d in replace_types['data_type']] 50 | with open(input_prefix + '_with_generated_response.json', 'r') as f: 51 | generated_data = json.load(f) 52 | with open(input_prefix + '_without_response_type.json', 'r') as f: 53 | generated_types = json.load(f) 54 | generated_raw_idx = [d for d in generated_types['raw_data_idx']] 55 | generated_types = [d + [0] for d in generated_types['data_type']] 56 | with open(input_prefix + '_with_replace_response_scores.json', 'r') as f: 57 | replace_scores = json.load(f) 58 | with open(input_prefix + '_with_generated_response_scores.json', 'r') as f: 59 | generated_scores = json.load(f) 60 | 61 | # generated_scores = {'ppls':[], 'entailment_scores': [], 'coherence_scores': []} 62 | all_data = replace_data[:len(replace_data)] + generated_data[:len(generated_data)] 63 | all_types = replace_types[:len(replace_data)] + generated_types[:len(generated_data)] 64 | all_raw_idx = replace_raw_idx[:len(replace_data)] + generated_raw_idx[:len(generated_data)] 65 | all_scores= {} 66 | for k, v in replace_scores.items(): 67 | all_scores[k] = replace_scores[k] + generated_scores[k] 68 | 69 | weighted_scores = [] 70 | for i in range(len(all_data)): 71 | weighted_scores.append(-all_scores['ppls'][i] / PPL_NORMALIZER * PPL_WEIGHT + 72 | all_scores['entailment_scores'] * ENTAILMENT_WEIGHT + 73 | all_scores['coherence_scores'][i] * COHERENCE_WEIGHT) 74 | 75 | res = [] 76 | res_types = [] 77 | raw_idx_map = {} 78 | selected_raw_idx = [] 79 | for i in range(len(weighted_scores)): 80 | if not raw_idx_map.__contains__(all_raw_idx[i]): 81 | raw_idx_map[all_raw_idx[i]] = [[], []] 82 | if weighted_scores[i] > TH: 83 | res.append(all_data[i]) 84 | res_types.append(all_types[i]) 85 | if all_types[i] == [1,0,0]: 86 | raw_idx_map[all_raw_idx[i]][0].append(all_data[i]) 87 | selected_raw_idx.append(all_raw_idx[i]) 88 | else: 89 | raw_idx_map[all_raw_idx[i]][1].append(all_data[i]) 90 | type_cnt = {} 91 | for res_type in res_types: 92 | if not type_cnt.__contains__(TYPE_MAP[tuple(res_type)]): 93 | type_cnt[TYPE_MAP[tuple(res_type)]] = 0 94 | type_cnt[TYPE_MAP[tuple(res_type)]] += 1 95 | print('The new augmented data number is ' + str(len(res))) 96 | with open('base_data/th0.99_model_entail_train_self.json', 'r') as f: 97 | original_data = json.load(f) 98 | # get_novelty(original_data, res) 99 | with open('th0.99_self_augmented_no_persona_filter.json', 'w') as f: 100 | json.dump(original_data + res, f) 101 | print('111') 102 | -------------------------------------------------------------------------------- /data_manipulation/data_distillation/get_distilled_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from sys import argv 4 | 5 | import numpy as np 6 | 7 | TH = 0.99 8 | 9 | def get_raw_data(original_lines): 10 | raw_data = [] 11 | all_personas = set() 12 | for i, line in enumerate(original_lines): 13 | line = line.strip() 14 | 15 | if len(line) == 0: 16 | continue 17 | 18 | space_idx = line.find(' ') 19 | if space_idx == -1: 20 | dialog_idx = int(line) 21 | else: 22 | dialog_idx = int(line[:space_idx]) 23 | 24 | if int(dialog_idx) == 1: 25 | raw_data.append({'persona': [], 'revised_persona': [], 'dialog': []}) 26 | 27 | dialog_line = line[space_idx + 1:].split('\t') 28 | dialog_line = [l.strip() for l in dialog_line] 29 | 30 | if dialog_line[0].startswith('your persona:'): 31 | persona_info = dialog_line[0].replace('your persona: ', '') 32 | all_personas.add(persona_info[:-1] + ' .') 33 | raw_data[-1]['persona'].append(persona_info[:-1] + ' .') 34 | elif len(dialog_line) > 1: 35 | raw_data[-1]['dialog'].append(dialog_line[0]) 36 | raw_data[-1]['dialog'].append(dialog_line[1]) 37 | return raw_data, list(all_personas) 38 | 39 | # The original train data file 40 | input_file = argv[1] 41 | # The logits given by NLI model obtained before 42 | logits_file = argv[2] 43 | # The output json file that contains all distilled samples that were determined as entailed by the NLI model 44 | output_file = argv[3] 45 | 46 | if '.json' in input_file: 47 | with open(input_file, 'r') as f: 48 | data = json.load(f) 49 | cnt1, cnt2, cnt3 = 0, 0, 0 50 | with open(logits_file, 'rb') as f: 51 | logits = pickle.load(f) 52 | entail = np.argmax(logits, axis=-1) 53 | entail_result = [] 54 | for i, d in enumerate(data): 55 | cur_logit = logits[i] 56 | if entail[i] == 0: 57 | cnt1 += 1 58 | elif entail[i] == 1: 59 | cnt2 += 1 60 | elif entail[i] == 2: 61 | softmax = np.exp(logits[i]) / np.sum(np.exp(logits[i])) 62 | if softmax[2] < TH: 63 | cnt2 += 1 64 | else: 65 | cnt3 += 1 66 | entail_result.append(data[i]) 67 | print(cnt1) 68 | print(cnt2) 69 | print(cnt3) 70 | with open(output_file, 'w') as f: 71 | json.dump(entail_result, f) 72 | else: 73 | with open(input_file, 'r', encoding='utf-8') as f: 74 | lines = f.readlines() 75 | raw_data, all_personas = get_raw_data(lines) 76 | 77 | with open(logits_file, 'rb') as f: 78 | d = pickle.load(f) 79 | cnt1, cnt2, cnt3 = 0, 0, 0 80 | entail_result = [] 81 | for dialog in d: 82 | cur_entail_result = [] 83 | for p in dialog: 84 | res = np.argmax(p, axis=-1) 85 | res = list(res) 86 | for i, r in enumerate(res): 87 | if r == 2: 88 | softmax = np.exp(p[i]) / np.sum(np.exp(p[i])) 89 | if softmax[2] < TH: 90 | res[i] = 1 91 | cur_entail_result.append(list(res)) 92 | entail_result.append(cur_entail_result) 93 | contradict_list = [] 94 | entail_list = [] 95 | neutral_list = [] 96 | entail_sample_idx = [] 97 | sample_idx = 0 98 | for i, dialog in enumerate(entail_result): 99 | for j, p in enumerate(dialog): 100 | for k, u in enumerate(p): 101 | if u == 0: 102 | cnt1 += 1 103 | contradict_list.append((i, j, k)) 104 | if u == 1: 105 | cnt2 += 1 106 | if u == 2: 107 | cnt3 += 1 108 | entail_list.append((i, j, k)) 109 | entail_sample_idx.append(sample_idx + k) 110 | sample_idx += len(p) 111 | print(cnt1) 112 | print(cnt2) 113 | print(cnt3) 114 | 115 | entail_data = [] 116 | for idx in entail_list: 117 | persona = raw_data[idx[0]]['persona'][idx[1]] 118 | response_idx = idx[2] * 2 + 1 119 | response = raw_data[idx[0]]['dialog'][response_idx] 120 | history = raw_data[idx[0]]['dialog'][max(0, response_idx - 3): response_idx] 121 | if len(history) == 0: 122 | history = ['__SILENCE__'] 123 | entail_data.append([persona, history, response]) 124 | with open(output_file, 'w', encoding='utf-8') as f: 125 | json.dump(entail_data, f) 126 | with open(output_file + 'raw_idx', 'w', encoding='utf-8') as f: 127 | json.dump(entail_sample_idx, f) 128 | -------------------------------------------------------------------------------- /data_manipulation/data_distillation/calculate_entailment.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from sys import argv 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import TensorDataset 8 | from tqdm import tqdm 9 | from transformers import AutoModelForSequenceClassification 10 | from transformers import AutoTokenizer 11 | from transformers import glue_convert_examples_to_features as convert_examples_to_features 12 | from transformers.data.processors.utils import InputExample 13 | 14 | NLI_MODEL_PATH = './persona_nli' 15 | 16 | # The original train file 17 | input_file = argv[1] 18 | # The output file that saves the NLI logits given the train samples 19 | output_file = argv[2] 20 | 21 | def get_dataloader(input_examples, tokenizer, device): 22 | features = convert_examples_to_features( 23 | input_examples, 24 | tokenizer, 25 | label_list=['0', '1'], 26 | max_length=128, 27 | output_mode='classification', 28 | ) 29 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(device) 30 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long).to(device) 31 | dataset = TensorDataset(all_input_ids, all_attention_mask) 32 | dataloader = DataLoader(dataset, batch_size=6) 33 | return dataloader 34 | 35 | 36 | def read_txt_data(input_file): 37 | with open(input_file, 'r') as f: 38 | lines = f.readlines() 39 | cur_persona1, cur_persona2 = [], [] 40 | cur_dialogs1, cur_dialogs2 = [], [] 41 | personas, dialogs = [], [] 42 | sentence_pairs = [] 43 | start = True 44 | for line in lines: 45 | if 'your persona:' in line or 'partner\'s persona' in line: 46 | if start and len(cur_persona1) > 0: 47 | personas.append(cur_persona1) 48 | personas.append(cur_persona2) 49 | dialogs.append(cur_dialogs1) 50 | dialogs.append(cur_dialogs2) 51 | cur_persona1, cur_persona2 = [], [] 52 | cur_dialogs1, cur_dialogs2 = [], [] 53 | start = False 54 | if 'your persona:' in line: 55 | persona_index = line.find('your persona:') 56 | persona = line[persona_index + 14: -1] 57 | cur_persona1.append(persona) 58 | elif 'partner\'s persona' in line: 59 | persona_index = line.find('partner\'s persona:') 60 | persona = line[persona_index + 19: -1] 61 | cur_persona2.append(persona) 62 | else: 63 | start = True 64 | space_index = line.find(' ') 65 | sents = line[space_index + 1:].split('\t') 66 | cur_dialogs1.append(sents[1]) 67 | cur_dialogs2.append(sents[0]) 68 | return personas, dialogs 69 | 70 | 71 | def read_json_data(input_file): 72 | with open(input_file, 'r') as f: 73 | data = json.load(f) 74 | examples = [] 75 | cnt = 0 76 | for d in data: 77 | examples.append(InputExample(str(cnt), d[0], d[2], '0')) 78 | cnt += 1 79 | return examples 80 | 81 | tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL_PATH) 82 | model = AutoModelForSequenceClassification.from_pretrained(NLI_MODEL_PATH) 83 | 84 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 85 | model.to(device) 86 | model.eval() 87 | pred_results = [] 88 | if '.json' in input_file: 89 | all_logits = None 90 | input_examples = read_json_data(input_file) 91 | train_dataloader = get_dataloader(input_examples, tokenizer, device) 92 | with torch.no_grad(): 93 | for batch in tqdm(train_dataloader): 94 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]} 95 | outputs = model(**inputs) 96 | if all_logits is None: 97 | all_logits = outputs[0].cpu().detach() 98 | else: 99 | all_logits = torch.cat((all_logits, outputs[0].cpu().detach()), dim=0) 100 | all_logits = all_logits.numpy() 101 | with open(output_file, 'wb') as f: 102 | pickle.dump(all_logits, f) 103 | else: 104 | personas, dialogs = read_txt_data(input_file) 105 | entailed_results = [] 106 | with torch.no_grad(): 107 | for i in tqdm(range(len(personas))): 108 | cur_persona = personas[i] 109 | cur_dialogs = dialogs[i] 110 | cnt = 0 111 | cur_pred_results = [] 112 | for persona in cur_persona: 113 | input_examples = [] 114 | for dialog in cur_dialogs: 115 | input_examples.append(InputExample(str(cnt), persona, dialog, '0')) 116 | cnt += 1 117 | train_dataloader = get_dataloader(input_examples, tokenizer, device) 118 | all_logits = None 119 | for batch in train_dataloader: 120 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]} 121 | outputs = model(**inputs) 122 | if all_logits is None: 123 | all_logits = outputs[0].detach() 124 | else: 125 | all_logits = torch.cat((all_logits, outputs[0].detach()), dim=0) 126 | results = torch.argmax(all_logits, dim=1) 127 | for j, r in enumerate(results): 128 | if r == 2: 129 | entailed_results.append((persona, cur_dialogs[j])) 130 | cur_pred_results.append(all_logits.cpu()) 131 | pred_results.append(cur_pred_results) 132 | with open('entailed_sentences.json', 'w') as f: 133 | json.dump(entailed_results, f) 134 | torch.save(pred_results, 'entailment_scores.bin') 135 | -------------------------------------------------------------------------------- /data_manipulation/prepare_model/train_nli_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 3 | import json 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup 8 | from transformers.data.processors.utils import InputExample 9 | from transformers import glue_convert_examples_to_features as convert_examples_to_features 10 | 11 | EPOCHS = 2 12 | LR = 1e-5 13 | WEIGHT_DECAY = 0.0 14 | WARMUP_RATIO = 0.05 15 | EVAL_INTERVAL = 1000 16 | BATCH_SIZE = 32 17 | MAX_GRAD_NORM = 1.0 18 | 19 | INPUT_MODEL_PATH = './roberta_mnli' 20 | OUTPUT_MODEL_FILE = 'best_model.bin' 21 | 22 | def get_input_examples(data): 23 | input_examples = [] 24 | label_dict = {'negative': '0', 'neutral': '1', 'positive': '2'} 25 | for d in data: 26 | input_examples.append(InputExample(d['id'], d['sentence1'], d['sentence2'], label_dict[d['label']])) 27 | return input_examples 28 | 29 | def eval_model(model, dev_dataloader, prev_best, step): 30 | dev_tqdm_data = tqdm(dev_dataloader, desc='Evaluation (step #{})'.format(step)) 31 | eval_loss = 0 32 | model.eval() 33 | preds, out_label_ids = None, None 34 | eval_step = 0 35 | with torch.no_grad(): 36 | for batch in dev_tqdm_data: 37 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], 'labels': batch[2]} 38 | outputs = model(**inputs) 39 | tmp_eval_loss, logits = outputs[:2] 40 | eval_step += 1 41 | eval_loss += tmp_eval_loss.mean().item() 42 | dev_tqdm_data.set_postfix({'loss': eval_loss / eval_step}) 43 | if preds is None: 44 | preds = logits.detach().cpu().numpy() 45 | out_label_ids = inputs["labels"].detach().cpu().numpy() 46 | else: 47 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 48 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) 49 | preds = np.argmax(preds, axis=1) 50 | accuracy = (preds == out_label_ids).astype(np.float32).mean().item() 51 | if accuracy > prev_best: 52 | print('Current model BEATS the previous best model, previous best is {:.3f}, current is {:.3f}'.format(prev_best, accuracy)) 53 | torch.save(model.state_dict(), OUTPUT_MODEL_FILE) 54 | prev_best = accuracy 55 | else: 56 | print('Current model CANNOT BEAT the previous best model, previous best is {:.3f}, current is {:.3f}'.format(prev_best, accuracy)) 57 | return prev_best 58 | 59 | with open('dialogue_nli_dataset/dialogue_nli_train.jsonl', 'r') as f: 60 | train_data = json.load(f) 61 | with open('dialogue_nli_dataset/dialogue_nli_dev.jsonl', 'r') as f: 62 | dev_data = json.load(f) 63 | train_examples = get_input_examples(train_data) 64 | dev_examples = get_input_examples(dev_data) 65 | 66 | tokenizer = AutoTokenizer.from_pretrained(INPUT_MODEL_PATH) 67 | model = AutoModelForSequenceClassification.from_pretrained(INPUT_MODEL_PATH) 68 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 69 | model.to(device) 70 | if torch.cuda.device_count() > 1: 71 | device = torch.device('cuda:0') 72 | model = model.to(device) 73 | model = torch.nn.parallel.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) 74 | 75 | train_features = convert_examples_to_features( 76 | train_examples, 77 | tokenizer, 78 | label_list=['0', '1', '2'], 79 | max_length=128, 80 | output_mode='classification', 81 | ) 82 | dev_features = convert_examples_to_features( 83 | dev_examples, 84 | tokenizer, 85 | label_list=['0', '1', '2'], 86 | max_length=128, 87 | output_mode='classification', 88 | ) 89 | train_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).to(device) 90 | train_attention_mask = torch.tensor([f.attention_mask for f in train_features], dtype=torch.long).to(device) 91 | train_labels = torch.tensor([f.label for f in train_features], dtype=torch.long).to(device) 92 | train_dataset = TensorDataset(train_input_ids, train_attention_mask, train_labels) 93 | train_sampler = RandomSampler(train_dataset) 94 | train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler) 95 | dev_input_ids = torch.tensor([f.input_ids for f in dev_features], dtype=torch.long).to(device) 96 | dev_attention_mask = torch.tensor([f.attention_mask for f in dev_features], dtype=torch.long).to(device) 97 | dev_labels = torch.tensor([f.label for f in dev_features], dtype=torch.long).to(device) 98 | dev_dataset = TensorDataset(dev_input_ids, dev_attention_mask, dev_labels) 99 | dev_dataloader = DataLoader(dev_dataset, batch_size=BATCH_SIZE) 100 | 101 | #eval_model(model, dev_dataloader, 0, 0) 102 | 103 | t_total = len(train_dataloader) * EPOCHS 104 | no_decay = ["bias", "LayerNorm.weight"] 105 | optimizer_grouped_parameters = [ 106 | { 107 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 108 | "weight_decay": WEIGHT_DECAY, 109 | }, 110 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 111 | ] 112 | optimizer = AdamW(optimizer_grouped_parameters, lr=LR, eps=1e-8) 113 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_RATIO, num_training_steps=t_total) 114 | 115 | prev_best = 0 116 | for epoch in range(EPOCHS): 117 | total_loss = 0.0 118 | tqdm_data = tqdm(train_dataloader, desc='Train (epoch #{})'.format(epoch + 1)) 119 | step = 0 120 | for batch in tqdm_data: 121 | model.train() 122 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], 'labels': batch[2]} 123 | outputs = model(**inputs) 124 | loss = outputs[0] 125 | loss = loss.mean() 126 | loss.backward() 127 | total_loss += loss.item() 128 | step += 1 129 | tqdm_data.set_postfix({'loss': total_loss / step}) 130 | torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM) 131 | optimizer.step() 132 | scheduler.step() 133 | optimizer.zero_grad() 134 | if step % EVAL_INTERVAL == 0: 135 | prev_best = eval_model(model, dev_dataloader, prev_best, step) 136 | prev_best = eval_model(model, dev_dataloader, prev_best, step) 137 | 138 | -------------------------------------------------------------------------------- /metrics/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chop; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chop; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | print "$length \n"; 82 | for(my $n=1;$n<=4;$n++) { 83 | my %REF_NGRAM_N = (); 84 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 85 | my $ngram = "$n"; 86 | for(my $w=0;$w<$n;$w++) { 87 | $ngram .= " ".$WORD[$start+$w]; 88 | } 89 | $REF_NGRAM_N{$ngram}++; 90 | } 91 | foreach my $ngram (keys %REF_NGRAM_N) { 92 | if (!defined($REF_NGRAM{$ngram}) || 93 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 94 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 95 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 96 | } 97 | } 98 | } 99 | } 100 | $length_translation += $length_translation_this_sentence; 101 | $length_reference += $closest_length; 102 | for(my $n=1;$n<=4;$n++) { 103 | my %T_NGRAM = (); 104 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 105 | my $ngram = "$n"; 106 | for(my $w=0;$w<$n;$w++) { 107 | $ngram .= " ".$WORD[$start+$w]; 108 | } 109 | $T_NGRAM{$ngram}++; 110 | } 111 | foreach my $ngram (keys %T_NGRAM) { 112 | $ngram =~ /^(\d+) /; 113 | # my $n = $1; 114 | # my $corr = 0; 115 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 116 | $TOTAL[$n] += $T_NGRAM{$ngram}; 117 | if (defined($REF_NGRAM{$ngram})) { 118 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 119 | $CORRECT[$n] += $T_NGRAM{$ngram}; 120 | # $corr = $T_NGRAM{$ngram}; 121 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 122 | } 123 | else { 124 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 125 | # $corr = $REF_NGRAM{$ngram}; 126 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 127 | } 128 | } 129 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 130 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 131 | } 132 | } 133 | $s++; 134 | } 135 | my $brevity_penalty = 1; 136 | my $bleu = 0; 137 | 138 | my @bleu=(); 139 | 140 | for(my $n=1;$n<=4;$n++) { 141 | if (defined ($TOTAL[$n])){ 142 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 143 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 144 | }else{ 145 | $bleu[$n]=0; 146 | } 147 | } 148 | 149 | if ($length_reference==0){ 150 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 151 | exit(1); 152 | } 153 | 154 | if ($length_translation<$length_reference) { 155 | $brevity_penalty = exp(1-$length_reference/$length_translation); 156 | } 157 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 158 | my_log( $bleu[2] ) + 159 | my_log( $bleu[3] ) + 160 | my_log( $bleu[4] ) ) / 4) ; 161 | printf "BLEU = %.5f, %.5f/%.5f/%.5f/%.5f (BP=%.5f, ratio=%.5f, hyp_len=%d, ref_len=%d)\n", 162 | 100*$bleu, 163 | 100*$bleu[1], 164 | 100*$bleu[2], 165 | 100*$bleu[3], 166 | 100*$bleu[4], 167 | $brevity_penalty, 168 | $length_translation / $length_reference, 169 | $length_translation, 170 | $length_reference; 171 | 172 | 173 | print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; 174 | 175 | sub my_log { 176 | return -9999999999 unless $_[0]; 177 | return log($_[0]); 178 | } -------------------------------------------------------------------------------- /BT/recover.py: -------------------------------------------------------------------------------- 1 | import json 2 | from sys import argv 3 | 4 | from nltk.translate.bleu_score import sentence_bleu 5 | from tqdm import tqdm 6 | 7 | bt_file = argv[1] 8 | original_file = argv[2] 9 | index_file = argv[3] 10 | output_file = argv[4] 11 | 12 | SYMBOLS = ['.', ',', '!', '?', ';', '"'] 13 | SPECIAL1 = '’' 14 | 15 | def find_the_most_different_replace(original_sentence, candidates): 16 | min_bleu = 100 17 | res = candidates[0] 18 | for c in candidates: 19 | bleu = sentence_bleu(original_sentence.lower(), c.lower()) 20 | if bleu < min_bleu and abs(len(original_sentence) - len(c)) < len(original_sentence) * 0.4: 21 | res = c.lower() 22 | min_bleu = bleu 23 | return res 24 | 25 | def clean_data_core(line): 26 | line_list = [] 27 | for c in line: 28 | if c in SYMBOLS or c == SPECIAL1: 29 | if len(line_list) > 0 and line_list[-1] != ' ': 30 | line_list.append(' ') 31 | if c == SPECIAL1: 32 | c = '\'' 33 | line_list.append(c) 34 | new_sentence = ''.join(line_list) 35 | end_idx = len(new_sentence) 36 | for i in range(len(new_sentence)): 37 | cur_idx = len(new_sentence) - 1 - i 38 | if new_sentence[cur_idx] not in SYMBOLS and new_sentence[cur_idx] != ' ': 39 | break 40 | else: 41 | if new_sentence[cur_idx] in SYMBOLS: 42 | end_idx = cur_idx + 1 43 | new_line = new_sentence[:end_idx] 44 | return new_line 45 | 46 | def recover_persona_end(cleaned_persona): 47 | if len(cleaned_persona) >= 2 and cleaned_persona[-1] in SYMBOLS and cleaned_persona[-2] == ' ': 48 | j = len(cleaned_persona) - 2 49 | while cleaned_persona[j] == ' ': 50 | j -= 1 51 | cleaned_persona = cleaned_persona[:j + 1] + '.' 52 | elif len(cleaned_persona) > 0 and cleaned_persona[-1] not in SYMBOLS and cleaned_persona[-1] != ' ': 53 | cleaned_persona = cleaned_persona + '.' 54 | return cleaned_persona 55 | 56 | def clean_data(data, is_json=False): 57 | if is_json: 58 | for i, sample in enumerate(data): 59 | persona, history, response = sample[0], sample[1], sample[2] 60 | cleaned_persona = clean_data_core(persona) 61 | cleaned_persona = recover_persona_end(cleaned_persona) 62 | for j, h in enumerate(history): 63 | if h != '__SILENCE__': 64 | history[j] = clean_data_core(h) 65 | cleaned_response = clean_data_core(response) 66 | data[i] = [cleaned_persona, history, cleaned_response] 67 | else: 68 | for i, line in enumerate(data): 69 | line = line.lower().strip() 70 | if 'your persona: ' in line: 71 | cleaned_persona = clean_data_core(line) 72 | if cleaned_persona[-1] in SYMBOLS and cleaned_persona[-2] == ' ': 73 | j = len(cleaned_persona) - 2 74 | while cleaned_persona[j] == ' ': 75 | j -= 1 76 | cleaned_persona = cleaned_persona[:j + 1] + '.' 77 | elif cleaned_persona[-1] not in SYMBOLS and cleaned_persona[-1] != ' ': 78 | cleaned_persona = cleaned_persona + '.' 79 | data[i] = cleaned_persona + '\n' 80 | else: 81 | items = line.split('\t') 82 | for j, s in enumerate(items[:2]): 83 | items[j] = clean_data_core(s) 84 | data[i] = '\t'.join(items) + '\n' 85 | return data 86 | 87 | beam_size = 25 88 | bt_sentences = [] 89 | with open(bt_file, 'r', encoding='utf-8') as f: 90 | lines = f.readlines() 91 | i = 0 92 | while i < len(lines): 93 | cur_sentences = [s.strip() for s in lines[i: i + beam_size]] 94 | i += beam_size 95 | bt_sentences.append(cur_sentences) 96 | with open(index_file, 'r') as f: 97 | indices = json.load(f) 98 | if '.txt' in original_file: 99 | with open(original_file, 'r') as f: 100 | original_lines = f.readlines() 101 | prev_line_idx = -1 102 | for i, line_idx in enumerate(tqdm(indices)): 103 | cur_bt_sentences = bt_sentences[i] 104 | line = original_lines[line_idx].strip() 105 | if 'your persona: ' in line: 106 | start_index = line.find('your persona: ') 107 | original_sentence = line[start_index + 14:] 108 | replace_start, replace_end = start_index + 14, len(line) 109 | else: 110 | space_index = line.find(' ') 111 | items = line[space_index + 1:].split('\t') 112 | if line_idx != prev_line_idx: 113 | original_sentence = items[0] 114 | replace_start, replace_end = space_index + 1, space_index + len(original_sentence) 115 | else: 116 | original_sentence = items[1] 117 | t_index = line.find('\t') 118 | replace_start, replace_end = t_index + 1, t_index + len(original_sentence) + 1 119 | replace_sentence = find_the_most_different_replace(original_sentence, cur_bt_sentences) 120 | original_lines[line_idx] = line[:replace_start] + replace_sentence + line[replace_end:] + '\n' 121 | prev_line_idx = line_idx 122 | cleaned_lines = clean_data(original_lines) 123 | with open(output_file, 'w', encoding='utf-8') as f: 124 | f.writelines(original_lines) 125 | if '.json' in original_file: 126 | with open(original_file, 'r', encoding='utf-8') as f: 127 | original_data = json.load(f) 128 | for i, data_index in enumerate(tqdm(indices)): 129 | cur_bt_sentences = bt_sentences[i] 130 | sample_index = data_index[0] 131 | original_sample = original_data[sample_index] 132 | if data_index[1] < 3: 133 | original_sentence = original_sample[data_index[1]] 134 | replace_sentence = find_the_most_different_replace(original_sentence, cur_bt_sentences) 135 | original_sample[data_index[1]] = replace_sentence 136 | else: 137 | original_sentence = original_sample[1][data_index[1] - 3] 138 | if original_sentence != '__SILENCE__': 139 | replace_sentence = find_the_most_different_replace(original_sentence, cur_bt_sentences) 140 | original_sample[1][data_index[1] - 3] = replace_sentence 141 | original_data[sample_index] = original_sample 142 | cleaned_data = clean_data(original_data, is_json=True) 143 | with open(output_file, 'w', encoding='utf-8') as f: 144 | json.dump(original_data, f) 145 | -------------------------------------------------------------------------------- /data_manipulation/prepare_model/train_coherence_nli.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 3 | import json 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup 8 | from transformers.data.processors.utils import InputExample 9 | from transformers import glue_convert_examples_to_features as convert_examples_to_features 10 | 11 | EPOCHS = 3 12 | LR = 1e-5 13 | WEIGHT_DECAY = 0.0 14 | WARMUP_RATIO = 0.05 15 | EVAL_INTERVAL = 2000 16 | BATCH_SIZE = 96 17 | MAX_GRAD_NORM = 1.0 18 | 19 | INPUT_MODEL_PATH = './roberta_mnli' 20 | OUTPUT_MODEL_FILE = 'best_model.bin' 21 | 22 | def get_input_examples(data): 23 | input_examples = [] 24 | label_dict = {'contradiction': '0', 'neutral': '1', 'entailment': '2'} 25 | for d in data: 26 | input_examples.append(InputExample(d['pairID'], d['sentence1'], d['sentence2'], label_dict[d['gold_label']])) 27 | return input_examples 28 | 29 | def eval_model(model, dev_dataloader, prev_best, step): 30 | dev_tqdm_data = tqdm(dev_dataloader, desc='Evaluation (step #{})'.format(step)) 31 | eval_loss = 0 32 | model.eval() 33 | preds, out_label_ids = None, None 34 | eval_step = 0 35 | with torch.no_grad(): 36 | for batch in dev_tqdm_data: 37 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], 'labels': batch[2]} 38 | outputs = model(**inputs) 39 | tmp_eval_loss, logits = outputs[:2] 40 | eval_step += 1 41 | eval_loss += tmp_eval_loss.mean().item() 42 | dev_tqdm_data.set_postfix({'loss': eval_loss / eval_step}) 43 | if preds is None: 44 | preds = logits.detach().cpu().numpy() 45 | out_label_ids = inputs["labels"].detach().cpu().numpy() 46 | else: 47 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 48 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) 49 | preds = np.argmax(preds, axis=1) 50 | accuracy = (preds == out_label_ids).astype(np.float32).mean().item() 51 | if accuracy > prev_best: 52 | print('Current model BEATS the previous best model, previous best is {:.3f}, current is {:.3f}'.format(prev_best, accuracy)) 53 | torch.save(model.state_dict(), OUTPUT_MODEL_FILE) 54 | prev_best = accuracy 55 | else: 56 | print('Current model CANNOT BEAT the previous best model, previous best is {:.3f}, current is {:.3f}'.format(prev_best, accuracy)) 57 | return prev_best 58 | 59 | with open('convai_nli_valid.jsonl', 'r', encoding='utf-8') as f: 60 | lines = f.readlines() 61 | train_data = [] 62 | for line in lines: 63 | train_data.append(json.loads(line.strip())) 64 | with open('convai_nli_valid.jsonl', 'r', encoding='utf-8') as f: 65 | lines = f.readlines() 66 | dev_data = [] 67 | for line in lines: 68 | dev_data.append(json.loads(line.strip())) 69 | train_examples = get_input_examples(train_data) 70 | dev_examples = get_input_examples(dev_data) 71 | 72 | tokenizer = AutoTokenizer.from_pretrained(INPUT_MODEL_PATH) 73 | model = AutoModelForSequenceClassification.from_pretrained(INPUT_MODEL_PATH) 74 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 75 | model.to(device) 76 | if torch.cuda.device_count() > 1: 77 | device = torch.device('cuda:0') 78 | model = model.to(device) 79 | model = torch.nn.parallel.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) 80 | 81 | train_features = convert_examples_to_features( 82 | train_examples, 83 | tokenizer, 84 | label_list=['0', '1', '2'], 85 | max_length=128, 86 | output_mode='classification', 87 | ) 88 | dev_features = convert_examples_to_features( 89 | dev_examples, 90 | tokenizer, 91 | label_list=['0', '1', '2'], 92 | max_length=128, 93 | output_mode='classification', 94 | ) 95 | train_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).to(device) 96 | train_attention_mask = torch.tensor([f.attention_mask for f in train_features], dtype=torch.long).to(device) 97 | train_labels = torch.tensor([f.label for f in train_features], dtype=torch.long).to(device) 98 | train_dataset = TensorDataset(train_input_ids, train_attention_mask, train_labels) 99 | train_sampler = RandomSampler(train_dataset) 100 | train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler) 101 | dev_input_ids = torch.tensor([f.input_ids for f in dev_features], dtype=torch.long).to(device) 102 | dev_attention_mask = torch.tensor([f.attention_mask for f in dev_features], dtype=torch.long).to(device) 103 | dev_labels = torch.tensor([f.label for f in dev_features], dtype=torch.long).to(device) 104 | dev_dataset = TensorDataset(dev_input_ids, dev_attention_mask, dev_labels) 105 | dev_dataloader = DataLoader(dev_dataset, batch_size=BATCH_SIZE) 106 | 107 | t_total = len(train_dataloader) * EPOCHS 108 | no_decay = ["bias", "LayerNorm.weight"] 109 | optimizer_grouped_parameters = [ 110 | { 111 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 112 | "weight_decay": WEIGHT_DECAY, 113 | }, 114 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 115 | ] 116 | optimizer = AdamW(optimizer_grouped_parameters, lr=LR, eps=1e-8) 117 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_RATIO, num_training_steps=t_total) 118 | 119 | prev_best = 0 120 | for epoch in range(EPOCHS): 121 | total_loss = 0.0 122 | tqdm_data = tqdm(train_dataloader, desc='Train (epoch #{})'.format(epoch + 1)) 123 | step = 0 124 | prev_best = eval_model(model, dev_dataloader, prev_best, step) 125 | for batch in tqdm_data: 126 | model.train() 127 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], 'labels': batch[2]} 128 | outputs = model(**inputs) 129 | loss = outputs[0] 130 | loss = loss.mean() 131 | loss.backward() 132 | total_loss += loss.item() 133 | step += 1 134 | tqdm_data.set_postfix({'loss': total_loss / step}) 135 | torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM) 136 | optimizer.step() 137 | scheduler.step() 138 | optimizer.zero_grad() 139 | if step % EVAL_INTERVAL == 0: 140 | prev_best = eval_model(model, dev_dataloader, prev_best, step) 141 | prev_best = eval_model(model, dev_dataloader, prev_best, step) 142 | -------------------------------------------------------------------------------- /metrics/dstc_example.py: -------------------------------------------------------------------------------- 1 | # author: Xiang Gao @ Microsoft Research, Oct 2018 2 | # evaluate DSTC-task2 submissions. https://github.com/DSTC-MSR-NLP/DSTC7-End-to-End-Conversation-Modeling 3 | 4 | from util import * 5 | from metrics import * 6 | from tokenizers import * 7 | 8 | def extract_cells(path_in, path_hash): 9 | keys = [line.strip('\n') for line in open(path_hash)] 10 | cells = dict() 11 | for line in open(path_in, encoding='utf-8'): 12 | c = line.strip('\n').split('\t') 13 | k = c[0] 14 | if k in keys: 15 | cells[k] = c[1:] 16 | return cells 17 | 18 | 19 | def extract_hyp_refs(raw_hyp, raw_ref, path_hash, fld_out, n_refs=6, clean=False, vshuman=-1): 20 | cells_hyp = extract_cells(raw_hyp, path_hash) 21 | cells_ref = extract_cells(raw_ref, path_hash) 22 | if not os.path.exists(fld_out): 23 | os.makedirs(fld_out) 24 | 25 | def _clean(s): 26 | if clean: 27 | return clean_str(s) 28 | else: 29 | return s 30 | 31 | keys = sorted(cells_hyp.keys()) 32 | with open(fld_out + '/hash.txt', 'w', encoding='utf-8') as f: 33 | f.write(unicode('\n'.join(keys))) 34 | 35 | lines = [_clean(cells_hyp[k][-1]) for k in keys] 36 | path_hyp = fld_out + '/hyp.txt' 37 | with open(path_hyp, 'w', encoding='utf-8') as f: 38 | f.write(unicode('\n'.join(lines))) 39 | 40 | lines = [] 41 | for _ in range(n_refs): 42 | lines.append([]) 43 | for k in keys: 44 | refs = cells_ref[k] 45 | for i in range(n_refs): 46 | idx = i % len(refs) 47 | if idx == vshuman: 48 | idx = (idx + 1) % len(refs) 49 | lines[i].append(_clean(refs[idx].split('|')[1])) 50 | 51 | path_refs = [] 52 | for i in range(n_refs): 53 | path_ref = fld_out + '/ref%i.txt'%i 54 | with open(path_ref, 'w', encoding='utf-8') as f: 55 | f.write(unicode('\n'.join(lines[i]))) 56 | path_refs.append(path_ref) 57 | 58 | return path_hyp, path_refs 59 | 60 | 61 | def eval_one_system(submitted, keys, multi_ref, n_refs=6, n_lines=None, clean=False, vshuman=-1, PRINT=True): 62 | 63 | print('evaluating %s' % submitted) 64 | 65 | fld_out = submitted.replace('.txt','') 66 | if clean: 67 | fld_out += '_cleaned' 68 | path_hyp, path_refs = extract_hyp_refs(submitted, multi_ref, keys, fld_out, n_refs, clean=clean, vshuman=vshuman) 69 | nist, bleu, meteor, entropy, div, avg_len = nlp_metrics(path_refs, path_hyp, fld_out, n_lines=n_lines) 70 | 71 | if n_lines is None: 72 | n_lines = len(open(path_hyp, encoding='utf-8').readlines()) 73 | 74 | if PRINT: 75 | print('n_lines = '+str(n_lines)) 76 | print('NIST = '+str(nist)) 77 | print('BLEU = '+str(bleu)) 78 | print('METEOR = '+str(meteor)) 79 | print('entropy = '+str(entropy)) 80 | print('diversity = ' + str(div)) 81 | print('avg_len = '+str(avg_len)) 82 | 83 | return [n_lines] + nist + bleu + [meteor] + entropy + div + [avg_len] 84 | 85 | 86 | def eval_all_systems(files, path_report, keys, multi_ref, n_refs=6, n_lines=None, clean=False, vshuman=False): 87 | # evaluate all systems (*.txt) in each folder `files` 88 | 89 | with open(path_report, 'w') as f: 90 | f.write('\t'.join( 91 | ['fname', 'n_lines'] + \ 92 | ['nist%i'%i for i in range(1, 4+1)] + \ 93 | ['bleu%i'%i for i in range(1, 4+1)] + \ 94 | ['meteor'] + \ 95 | ['entropy%i'%i for i in range(1, 4+1)] +\ 96 | ['div1','div2','avg_len'] 97 | ) + '\n') 98 | 99 | for fl in files: 100 | if fl.endswith('.txt'): 101 | submitted = fl 102 | results = eval_one_system(submitted, keys=keys, multi_ref=multi_ref, n_refs=n_refs, clean=clean, n_lines=n_lines, vshuman=vshuman, PRINT=False) 103 | with open(path_report, 'a') as f: 104 | f.write('\t'.join(map(str, [submitted] + results)) + '\n') 105 | else: 106 | for fname in os.listdir(fl): 107 | if fname.endswith('.txt'): 108 | submitted = fl + '/' + fname 109 | results = eval_one_system(submitted, keys=keys, multi_ref=multi_ref, n_refs=n_refs, clean=clean, n_lines=n_lines, vshuman=vshuman, PRINT=False) 110 | with open(path_report, 'a') as f: 111 | f.write('\t'.join(map(str, [submitted] + results)) + '\n') 112 | 113 | print('report saved to: '+path_report, file=sys.stderr) 114 | 115 | 116 | if __name__ == '__main__': 117 | 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument('submitted') # if 'all' or '*', eval all teams listed in dstc/teams.txt 120 | # elif endswith '.txt', eval this single file 121 | # else, eval all *.txt in folder `submitted_fld` 122 | 123 | parser.add_argument('--clean', '-c', action='store_true') # whether to clean ref and hyp before eval 124 | parser.add_argument('--n_lines', '-n', type=int, default=-1) # eval all lines (default) or top n_lines (e.g., for fast debugging) 125 | parser.add_argument('--n_refs', '-r', type=int, default=6) # number of references 126 | parser.add_argument('--vshuman', '-v', type=int, default='1') # when evaluating against human performance (N in refN.txt that should be removed) 127 | # in which case we need to remove human output from refs 128 | parser.add_argument('--refs', '-g', default='dstc/test.refs') 129 | parser.add_argument('--keys', '-k', default='keys/test.2k.txt') 130 | parser.add_argument('--teams', '-i', type=str, default='dstc/teams.txt') 131 | parser.add_argument('--report', '-o', type=str, default=None) 132 | args = parser.parse_args() 133 | print('Args: %s\n' % str(args), file=sys.stderr) 134 | 135 | if args.n_lines < 0: 136 | n_lines = None # eval all lines 137 | else: 138 | n_lines = args.n_lines # just eval top n_lines 139 | 140 | if args.submitted.endswith('.txt'): 141 | eval_one_system(args.submitted, keys=args.keys, multi_ref=args.refs, clean=args.clean, n_lines=n_lines, n_refs=args.n_refs, vshuman=args.vshuman) 142 | else: 143 | fname_report = 'report_ref%i'%args.n_refs 144 | if args.clean: 145 | fname_report += '_cleaned' 146 | fname_report += '.tsv' 147 | if args.submitted == 'all' or args.submitted == '*': 148 | files = ['dstc/' + line.strip('\n') for line in open(args.teams)] 149 | path_report = 'dstc/' + fname_report 150 | else: 151 | files = [args.submitted] 152 | path_report = args.submitted + '/' + fname_report 153 | if args.report != None: 154 | path_report = args.report 155 | eval_all_systems(files, path_report, keys=args.keys, multi_ref=args.refs, clean=args.clean, n_lines=n_lines, n_refs=args.n_refs, vshuman=args.vshuman) 156 | -------------------------------------------------------------------------------- /attention_experiment/get_target_samples.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | from transformers.tokenization_gpt2 import GPT2Tokenizer 5 | 6 | from itertools import chain 7 | 8 | IGNORE_TOKENS = ['i', 'my', 'he', 'she', '.', 'am', 'was', 'is', 'are', 'have', 'has', 'had'] 9 | 10 | def read_txt_data(input_file): 11 | with open(input_file, 'r', encoding='utf-8') as f: 12 | lines = f.readlines() 13 | data = [] 14 | for line in lines: 15 | line = line.strip() 16 | if len(line) == 0: 17 | continue 18 | 19 | space_idx = line.find(' ') 20 | if space_idx == -1: 21 | dialog_idx = int(line) 22 | else: 23 | dialog_idx = int(line[:space_idx]) 24 | 25 | if int(dialog_idx) == 1: 26 | data.append({'persona_info': [], 'dialog': []}) 27 | 28 | dialog_line = line[space_idx + 1:].split('\t') 29 | dialog_line = [l.strip() for l in dialog_line] 30 | 31 | if dialog_line[0].startswith('your persona:'): 32 | persona_info = dialog_line[0].replace('your persona: ', '') 33 | if persona_info[-1] == '.' and persona_info[-2] != ' ': 34 | persona_info = persona_info[:-1] + ' .' 35 | data[-1]['persona_info'].append(persona_info) 36 | elif len(dialog_line) > 1: 37 | data[-1]['dialog'].append(dialog_line[0]) 38 | data[-1]['dialog'].append(dialog_line[1]) 39 | return data 40 | 41 | def get_dataset_and_history_positions(data): 42 | dataset = [] 43 | positions = [] 44 | for chat in data: 45 | persona_info = [s.split() for s in chat['persona_info']] 46 | dialog = [] 47 | for i, replica in enumerate(chat['dialog'], 1): 48 | dialog.append(replica.split()) 49 | if not i % 2: 50 | dataset.append((persona_info, dialog[:], [])) 51 | persona_len = [len(x) for x in persona_info] 52 | dialog_len = [len(x) for x in dialog] 53 | persona_pos, history_pos = [], [] 54 | p = 1 55 | for l in persona_len: 56 | p = p + l 57 | persona_pos.append(p) 58 | for j in range(max(len(dialog_len) - 6, 0), len(dialog_len) - 1): 59 | p = p + 1 + dialog_len[j] 60 | history_pos.append(p) 61 | positions.append([persona_pos, history_pos]) 62 | for i, data in enumerate(dataset): 63 | dataset[i] = [[' '.join(p) for p in data[0]], [' '.join(u) for u in data[1]]] 64 | return dataset, positions 65 | 66 | def _get_entail_index_matched_token_positions(entail_data, tokenizer): 67 | all_attention_positions = [] 68 | matched_token_positions = [] 69 | for i in range(len(entail_data)): 70 | entail_sample = entail_data[i] 71 | raw_idx = raw_sample_idx[i] 72 | if tokenizer is None: 73 | persona = entail_sample[0].split() 74 | response = entail_sample[2].split() 75 | else: 76 | persona = [t[1:] if t[0] == 'Ġ' else t for t in tokenizer.tokenize(entail_sample[0])] 77 | response = [t[1:] if t[0] == 'Ġ' else t for t in tokenizer.tokenize(entail_sample[2])] 78 | target_positions = [] 79 | for i in range(len(persona)): 80 | for j in range(len(response)): 81 | if persona[i] not in IGNORE_TOKENS and persona[i] == response[j]: 82 | target_positions.append([i, j]) 83 | all_attention_positions.append([raw_idx, entail_sample[0]]) 84 | matched_token_positions.append(target_positions) 85 | return all_attention_positions, matched_token_positions 86 | 87 | def get_dataset_and_persona_positions(raw_data, entail_data, tokenizer): 88 | all_attention_positions, all_matched_token_positions = _get_entail_index_matched_token_positions(entail_data, tokenizer) 89 | raw_sample_idx_set = set([s[0] for s in all_attention_positions]) 90 | all_dataset = [] 91 | index = 0 92 | for chat in raw_data: 93 | persona_info = [s for s in chat['persona_info']] 94 | dialog = [] 95 | for i, replica in enumerate(chat['dialog'], 1): 96 | dialog.append(replica) 97 | if not i % 2: 98 | all_dataset.append((persona_info, dialog[:], [])) 99 | new_dataset = [all_dataset[i] for i in [x[0] for x in all_attention_positions]] 100 | 101 | all_target_persona_sentence_positions = [] 102 | all_persona_positions = [] 103 | for i in range(len(all_attention_positions)): 104 | target_persona = all_attention_positions[i][1] 105 | cur_data = new_dataset[i] 106 | target_persona_text_index = 0 107 | while target_persona != cur_data[0][target_persona_text_index]: 108 | target_persona_text_index += 1 109 | if tokenizer is None: 110 | tokenized_personas = [p.split() for p in cur_data[0]] 111 | else: 112 | tokenized_personas = [tokenizer.tokenize(p) for p in cur_data[0]] 113 | target_persona_start_token_index = 1 + sum([len(x) for x in tokenized_personas[:target_persona_text_index]]) 114 | target_persona_end_token_index = target_persona_start_token_index + \ 115 | len(tokenized_personas[target_persona_text_index]) 116 | all_target_persona_sentence_positions.append([target_persona_start_token_index, target_persona_end_token_index]) 117 | all_persona_length = 1 + sum([len(x) for x in tokenized_personas]) 118 | all_persona_positions.append([1, all_persona_length]) 119 | if len(all_matched_token_positions[i]) > 0: 120 | for j, positions in enumerate(all_matched_token_positions[i]): 121 | all_matched_token_positions[i][j][0] = all_matched_token_positions[i][j][0] + \ 122 | target_persona_start_token_index 123 | return new_dataset, all_target_persona_sentence_positions, all_persona_positions, all_matched_token_positions 124 | 125 | with open('base_data/th0.99_dev_self.json', 'r') as f: 126 | entail_data = json.load(f) 127 | with open('base_data/th0.99_dev_raw_sample_idx.json', 'r') as f: 128 | raw_sample_idx = json.load(f) 129 | data = read_txt_data('../datasets/ConvAI2/valid_self_original.txt') 130 | gpt2_special_sumbol = 'Ġ' 131 | 132 | MODE = 'persona' 133 | MODEL = 'gpt2' 134 | 135 | if MODE == 'persona': 136 | tokenizer = None 137 | if MODEL == 'gpt2': 138 | tokenizer = GPT2Tokenizer.from_pretrained('../gpt2-small') 139 | new_dataset, target_persona_sentence_position, all_persona_positions, matched_token_positions = \ 140 | get_dataset_and_persona_positions(data, entail_data, tokenizer) 141 | with open('th0.99_consistent_dataset.json', 'w') as f: 142 | json.dump(new_dataset, f) 143 | with open('th0.99_consistent_positions.json', 'w') as f: 144 | json.dump({'token_positions': matched_token_positions, 145 | 'target_persona_positions': target_persona_sentence_position, 146 | 'whole_persona_positions': all_persona_positions}, f) 147 | else: 148 | data, positions = get_dataset_and_history_positions(data) 149 | with open('attention_dev_data.json', 'w', encoding='utf-8') as f: 150 | json.dump(data, f) 151 | with open('attention_dev_position.json', 'w', encoding='utf-8') as f: 152 | json.dump(positions, f) 153 | 154 | print('111') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # D3 2 | ### The implementation for ACL 2022 paper 3 | 4 | ### A Model-Agnostic Data Manipulation Method for Persona-based Dialogue Generation 5 | 6 | ![Framework](https://github.com/caoyu-noob/D3/blob/main/framework.PNG) 7 | 8 | --- 9 | 10 | ### File structure 11 | 1. The main entrance to train the model is in `train.py` in the root directory. We also provide some example shells 12 | for running under different conditions. 13 | 2. The code related to our data manipulation method D3 is under `./data_manipulation`, where you can obtain augmented 14 | data by using code under this directory. 15 | 3. `./attention_experiment` contains scripts for the attention experiments (like Appendix C.1 and C.4) in our paper 16 | 4. `./model` contains scripts for all other necessary parts to run experiment, including models, optimizer, data interface 17 | and so on. 18 | 19 | --- 20 | 21 | ### Requirements 22 | 1. python == 3.7.0 23 | 2. torch==1.5.0 24 | 3. transformers==3.1.0 25 | 4. spacy==2.2.4 26 | 5. fairseq==0.9.0 (I downloaded the source code into the root directory) 27 | 6. sentencepiece==0.1.94 28 | 29 | For evaluating the generated responses, you need to install `java-1.8.0`, `perl`, java-1.8.0, as well as 30 | perl library including XML::Twig, Sort::Naturally, String::Util (I use cpanm to install them on Linux). 31 | 32 | [METEOR](https://www.cs.cmu.edu/~alavie/METEOR/download/meteor-1.5.tar.gz) is also needed for evaluating the quality 33 | of responses, please unzip it and put it under `./metrics/` 34 | 35 | We also use [BERTScore](https://github.com/Tiiiger/bert_score) as a metric in our experiments, you may need to download 36 | a proper BERT model for a successful evaluation. Here we use a roberta-large model. To rescale the score, we have put 37 | the baseline file under `./bert_score/rescale_baseline/en`. If you want to rescale the bert score, please add 38 | `--rescale_with_baseline` in the training shell. 39 | 40 | --- 41 | ## Run the code 42 | 43 | To be honest, just applying step 3.Data Distillation can achieve a satisfactory performance. The step 4.data diversification 44 | contribute less to the final results and is more complex. 45 | 46 | ## 1. Obtain PersonaChat Dataset 47 | 48 | Obtain PersonaChat dataset via [ParlAI](https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/personachat) 49 | or [our zipped version](https://drive.google.com/file/d/1zQVO5MuEy3wBUfZpM39uYmloD3-Rld4T/view) and put them into the `./datasets` directory. 50 | 51 | ## 2. Prepare models 52 | At first, we have to get all trained models we need for data manipulation in experiments. 53 | You need go to `./data_manipulation/prepare_model`. 54 | 55 | ##### 1) NLI model for evaluating persona consistency 56 | You need to download [DialogueNLI dataset](https://wellecks.github.io/dialogue_nli/) 57 | and put it under this directory. Also, download large size [RoBERTa MNLI model](https://huggingface.co/roberta-large-mnli) 58 | and put it under this directory, renaming the document as `roberta_mnli/`. 59 | 60 | Then you can train the NLI model using this dataset using script `train_nli_model.py`. 61 | 62 | After obtain the trained best model, you need to renamed the file `best_model.bin` as `pytorch_model.bin` for the following 63 | use. Define the path that saves the trained NLI model for persona consistency as `PERSONA_NLI`. 64 | 65 | We also provide our trained [NLI model](https://drive.google.com/file/d/1QnT8V2Yj4Zl2yW2rnQIi2p56I_wbN3Ee/view?usp=sharing) 66 | for downloading. 67 | 68 | ##### 2) NLI model for evaluating coherence of dialogue history 69 | 70 | Using the same RoBERTa MNLI model we used in 1 and `train_coherence_nli.py` to train it on the [InferConvAI2 dataset](https://github.com/nouhadziri/DialogEntailment). 71 | It is a dialogue NLI dataset designed for evaluating the coherence of dialogue history. 72 | 73 | Save the obtained model, define the path containing the model as `COHERENCE_NLI`. 74 | 75 | ##### 3) BERT and GPT2 model used in data diversification 76 | 77 | First use `extract_personas_and_responses.py` to extract persona and response texts into two json files. 78 | 79 | Download the [bert-based-uncased model](https://huggingface.co/bert-base-uncased) and [gpt2-small model](https://huggingface.co/gpt2), 80 | put them under the corresponding directories you like. 81 | Then using `finetune_bert_and_gpt2.py` to fine tune BERT and GPT2 model on `personas.json`, obtaining BERTper and 82 | GPT2per, then fine tune GPT2 on `responses.json` to obtain GPT2res, editing the code to assign the model paths 83 | of BERT and GPT2 you just defined before. 84 | 85 | ##### 4) Back translation model for dialogue history diversification 86 | 87 | Got to directory `./BT`. 88 | 89 | Download [WMT14 en-fr corpus](http://statmt.org/wmt14/translation-task.html#Download), and pre-processing it with 90 | BPE from sentencepiece using `preprocess.sh`, obtaining `sentence.bpe.model`. 91 | 92 | Train en-fr and fr-en translation model using `train_en-fr.sh` and `train_fr-en.sh` under this directory and the average the last 5 models using 93 | `average_model.sh`. Define the obtained model checkpoints as `BT_EN-FR` and `BT-FR-EN`. 94 | 95 | ## 3. Data Distillation 96 | 97 | Go to `./data_augmentation/data_distillation`. 98 | 99 | Using `calculate_entailment.py` to obtained the predicted results given by the NLI model under `PERSONA_NLI` 100 | you obtained before. 101 | 102 | Then using `get_distilled_dataset.py` to obtain the distilled dataset using the previously logits given by the NLI model. 103 | Assume that the obtain distilled data file is `DISTILL_DATA`. 104 | 105 | ## 4. Data diversification 106 | 107 | ##### 1) Obtain the Multi-GPT2 model for response align under new personas 108 | At first you need to obtain a Multi-GPT2 model trained on the distilled samples. You can use the shell 109 | `train_multi_gpt2_distilled.sh` under the root directory. Set the training data as `DISTILL_DATA` 110 | according to the definitions of `config.py`. Note that you should use `config.json` under `multi_gpt2` to replace the 111 | original `config.json` in the initial model weight path to train this model. 112 | 113 | ##### 2) Augment dialogue history 114 | Then you need to augment dialogue history. Go to `./BT`, using `get_bt_input_file.py` to transform the distilled data 115 | `DISTILL_DATA` into the format for back translation. Then use `bpe_split.py` to pre-process the newly obtained txt file with BPE. 116 | 117 | Using `evaluate.sh` and `evaluate_back.sh` you can translate all utterance into French and then back to English. 118 | 119 | Finally, using `recover.py` you can recover the txt file into its original distilled data format in a json file. 120 | 121 | ##### 3) Editing personas and align responses 122 | Go to `./data_augmentation/data_diversification`. Using `generate_new_personas_and_edit_responses.py` you can obtain 123 | new personas as well as some samples with edited new responses if applicable. 124 | 125 | Using `inference_multi_gpt2.sh` in the root directory you can get the predicted responses for the rest samples. 126 | 127 | Using `get_augmented_scores.py` you can get the filter scores for each new sample. 128 | 129 | Using `filter_augmented_data.py` you can get the filtered diversified samples along with the distilled one. They form 130 | the augmented dataset used as an easy curriculum for training. 131 | 132 | ## 5. Train model 133 | 134 | Put the obtained augmented dataset into `./datasets/augmented/` and then you can train two models using 135 | `train_seq2seq_D3.sh` and `train_gpt2_D3.sh`. -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import copy 18 | import io 19 | import json 20 | import os 21 | import random 22 | import re 23 | import sys 24 | import logging 25 | from collections import Counter, namedtuple 26 | 27 | import numpy as np 28 | import torch 29 | import torch.nn as nn 30 | from scipy.interpolate import RectBivariateSpline 31 | from torch.utils.checkpoint import checkpoint 32 | 33 | py_version = sys.version.split('.')[0] 34 | if py_version == '2': 35 | open = io.open 36 | unicode = unicode 37 | else: 38 | unicode = str 39 | open = open 40 | 41 | 42 | def set_seed(seed): 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | random.seed(seed) 46 | 47 | 48 | def repeat_along_dim1(obj, repetitions): 49 | """ repeat (a possibly nested object of) tensors from (batch, ...) to (batch * repetitions, ...) """ 50 | if isinstance(obj, tuple): 51 | return tuple(repeat_along_dim1(o, repetitions) for o in obj) 52 | if isinstance(obj, list): 53 | return list(repeat_along_dim1(o, repetitions) for o in obj) 54 | 55 | obj = obj.unsqueeze(1).repeat([1, repetitions] + [1] * len(obj.size()[1:])) 56 | return obj.view(-1, *obj.size()[2:]) 57 | 58 | 59 | def pad_sequence(sequences, batch_first=False, padding_value=0, left=False): 60 | # assuming trailing dimensions and type of all the Tensors 61 | # in sequences are same and fetching those from sequences[0] 62 | if not len(sequences): 63 | return torch.empty(0) 64 | trailing_dims = sequences[0].size()[1:] 65 | max_len = max([s.size(0) for s in sequences]) 66 | if batch_first: 67 | out_dims = (len(sequences), max_len) + trailing_dims 68 | else: 69 | out_dims = (max_len, len(sequences)) + trailing_dims 70 | 71 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) 72 | for i, tensor in enumerate(sequences): 73 | length = tensor.size(0) 74 | s_slice = slice(-length, None) if left else slice(None, length) 75 | s_slice = (i, s_slice) if batch_first else (s_slice, i) 76 | out_tensor[s_slice] = tensor 77 | 78 | return out_tensor 79 | 80 | 81 | def checkpoint_sequential(functions, segments, *inputs): 82 | def run_function(start, end, functions): 83 | def forward(*inputs): 84 | for j in range(start, end + 1): 85 | inputs = functions[j](*inputs) 86 | return inputs 87 | return forward 88 | 89 | if isinstance(functions, torch.nn.Sequential): 90 | functions = list(functions.children()) 91 | 92 | segment_size = len(functions) // segments 93 | # the last chunk has to be non-volatile 94 | end = -1 95 | for start in range(0, segment_size * (segments - 1), segment_size): 96 | end = start + segment_size - 1 97 | inputs = checkpoint(run_function(start, end, functions), *inputs) 98 | if not isinstance(inputs, tuple): 99 | inputs = (inputs,) 100 | return run_function(end + 1, len(functions) - 1, functions)(*inputs) 101 | 102 | 103 | def f1_score(predictions, targets, average=True): 104 | def f1_score_items(pred_items, gold_items): 105 | common = Counter(gold_items) & Counter(pred_items) 106 | num_same = sum(common.values()) 107 | 108 | if num_same == 0: 109 | return 0 110 | 111 | precision = num_same / len(pred_items) 112 | recall = num_same / len(gold_items) 113 | f1 = (2 * precision * recall) / (precision + recall) 114 | 115 | return f1 116 | 117 | scores = [f1_score_items(p, t) for p, t in zip(predictions, targets)] 118 | 119 | if average: 120 | return sum(scores) / len(scores) 121 | 122 | return scores 123 | 124 | 125 | def openai_transformer_config(): 126 | class dotdict(dict): 127 | __getattr__ = dict.get 128 | __setattr__ = dict.__setitem__ 129 | __delattr__ = dict.__delitem__ 130 | 131 | cfg = dotdict({'n_layers': 12, 'n_embeddings': 40477, 'n_pos_embeddings': 512, 132 | 'embeddings_size': 768, 'n_heads': 12, 'dropout': 0.1, 133 | 'embed_dropout': 0.1, 'attn_dropout': 0.1, 'ff_dropout': 0.1}) 134 | 135 | return cfg 136 | 137 | 138 | def load_openai_weights(model, directory, n_special_tokens=0, use_tokenizer=False): 139 | # TODO: add check of shapes 140 | 141 | parameters_names_path = os.path.join(directory, 'parameters_names.json') 142 | parameters_shapes_path = os.path.join(directory, 'parameters_shapes.json') 143 | parameters_weights_paths = [os.path.join(directory, 'params_{}.npy'.format(n)) for n in range(10)] 144 | 145 | with open(parameters_names_path, 'r') as parameters_names_file: 146 | parameters_names = json.load(parameters_names_file) 147 | 148 | with open(parameters_shapes_path, 'r') as parameters_shapes_file: 149 | parameters_shapes = json.load(parameters_shapes_file) 150 | 151 | parameters_weights = [np.load(path) for path in parameters_weights_paths] 152 | parameters_offsets = np.cumsum([np.prod(shape) for shape in parameters_shapes]) 153 | parameters_weights = np.split(np.concatenate(parameters_weights, 0), parameters_offsets)[:-1] 154 | parameters_weights = [p.reshape(s) for p, s in zip(parameters_weights, parameters_shapes)] 155 | 156 | if not use_tokenizer: 157 | parameters_weights[1] = parameters_weights[1][1:] # skip 0 - 158 | 159 | if isinstance(model.pos_embeddings, nn.Embedding): 160 | if model.pos_embeddings.num_embeddings - 1 > parameters_weights[0].shape[0]: 161 | xx = np.linspace(0, parameters_weights[0].shape[0], model.pos_embeddings.num_embeddings - 1) 162 | new_kernel = RectBivariateSpline(np.arange(parameters_weights[0].shape[0]), 163 | np.arange(parameters_weights[0].shape[1]), 164 | parameters_weights[0]) 165 | parameters_weights[0] = new_kernel(xx, np.arange(parameters_weights[0].shape[1])) 166 | 167 | # parameters_weights[0] = parameters_weights[0][:model.pos_embeddings.num_embeddings - 1] 168 | # model.pos_embeddings.weight.data[1:] = torch.from_numpy(parameters_weights[0]) 169 | model.pos_embeddings.weight.data = torch.from_numpy(parameters_weights[0]) 170 | 171 | 172 | if use_tokenizer: 173 | model.embeddings.weight.data[-n_special_tokens + 1:] = 0 174 | model.embeddings.weight.data[: -n_special_tokens + 1] = torch.from_numpy(parameters_weights[1]) 175 | else: 176 | parameters_weights[1] = parameters_weights[1][:model.embeddings.num_embeddings - n_special_tokens] 177 | model.embeddings.weight.data[:n_special_tokens] = 0 178 | model.embeddings.weight.data[n_special_tokens:] = torch.from_numpy(parameters_weights[1]) 179 | 180 | parameters_weights = parameters_weights[2:] 181 | 182 | for name, weights in zip(parameters_names, parameters_weights): 183 | name = name[6:] # skip "model/" 184 | assert name[-2:] == ':0' 185 | name = name[:-2] 186 | name = name.split('/') 187 | 188 | pointer = model 189 | for m_name in name: 190 | if re.fullmatch(r'[A-Za-z]+\d+', m_name): 191 | l = re.split(r'(\d+)', m_name) 192 | else: 193 | l = [m_name] 194 | 195 | pointer = getattr(pointer, l[0]) 196 | 197 | if len(l) >= 2: 198 | num = int(l[1]) 199 | pointer = pointer[num] 200 | 201 | if len(weights.shape) == 3: # conv1d to linear 202 | weights = weights[0].transpose((1, 0)) 203 | 204 | pointer.data[...] = torch.from_numpy(weights) 205 | 206 | # Initialize shared attention layer is necessary 207 | for layer in model.layers: 208 | attn_state = layer.attn.state_dict() 209 | for context_attn in layer.context_attns: 210 | context_attn.load_state_dict(copy.deepcopy(attn_state), strict=False) 211 | 212 | def config_logger(log_path): 213 | logger = logging.getLogger() 214 | logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s', 215 | level=logging.INFO) 216 | file_handler = logging.FileHandler(log_path, mode='w') 217 | file_handler.setLevel(logging.INFO) 218 | file_handler.setFormatter( 219 | logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s')) 220 | logger.addHandler(file_handler) 221 | return logger 222 | -------------------------------------------------------------------------------- /model/seq2seq_vocab.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import json 4 | 5 | class Seq2seqTokenizer: 6 | def __init__(self): 7 | self.word2idx = {"": 0, "": 1, "": 2, "": 3, '': 4, '': 5} 8 | self.idx2word = {0: "", 1: "", 2: "", 3: "", 4: "", 5: ""} 9 | self.n_words = 6 10 | self.all_special_ids = [0, 1, 2, 3, 4, 5] 11 | self.pad_id = 1 12 | self.bos_id = 2 13 | self.eos_id = 3 14 | self.talker1_bos_id = 4 15 | self.talker2_bos_id = 5 16 | 17 | def tokenize(self, str): 18 | res = str.strip().split(' ') 19 | res = [x.lower() for x in res] 20 | return res 21 | 22 | def encode(self, tokenized_str): 23 | res = [] 24 | for token in tokenized_str: 25 | if self.word2idx.__contains__(token): 26 | res.append(self.word2idx[token]) 27 | return res 28 | 29 | def decode(self, ids, skip_special_tokens=True, clean_up_tokenization_spaces=False): 30 | res = [] 31 | for id in ids: 32 | if skip_special_tokens and id in self.all_special_ids: 33 | continue 34 | res.append(self.idx2word[id]) 35 | text = ' '.join(res) 36 | return text 37 | 38 | def index_words(self, sentence): 39 | for word in sentence.split(' '): 40 | self.index_word(word) 41 | 42 | def index_word(self, word): 43 | if not self.word2idx.__contains__(word): 44 | self.word2idx[word] = self.n_words 45 | self.idx2word[self.n_words] = word 46 | self.n_words += 1 47 | 48 | class Seq2seqVocab: 49 | def __init__(self, train_dataset_path, valid_dataset_path, test_dataset_path, vocab_path, data_type='persona', 50 | extra_train_data_path=None, extra_data_type='persona', extend_exist_vocab=None): 51 | if (os.path.exists(vocab_path)): 52 | with open(vocab_path, 'rb') as f: 53 | cached_data = pickle.load(f) 54 | self.vocab = cached_data[0] 55 | self.all_data = cached_data[1] 56 | else: 57 | if extend_exist_vocab: 58 | with open(extend_exist_vocab, 'rb') as f: 59 | cached_data = pickle.load(f) 60 | self.vocab = cached_data[0] 61 | print('loaded vocab size' + str(self.vocab.n_words)) 62 | else: 63 | self.vocab = Seq2seqTokenizer() 64 | self.all_data = self._parse_data(train_dataset_path, valid_dataset_path, test_dataset_path, data_type) 65 | if extra_train_data_path: 66 | extra_data = self._parse_data(extra_train_data_path, None, None, extra_data_type) 67 | self.all_data.extend(extra_data) 68 | self.parse_vocab(self.all_data, self.vocab) 69 | with open(vocab_path, 'wb') as f: 70 | pickle.dump([self.vocab, []], f) 71 | 72 | def _parse_data(self, train_dataset_path, valid_dataset_path, test_dataset_path, data_type): 73 | data = None 74 | if data_type == 'persona': 75 | data = self.parse_data_persona(train_dataset_path, valid_dataset_path, test_dataset_path) 76 | elif data_type == 'emoji': 77 | data = self.parse_data_emoji(train_dataset_path, valid_dataset_path, test_dataset_path) 78 | elif data_type == 'daily': 79 | data = self.parse_data_daily(train_dataset_path, valid_dataset_path, test_dataset_path) 80 | elif data_type == 'entailment': 81 | data = self.parse_data_entailment(train_dataset_path, valid_dataset_path, test_dataset_path) 82 | return data 83 | 84 | def parse_data_persona(self, train_dataset_path, valid_dataset_path, test_dataset_path): 85 | subsets = [train_dataset_path, valid_dataset_path, test_dataset_path] 86 | all_data = [] 87 | for subset in subsets: 88 | data = [] 89 | if subset is None or len(subset) == 0: 90 | all_data.append(data) 91 | continue 92 | with open(subset, 'r', encoding='utf-8') as f: 93 | for line in f.readlines(): 94 | line = line.strip() 95 | space_idx = line.find(' ') 96 | if space_idx == -1: 97 | dialog_idx = int(line) 98 | else: 99 | dialog_idx = int(line[:space_idx]) 100 | 101 | if int(dialog_idx) == 1: 102 | data.append({'persona_info': [], 'dialog': [], 'candidates': []}) 103 | 104 | dialog_line = line[space_idx + 1:].split('\t') 105 | dialog_line = [l.strip() for l in dialog_line] 106 | 107 | if dialog_line[0].startswith('your persona:'): 108 | persona_info = dialog_line[0].replace('your persona: ', '') 109 | if persona_info[-1] == '.' and persona_info[-2] != ' ': 110 | persona_info = persona_info[:-1] + ' .' 111 | data[-1]['persona_info'].append(persona_info) 112 | if dialog_line[0].startswith('partner\'s person'): 113 | if not data[-1].__contains__('partner_persona_info'): 114 | data[-1]['partner_persona_info'] = [] 115 | persona_info = dialog_line[0].replace('partner\'s persona: ', '') 116 | if persona_info[-1] == '.' and persona_info[-2] != ' ': 117 | persona_info = persona_info[:-1] + ' .' 118 | data[-1]['partner_persona_info'].append(persona_info) 119 | elif len(dialog_line) > 1: 120 | data[-1]['dialog'].append(dialog_line[0]) 121 | data[-1]['dialog'].append(dialog_line[1]) 122 | if len(dialog_line) == 4: 123 | data[-1]['candidates'].append(dialog_line[3].split('|')[:-1]) # the last candidate is a duplicate of the good answer (dialog_line[1]) 124 | 125 | all_data.append(data) 126 | return all_data 127 | 128 | def parse_data_emoji(self, train_dataset_path, valid_dataset_path, test_dataset_path): 129 | subsets = [train_dataset_path, valid_dataset_path, test_dataset_path] 130 | all_data = [] 131 | for subset in subsets: 132 | data = [] 133 | if subset is None or len(subset) == 0: 134 | all_data.append(data) 135 | continue 136 | with open(subset, 'r', encoding='utf-8') as f: 137 | for line in f.readlines(): 138 | line = line.strip() 139 | items = line.split('\t') 140 | data.append({'persona_info': [], 'dialog': [], 'candidates': []}) 141 | data[-1]['persona_info'].append(items[0]) 142 | data[-1]['dialog'].append(items[1]) 143 | data[-1]['dialog'].append(items[2]) 144 | all_data.append(data) 145 | return all_data 146 | 147 | def parse_data_daily(self, train_dataset_path, valid_dataset_path, test_dataset_path): 148 | subsets = [train_dataset_path, valid_dataset_path, test_dataset_path] 149 | all_data = [] 150 | for subset in subsets: 151 | data = [] 152 | if subset is None or len(subset) == 0: 153 | all_data.append(data) 154 | continue 155 | with open(subset, 'r', encoding='utf-8') as f: 156 | for line in f.readlines(): 157 | line = line.strip() 158 | items = line.split('\t') 159 | items = [item.strip().lower() for item in items] 160 | data.append({'persona_info': [], 'dialog': [], 'candidates': []}) 161 | data[-1]['persona_info'].append(items[0]) 162 | for i in range(1, len(items)): 163 | data[-1]['dialog'].append(items[i]) 164 | all_data.append(data) 165 | return all_data 166 | 167 | def parse_data_entailment(self, train_dataset_path, valid_dataset_path, test_dataset_path): 168 | subsets = [train_dataset_path, valid_dataset_path, test_dataset_path] 169 | all_data = [] 170 | for subset in subsets: 171 | data = [] 172 | if subset is None or len(subset) == 0: 173 | all_data.append(data) 174 | continue 175 | try: 176 | with open(subset, 'r', encoding='utf-8') as f: 177 | data =[] 178 | list = json.load(f) 179 | for item in list: 180 | data.append({'persona_info': [], 'dialog': [], 'candidates': []}) 181 | data[-1]['persona_info'].append(item[0]) 182 | data[-1]['dialog'].extend(item[1]) 183 | data[-1]['dialog'].append(item[2]) 184 | except: 185 | print('Incorrect data format ' + subset) 186 | all_data.append(data) 187 | return all_data 188 | 189 | def parse_vocab(self, all_data, vocab): 190 | for data in all_data: 191 | for p in data: 192 | for s in p['persona_info']: 193 | vocab.index_words(s) 194 | for s in p['dialog']: 195 | vocab.index_words(s) 196 | for c in p['candidates']: 197 | for s in c: 198 | vocab.index_words(s) -------------------------------------------------------------------------------- /model/optim.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | 22 | logger = logging.getLogger(__file__) 23 | 24 | 25 | class Adam(torch.optim.Optimizer): 26 | """Implements Adam algorithm. 27 | This implementation is modified from torch.optim.Adam based on: 28 | `Fixed Weight Decay Regularization in Adam` 29 | (see https://arxiv.org/abs/1711.05101) 30 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 31 | Arguments: 32 | params (iterable): iterable of parameters to optimize or dicts defining 33 | parameter groups 34 | lr (float, optional): learning rate (default: 1e-3) 35 | betas (Tuple[float, float], optional): coefficients used for computing 36 | running averages of gradient and its square (default: (0.9, 0.999)) 37 | eps (float, optional): term added to the denominator to improve 38 | numerical stability (default: 1e-8) 39 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 40 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 41 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 42 | .. _Adam\: A Method for Stochastic Optimization: 43 | https://arxiv.org/abs/1412.6980 44 | .. _On the Convergence of Adam and Beyond: 45 | https://openreview.net/forum?id=ryQu7f-RZ 46 | """ 47 | 48 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False): 49 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) 50 | super(Adam, self).__init__(params, defaults) 51 | 52 | def step(self, closure=None): 53 | """Performs a single optimization step. 54 | Arguments: 55 | closure (callable, optional): A closure that reevaluates the model 56 | and returns the loss. 57 | """ 58 | loss = None 59 | if closure is not None: 60 | loss = closure() 61 | 62 | for group in self.param_groups: 63 | for p in group['params']: 64 | if p.grad is None: 65 | continue 66 | grad = p.grad.data 67 | 68 | amsgrad = group['amsgrad'] 69 | 70 | state = self.state[p] 71 | 72 | # State initialization 73 | if len(state) == 0: 74 | state['step'] = 0 75 | # Exponential moving average of gradient values 76 | state['exp_avg'] = torch.zeros_like(p.data) 77 | # Exponential moving average of squared gradient values 78 | state['exp_avg_sq'] = torch.zeros_like(p.data) 79 | if amsgrad: 80 | # Maintains max of all exp. moving avg. of sq. grad. values 81 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 82 | 83 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 84 | if amsgrad: 85 | max_exp_avg_sq = state['max_exp_avg_sq'] 86 | beta1, beta2 = group['betas'] 87 | 88 | state['step'] += 1 89 | 90 | if grad.is_sparse: 91 | grad = grad.coalesce() # the update is non-linear so indices must be unique 92 | grad_indices = grad._indices() 93 | grad_values = grad._values() 94 | size = grad.size() 95 | 96 | def make_sparse(values): 97 | constructor = grad.new 98 | if grad_indices.dim() == 0 or values.dim() == 0: 99 | return constructor().resize_as_(grad) 100 | return constructor(grad_indices, values, size) 101 | 102 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 103 | beta1, beta2 = group['betas'] 104 | 105 | # Decay the first and second moment running average coefficient 106 | # old <- b * old + (1 - b) * new 107 | # <==> old += (1 - b) * (new - old) 108 | old_exp_avg_values = exp_avg.sparse_mask(grad)._values() 109 | exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) 110 | exp_avg.add_(make_sparse(exp_avg_update_values)) 111 | old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() 112 | exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) 113 | exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) 114 | 115 | # Dense addition again is intended, avoiding another sparse_mask 116 | numer = exp_avg_update_values.add_(old_exp_avg_values) 117 | exp_avg_sq_update_values.add_(old_exp_avg_sq_values) 118 | denom = exp_avg_sq_update_values.sqrt_().add_(group['eps']) 119 | del exp_avg_update_values, exp_avg_sq_update_values 120 | 121 | bias_correction1 = 1 - beta1 ** state['step'] 122 | bias_correction2 = 1 - beta2 ** state['step'] 123 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 124 | 125 | p.data.add_(make_sparse(-step_size * numer.div_(denom))) 126 | else: 127 | # Decay the first and second moment running average coefficient 128 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 129 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 130 | if amsgrad: 131 | # Maintains the maximum of all 2nd moment running avg. till now 132 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 133 | # Use the max. for normalizing running avg. of gradient 134 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 135 | else: 136 | denom = exp_avg_sq.sqrt().add_(group['eps']) 137 | 138 | bias_correction1 = 1 - beta1 ** state['step'] 139 | bias_correction2 = 1 - beta2 ** state['step'] 140 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 141 | 142 | if group['weight_decay'] != 0: 143 | p.data.add_(-group['weight_decay'] * group['lr'], p.data) 144 | 145 | p.data.addcdiv_(-step_size, exp_avg, denom) 146 | 147 | return loss 148 | 149 | def backward(self, losses): 150 | with torch.autograd.set_detect_anomaly(True): 151 | if not isinstance(losses, (tuple, list)): 152 | losses = [losses] 153 | full_loss = sum(losses, 0) 154 | full_loss.backward() 155 | return full_loss 156 | 157 | class NoamOpt: 158 | def __init__(self, embeddings_size, warmup, optimizer, linear_schedule=False, lr=None, total_steps=None, 159 | apex_level=None, loss_weight=None, extra_module_lr_rate=1.0): 160 | self.embeddings_size = embeddings_size 161 | self.warmup = warmup 162 | self.optimizer = optimizer 163 | self.linear_schedule = linear_schedule 164 | self.apex_level = apex_level 165 | self.lr = lr 166 | self.total_steps = total_steps 167 | self.loss_weight = loss_weight 168 | self.extra_module_lr_rate = extra_module_lr_rate 169 | 170 | self._step = 0 171 | 172 | def state_dict(self): 173 | return {'step': self._step, 174 | 'optimizer': self.optimizer.state_dict()} 175 | 176 | def load_state_dict(self, state_dict): 177 | self._step = state_dict['step'] 178 | try: 179 | self.optimizer.load_state_dict(state_dict['optimizer']) 180 | except ValueError as e: 181 | logger.info("Optimizer cannot be loaded from checkpoint: {}".format(e)) 182 | except KeyError as e: 183 | logger.info("Optimizer cannot be loaded from checkpoint: {}".format(e)) 184 | 185 | def backward(self, losses): 186 | if not isinstance(losses, (tuple, list)): 187 | losses = [losses] 188 | if self.loss_weight is None: 189 | full_loss = sum(losses, 0) 190 | else: 191 | full_loss = torch.sum(torch.stack(losses, 0) * torch.exp(self.loss_weight[1])) + torch.sum(self.loss_weight[1]) 192 | 193 | if self.apex_level is not None: 194 | try: 195 | from apex.amp import scale_loss 196 | except ImportError: 197 | raise ImportError("Please install apex.") 198 | 199 | for loss_id, loss in enumerate(losses): 200 | with scale_loss(loss, self.optimizer, loss_id=loss_id) as scaled_loss: 201 | scaled_loss.backward() 202 | else: 203 | full_loss.backward() 204 | return full_loss 205 | 206 | def zero_grad(self): 207 | return self.optimizer.zero_grad() 208 | 209 | def get_lr(self): 210 | return self.optimizer.param_groups[0]['lr'] 211 | 212 | @property 213 | def param_groups(self): 214 | return self.optimizer.param_groups 215 | 216 | def step(self): 217 | self._step += 1 218 | rate = self.rate_linear() if self.linear_schedule else self.rate() 219 | for p in self.optimizer.param_groups: 220 | if p.__contains__('extra'): 221 | p['lr'] = rate * self.extra_module_lr_rate 222 | else: 223 | p['lr'] = rate 224 | self.optimizer.step() 225 | 226 | def rate(self, step=None): 227 | if step is None: 228 | step = self._step 229 | 230 | return self.lr * (self.embeddings_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) 231 | 232 | @staticmethod 233 | def warmup_linear(x, warmup=0.002): 234 | if x < warmup: 235 | return x/warmup 236 | return 1.0 - x 237 | 238 | def rate_linear(self, step=None): 239 | if step is None: 240 | step = self._step 241 | assert self.lr is not None and self.total_steps is not None 242 | 243 | return self.lr * self.warmup_linear(step/self.total_steps, self.warmup) 244 | -------------------------------------------------------------------------------- /bert_score/score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import pathlib 5 | import torch 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.axes_grid1 import make_axes_locatable 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from collections import defaultdict 12 | from transformers import AutoTokenizer 13 | 14 | from .utils import ( 15 | get_model, 16 | get_idf_dict, 17 | bert_cos_score_idf, 18 | get_bert_embedding, 19 | lang2model, 20 | model2layers, 21 | get_hash, 22 | cache_scibert, 23 | sent_encode, 24 | ) 25 | 26 | 27 | __all__ = ["score", "plot_example"] 28 | 29 | 30 | def get_bert_score( 31 | cands, 32 | refs, 33 | model_type=None, 34 | num_layers=None, 35 | verbose=False, 36 | idf=False, 37 | device=None, 38 | batch_size=64, 39 | nthreads=4, 40 | all_layers=False, 41 | lang=None, 42 | return_hash=False, 43 | rescale_with_baseline=False, 44 | baseline_path=None, 45 | ): 46 | """ 47 | BERTScore metric. 48 | 49 | Args: 50 | - :param: `cands` (list of str): candidate sentences 51 | - :param: `refs` (list of str or list of list of str): reference sentences 52 | - :param: `model_type` (str): bert specification, default using the suggested 53 | model for the target langauge; has to specify at least one of 54 | `model_type` or `lang` 55 | - :param: `num_layers` (int): the layer of representation to use. 56 | default using the number of layer tuned on WMT16 correlation data 57 | - :param: `verbose` (bool): turn on intermediate status update 58 | - :param: `idf` (bool or dict): use idf weighting, can also be a precomputed idf_dict 59 | - :param: `device` (str): on which the contextual embedding model will be allocated on. 60 | If this argument is None, the model lives on cuda:0 if cuda is available. 61 | - :param: `nthreads` (int): number of threads 62 | - :param: `batch_size` (int): bert score processing batch size 63 | - :param: `lang` (str): language of the sentences; has to specify 64 | at least one of `model_type` or `lang`. `lang` needs to be 65 | specified when `rescale_with_baseline` is True. 66 | - :param: `return_hash` (bool): return hash code of the setting 67 | - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline 68 | - :param: `baseline_path` (str): customized baseline file 69 | 70 | Return: 71 | - :param: `(P, R, F)`: each is of shape (N); N = number of input 72 | candidate reference pairs. if returning hashcode, the 73 | output will be ((P, R, F), hashcode). If a candidate have 74 | multiple references, the returned score of this candidate is 75 | the *best* score among all references. 76 | """ 77 | assert len(cands) == len(refs), "Different number of candidates and references" 78 | 79 | assert lang is not None or model_type is not None, "Either lang or model_type should be specified" 80 | 81 | ref_group_boundaries = None 82 | if not isinstance(refs[0], str): 83 | ref_group_boundaries = [] 84 | ori_cands, ori_refs = cands, refs 85 | cands, refs = [], [] 86 | count = 0 87 | for cand, ref_group in zip(ori_cands, ori_refs): 88 | cands += [cand] * len(ref_group) 89 | refs += ref_group 90 | ref_group_boundaries.append((count, count + len(ref_group))) 91 | count += len(ref_group) 92 | 93 | if rescale_with_baseline: 94 | assert lang is not None, "Need to specify Language when rescaling with baseline" 95 | 96 | if model_type is None: 97 | lang = lang.lower() 98 | model_type = lang2model[lang] 99 | if num_layers is None: 100 | num_layers = model2layers[model_type] 101 | 102 | if model_type.startswith("scibert"): 103 | tokenizer = AutoTokenizer.from_pretrained(cache_scibert(model_type)) 104 | else: 105 | tokenizer = AutoTokenizer.from_pretrained(model_type) 106 | 107 | model = get_model(model_type, num_layers, all_layers) 108 | if device is None: 109 | device = "cuda" if torch.cuda.is_available() else "cpu" 110 | model.to(device) 111 | 112 | if not idf: 113 | idf_dict = defaultdict(lambda: 1.0) 114 | # set idf for [SEP] and [CLS] to 0 115 | idf_dict[tokenizer.sep_token_id] = 0 116 | idf_dict[tokenizer.cls_token_id] = 0 117 | elif isinstance(idf, dict): 118 | if verbose: 119 | print("using predefined IDF dict...") 120 | idf_dict = idf 121 | else: 122 | if verbose: 123 | print("preparing IDF dict...") 124 | start = time.perf_counter() 125 | idf_dict = get_idf_dict(refs, tokenizer, nthreads=nthreads) 126 | if verbose: 127 | print("done in {:.2f} seconds".format(time.perf_counter() - start)) 128 | 129 | if verbose: 130 | print("calculating scores...") 131 | start = time.perf_counter() 132 | all_preds = bert_cos_score_idf( 133 | model, 134 | refs, 135 | cands, 136 | tokenizer, 137 | idf_dict, 138 | verbose=verbose, 139 | device=device, 140 | batch_size=batch_size, 141 | all_layers=all_layers, 142 | ).cpu() 143 | 144 | if ref_group_boundaries is not None: 145 | max_preds = [] 146 | for beg, end in ref_group_boundaries: 147 | max_preds.append(all_preds[beg:end].max(dim=0)[0]) 148 | all_preds = torch.stack(max_preds, dim=0) 149 | 150 | use_custom_baseline = baseline_path is not None 151 | if rescale_with_baseline: 152 | if baseline_path is None: 153 | baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv") 154 | if os.path.isfile(baseline_path): 155 | if not all_layers: 156 | baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float() 157 | else: 158 | baselines = torch.from_numpy(pd.read_csv(baseline_path).to_numpy())[:, 1:].unsqueeze(1).float() 159 | 160 | all_preds = (all_preds - baselines) / (1 - baselines) 161 | else: 162 | print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr) 163 | 164 | out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2] # P, R, F 165 | 166 | if verbose: 167 | time_diff = time.perf_counter() - start 168 | print(f"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec") 169 | 170 | if return_hash: 171 | return tuple( 172 | [out, get_hash(model_type, num_layers, idf, rescale_with_baseline, use_custom_baseline=use_custom_baseline)] 173 | ) 174 | 175 | return out 176 | 177 | 178 | def plot_example( 179 | candidate, 180 | reference, 181 | model_type=None, 182 | num_layers=None, 183 | lang=None, 184 | rescale_with_baseline=False, 185 | baseline_path=None, 186 | fname="", 187 | ): 188 | """ 189 | BERTScore metric. 190 | 191 | Args: 192 | - :param: `candidate` (str): a candidate sentence 193 | - :param: `reference` (str): a reference sentence 194 | - :param: `verbose` (bool): turn on intermediate status update 195 | - :param: `model_type` (str): bert specification, default using the suggested 196 | model for the target langauge; has to specify at least one of 197 | `model_type` or `lang` 198 | - :param: `num_layers` (int): the layer of representation to use 199 | - :param: `lang` (str): language of the sentences; has to specify 200 | at least one of `model_type` or `lang`. `lang` needs to be 201 | specified when `rescale_with_baseline` is True. 202 | - :param: `return_hash` (bool): return hash code of the setting 203 | - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline 204 | - :param: `fname` (str): path to save the output plot 205 | """ 206 | assert isinstance(candidate, str) 207 | assert isinstance(reference, str) 208 | 209 | assert lang is not None or model_type is not None, "Either lang or model_type should be specified" 210 | 211 | if rescale_with_baseline: 212 | assert lang is not None, "Need to specify Language when rescaling with baseline" 213 | 214 | if model_type is None: 215 | lang = lang.lower() 216 | model_type = lang2model[lang] 217 | if num_layers is None: 218 | num_layers = model2layers[model_type] 219 | 220 | if model_type.startswith("scibert"): 221 | tokenizer = AutoTokenizer.from_pretrained(cache_scibert(model_type)) 222 | else: 223 | tokenizer = AutoTokenizer.from_pretrained(model_type) 224 | model = get_model(model_type, num_layers) 225 | device = "cuda" if torch.cuda.is_available() else "cpu" 226 | model.to(device) 227 | 228 | idf_dict = defaultdict(lambda: 1.0) 229 | # set idf for [SEP] and [CLS] to 0 230 | idf_dict[tokenizer.sep_token_id] = 0 231 | idf_dict[tokenizer.cls_token_id] = 0 232 | 233 | hyp_embedding, masks, padded_idf = get_bert_embedding( 234 | [candidate], model, tokenizer, idf_dict, device=device, all_layers=False 235 | ) 236 | ref_embedding, masks, padded_idf = get_bert_embedding( 237 | [reference], model, tokenizer, idf_dict, device=device, all_layers=False 238 | ) 239 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1)) 240 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1)) 241 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)) 242 | sim = sim.squeeze(0).cpu() 243 | 244 | # remove [CLS] and [SEP] tokens 245 | r_tokens = [tokenizer.decode([i]) for i in sent_encode(tokenizer, reference)][1:-1] 246 | h_tokens = [tokenizer.decode([i]) for i in sent_encode(tokenizer, candidate)][1:-1] 247 | sim = sim[1:-1, 1:-1] 248 | 249 | if rescale_with_baseline: 250 | if baseline_path is None: 251 | baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv") 252 | if os.path.isfile(baseline_path): 253 | baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float() 254 | sim = (sim - baselines[2].item()) / (1 - baselines[2].item()) 255 | else: 256 | print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr) 257 | 258 | fig, ax = plt.subplots(figsize=(len(r_tokens), len(h_tokens))) 259 | im = ax.imshow(sim, cmap="Blues", vmin=0, vmax=1) 260 | 261 | # We want to show all ticks... 262 | ax.set_xticks(np.arange(len(r_tokens))) 263 | ax.set_yticks(np.arange(len(h_tokens))) 264 | # ... and label them with the respective list entries 265 | ax.set_xticklabels(r_tokens, fontsize=10) 266 | ax.set_yticklabels(h_tokens, fontsize=10) 267 | ax.grid(False) 268 | plt.xlabel("Reference (tokenized)", fontsize=14) 269 | plt.ylabel("Candidate (tokenized)", fontsize=14) 270 | title = "Similarity Matrix" 271 | if rescale_with_baseline: 272 | title += " (after Rescaling)" 273 | plt.title(title, fontsize=14) 274 | 275 | divider = make_axes_locatable(ax) 276 | cax = divider.append_axes("right", size="2%", pad=0.2) 277 | fig.colorbar(im, cax=cax) 278 | 279 | # Rotate the tick labels and set their alignment. 280 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 281 | 282 | # Loop over data dimensions and create text annotations. 283 | for i in range(len(h_tokens)): 284 | for j in range(len(r_tokens)): 285 | text = ax.text( 286 | j, 287 | i, 288 | "{:.3f}".format(sim[i, j].item()), 289 | ha="center", 290 | va="center", 291 | color="k" if sim[i, j].item() < 0.5 else "w", 292 | ) 293 | 294 | fig.tight_layout() 295 | if fname != "": 296 | plt.savefig(fname, dpi=100) 297 | print("Saved figure to file: ", fname) 298 | plt.show() 299 | -------------------------------------------------------------------------------- /bert_score/scorer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import pathlib 5 | import torch 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.axes_grid1 import make_axes_locatable 8 | import numpy as np 9 | import pandas as pd 10 | import warnings 11 | 12 | from collections import defaultdict 13 | from transformers import AutoTokenizer 14 | 15 | from .utils import ( 16 | get_model, 17 | get_idf_dict, 18 | bert_cos_score_idf, 19 | get_bert_embedding, 20 | lang2model, 21 | model2layers, 22 | get_hash, 23 | cache_scibert, 24 | sent_encode, 25 | ) 26 | 27 | 28 | class BERTScorer: 29 | """ 30 | BERTScore Scorer Object. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | model_type=None, 36 | num_layers=None, 37 | batch_size=64, 38 | nthreads=4, 39 | all_layers=False, 40 | idf=False, 41 | idf_sents=None, 42 | device=None, 43 | lang=None, 44 | rescale_with_baseline=False, 45 | baseline_path=None, 46 | ): 47 | """ 48 | Args: 49 | - :param: `model_type` (str): contexual embedding model specification, default using the suggested 50 | model for the target langauge; has to specify at least one of 51 | `model_type` or `lang` 52 | - :param: `num_layers` (int): the layer of representation to use. 53 | default using the number of layer tuned on WMT16 correlation data 54 | - :param: `verbose` (bool): turn on intermediate status update 55 | - :param: `idf` (bool): a booling to specify whether to use idf or not (this should be True even if `idf_sents` is given) 56 | - :param: `idf_sents` (List of str): list of sentences used to compute the idf weights 57 | - :param: `device` (str): on which the contextual embedding model will be allocated on. 58 | If this argument is None, the model lives on cuda:0 if cuda is available. 59 | - :param: `batch_size` (int): bert score processing batch size 60 | - :param: `nthreads` (int): number of threads 61 | - :param: `lang` (str): language of the sentences; has to specify 62 | at least one of `model_type` or `lang`. `lang` needs to be 63 | specified when `rescale_with_baseline` is True. 64 | - :param: `return_hash` (bool): return hash code of the setting 65 | - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline 66 | - :param: `baseline_path` (str): customized baseline file 67 | """ 68 | 69 | assert lang is not None or model_type is not None, "Either lang or model_type should be specified" 70 | 71 | if rescale_with_baseline: 72 | assert lang is not None, "Need to specify Language when rescaling with baseline" 73 | 74 | if device is None: 75 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 76 | else: 77 | self.device = device 78 | 79 | self._lang = lang 80 | self._rescale_with_baseline = rescale_with_baseline 81 | self._idf = idf 82 | self.batch_size = batch_size 83 | self.nthreads = nthreads 84 | self.all_layers = all_layers 85 | 86 | if model_type is None: 87 | lang = lang.lower() 88 | self._model_type = lang2model[lang] 89 | else: 90 | self._model_type = model_type 91 | 92 | if num_layers is None: 93 | self._num_layers = model2layers[self.model_type] 94 | else: 95 | self._num_layers = num_layers 96 | 97 | # Building model and tokenizer 98 | 99 | if self.model_type.startswith("scibert"): 100 | self._tokenizer = AutoTokenizer.from_pretrained(cache_scibert(self.model_type)) 101 | else: 102 | self._tokenizer = AutoTokenizer.from_pretrained(self.model_type) 103 | 104 | self._model = get_model(self.model_type, self.num_layers, self.all_layers) 105 | self._model.to(self.device) 106 | 107 | self._idf_dict = None 108 | if idf_sents is not None: 109 | self.compute_idf(idf_sents) 110 | 111 | self._baseline_vals = None 112 | self.baseline_path = baseline_path 113 | self.use_custom_baseline = self.baseline_path is not None 114 | if self.baseline_path is None: 115 | self.baseline_path = os.path.join( 116 | os.path.dirname(__file__), f"rescale_baseline/{self.lang}/{self.model_type}.tsv" 117 | ) 118 | 119 | @property 120 | def lang(self): 121 | return self._lang 122 | 123 | @property 124 | def idf(self): 125 | return self._idf 126 | 127 | @property 128 | def model_type(self): 129 | return self._model_type 130 | 131 | @property 132 | def num_layers(self): 133 | return self._num_layers 134 | 135 | @property 136 | def rescale_with_baseline(self): 137 | return self._rescale_with_baseline 138 | 139 | @property 140 | def baseline_vals(self): 141 | if self._baseline_vals is None: 142 | if os.path.isfile(self.baseline_path): 143 | if not self.all_layers: 144 | self._baseline_vals = torch.from_numpy( 145 | pd.read_csv(self.baseline_path).iloc[self.num_layers].to_numpy() 146 | )[1:].float() 147 | else: 148 | self._baseline_vals = ( 149 | torch.from_numpy(pd.read_csv(self.baseline_path).to_numpy())[:, 1:].unsqueeze(1).float() 150 | ) 151 | else: 152 | raise ValueError(f"Baseline not Found for {self.model_type} on {self.lang} at {self.baseline_path}") 153 | 154 | return self._baseline_vals 155 | 156 | @property 157 | def hash(self): 158 | return get_hash( 159 | self.model_type, self.num_layers, self.idf, self.rescale_with_baseline, self.use_custom_baseline 160 | ) 161 | 162 | def compute_idf(self, sents): 163 | """ 164 | Args: 165 | 166 | """ 167 | if self._idf_dict is not None: 168 | warnings.warn("Overwriting the previous importance weights.") 169 | 170 | self._idf_dict = get_idf_dict(sents, self._tokenizer, nthreads=self.nthreads) 171 | 172 | def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False): 173 | """ 174 | Args: 175 | - :param: `cands` (list of str): candidate sentences 176 | - :param: `refs` (list of str or list of list of str): reference sentences 177 | 178 | Return: 179 | - :param: `(P, R, F)`: each is of shape (N); N = number of input 180 | candidate reference pairs. if returning hashcode, the 181 | output will be ((P, R, F), hashcode). If a candidate have 182 | multiple references, the returned score of this candidate is 183 | the *best* score among all references. 184 | """ 185 | 186 | ref_group_boundaries = None 187 | if not isinstance(refs[0], str): 188 | ref_group_boundaries = [] 189 | ori_cands, ori_refs = cands, refs 190 | cands, refs = [], [] 191 | count = 0 192 | for cand, ref_group in zip(ori_cands, ori_refs): 193 | cands += [cand] * len(ref_group) 194 | refs += ref_group 195 | ref_group_boundaries.append((count, count + len(ref_group))) 196 | count += len(ref_group) 197 | 198 | if verbose: 199 | print("calculating scores...") 200 | start = time.perf_counter() 201 | 202 | if self.idf: 203 | assert self._idf_dict, "IDF weights are not computed" 204 | idf_dict = self._idf_dict 205 | else: 206 | idf_dict = defaultdict(lambda: 1.0) 207 | idf_dict[self._tokenizer.sep_token_id] = 0 208 | idf_dict[self._tokenizer.cls_token_id] = 0 209 | 210 | all_preds = bert_cos_score_idf( 211 | self._model, 212 | refs, 213 | cands, 214 | self._tokenizer, 215 | idf_dict, 216 | verbose=verbose, 217 | device=self.device, 218 | batch_size=batch_size, 219 | all_layers=self.all_layers, 220 | ).cpu() 221 | 222 | if ref_group_boundaries is not None: 223 | max_preds = [] 224 | for start, end in ref_group_boundaries: 225 | max_preds.append(all_preds[start:end].max(dim=0)[0]) 226 | all_preds = torch.stack(max_preds, dim=0) 227 | 228 | if self.rescale_with_baseline: 229 | all_preds = (all_preds - self.baseline_vals) / (1 - self.baseline_vals) 230 | 231 | out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2] # P, R, F 232 | 233 | if verbose: 234 | time_diff = time.perf_counter() - start 235 | print(f"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec") 236 | 237 | if return_hash: 238 | out = tuple([out, self.hash]) 239 | 240 | return out 241 | 242 | def plot_example(self, candidate, reference, fname=""): 243 | """ 244 | Args: 245 | - :param: `candidate` (str): a candidate sentence 246 | - :param: `reference` (str): a reference sentence 247 | - :param: `fname` (str): path to save the output plot 248 | """ 249 | 250 | assert isinstance(candidate, str) 251 | assert isinstance(reference, str) 252 | 253 | idf_dict = defaultdict(lambda: 1.0) 254 | idf_dict[self._tokenizer.sep_token_id] = 0 255 | idf_dict[self._tokenizer.cls_token_id] = 0 256 | 257 | hyp_embedding, masks, padded_idf = get_bert_embedding( 258 | [candidate], self._model, self._tokenizer, idf_dict, device=self.device, all_layers=False 259 | ) 260 | ref_embedding, masks, padded_idf = get_bert_embedding( 261 | [reference], self._model, self._tokenizer, idf_dict, device=self.device, all_layers=False 262 | ) 263 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1)) 264 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1)) 265 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)) 266 | sim = sim.squeeze(0).cpu() 267 | 268 | r_tokens = [self._tokenizer.decode([i]) for i in sent_encode(self._tokenizer, reference)][1:-1] 269 | h_tokens = [self._tokenizer.decode([i]) for i in sent_encode(self._tokenizer, candidate)][1:-1] 270 | sim = sim[1:-1, 1:-1] 271 | 272 | if self.rescale_with_baseline: 273 | sim = (sim - self.baseline_vals[2].item()) / (1 - self.baseline_vals[2].item()) 274 | 275 | fig, ax = plt.subplots(figsize=(len(r_tokens), len(h_tokens))) 276 | im = ax.imshow(sim, cmap="Blues", vmin=0, vmax=1) 277 | 278 | # We want to show all ticks... 279 | ax.set_xticks(np.arange(len(r_tokens))) 280 | ax.set_yticks(np.arange(len(h_tokens))) 281 | # ... and label them with the respective list entries 282 | ax.set_xticklabels(r_tokens, fontsize=10) 283 | ax.set_yticklabels(h_tokens, fontsize=10) 284 | ax.grid(False) 285 | plt.xlabel("Reference (tokenized)", fontsize=14) 286 | plt.ylabel("Candidate (tokenized)", fontsize=14) 287 | title = "Similarity Matrix" 288 | if self.rescale_with_baseline: 289 | title += " (after Rescaling)" 290 | plt.title(title, fontsize=14) 291 | 292 | divider = make_axes_locatable(ax) 293 | cax = divider.append_axes("right", size="2%", pad=0.2) 294 | fig.colorbar(im, cax=cax) 295 | 296 | # Rotate the tick labels and set their alignment. 297 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 298 | 299 | # Loop over data dimensions and create text annotations. 300 | for i in range(len(h_tokens)): 301 | for j in range(len(r_tokens)): 302 | text = ax.text( 303 | j, 304 | i, 305 | "{:.3f}".format(sim[i, j].item()), 306 | ha="center", 307 | va="center", 308 | color="k" if sim[i, j].item() < 0.5 else "w", 309 | ) 310 | 311 | fig.tight_layout() 312 | if fname != "": 313 | plt.savefig(fname, dpi=100) 314 | print("Saved figure to file: ", fname) 315 | plt.show() 316 | 317 | def __repr__(self): 318 | return f"{self.__class__.__name__}(hash={self.hash}, batch_size={self.batch_size}, nthreads={self.nthreads})" 319 | 320 | def __str__(self): 321 | return self.__repr__() 322 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import json 4 | import torch 5 | import torch.nn as nn 6 | 7 | from config import get_trainer_config, InputConfig 8 | from model.dataset import FacebookDataset 9 | from model.trainer import Trainer 10 | from model.gpt2_model import GPT2EncoderDecoderModel, GPT2DoubleHeadsModel 11 | from model.utils import f1_score, open, set_seed, config_logger 12 | from transformers.tokenization_gpt2 import GPT2Tokenizer 13 | from metrics import nlp_metrics 14 | from model.seq2seq import TransformerSeq2Seq 15 | from model.seq2seq_vocab import Seq2seqVocab 16 | from model.entailment_score import EntailmentScorer 17 | from bert_score.score import get_bert_score 18 | 19 | PADDING_IDX = 0 20 | 21 | def modify_tokenizer(tokenizer, data_type): 22 | additional_special_tokens = ['', '', '', '', '', 23 | ''] 24 | if data_type == 'emoji': 25 | with open('datasets/emoji_talk/emojis.json', 'r') as f: 26 | emojis = json.load(f)['emojis'] 27 | additional_special_tokens.extend(emojis) 28 | tokenizer.add_special_tokens({'pad_token': '', 'bos_token': '', 'eos_token': '', 29 | 'additional_special_tokens': additional_special_tokens}) 30 | tokenizer.eos_id, tokenizer.bos_id, tokenizer.pad_id = tokenizer.eos_token_id, tokenizer.bos_token_id, tokenizer.pad_token_id 31 | tokenizer.sent_dialog_id = tokenizer.bos_token_id 32 | tokenizer.info_dialog_id, tokenizer.info_bos_id = tokenizer.added_tokens_encoder[''], \ 33 | tokenizer.added_tokens_encoder[ 34 | ''] 35 | tokenizer.info_eos_id = tokenizer.added_tokens_encoder[''] 36 | tokenizer.talker1_dialog_id, tokenizer.talker1_bos_id = tokenizer.added_tokens_encoder[''], \ 37 | tokenizer.added_tokens_encoder[''] 38 | tokenizer.talker1_eos_id = tokenizer.added_tokens_encoder[''] 39 | tokenizer.talker2_dialog_id, tokenizer.talker2_bos_id = tokenizer.added_tokens_encoder[''], \ 40 | tokenizer.added_tokens_encoder[''] 41 | tokenizer.talker2_eos_id = tokenizer.added_tokens_encoder[''] 42 | return tokenizer, len(additional_special_tokens) + 3 43 | 44 | def pad_sequence(sequences, batch_first=False, padding_value=0, left=False): 45 | # assuming trailing dimensions and type of all the Tensors 46 | # in sequences are same and fetching those from sequences[0] 47 | if not len(sequences): 48 | return torch.empty(0) 49 | trailing_dims = sequences[0].size()[1:] 50 | max_len = max([s.size(0) for s in sequences]) 51 | if batch_first: 52 | out_dims = (len(sequences), max_len) + trailing_dims 53 | else: 54 | out_dims = (max_len, len(sequences)) + trailing_dims 55 | 56 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) 57 | for i, tensor in enumerate(sequences): 58 | length = tensor.size(0) 59 | s_slice = slice(-length, None) if left else slice(None, length) 60 | s_slice = (i, s_slice) if batch_first else (s_slice, i) 61 | out_tensor[s_slice] = tensor 62 | 63 | return out_tensor 64 | 65 | def main(): 66 | args = InputConfig().args 67 | 68 | trainer_config = get_trainer_config(args) 69 | 70 | set_seed(trainer_config.seed) 71 | device = torch.device(trainer_config.device) 72 | save_path = trainer_config.load_last[:trainer_config.load_last.rfind('/')] 73 | generate_file_name = args.generate_file_name 74 | 75 | logger = config_logger(os.path.join(save_path, 'inference.log')) 76 | 77 | parsed_valid_data, parsed_test_data = None, None 78 | if args.model_type == 'gpt2': 79 | if args.single_input: 80 | model = GPT2DoubleHeadsModel.from_pretrained('./gpt2-small') 81 | else: 82 | model = GPT2EncoderDecoderModel.from_pretrained('./gpt2-small') 83 | tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-small') 84 | elif args.model_type == 'seq2seq' or args.model_type == 'rnn-seq2seq': 85 | seq2seq_vocab = Seq2seqVocab(trainer_config.train_datasets, trainer_config.valid_datasets, 86 | trainer_config.test_datasets, args.vocab_path, data_type=args.data_type) 87 | tokenizer = seq2seq_vocab.vocab 88 | if args.model_type == 'seq2seq': 89 | model = TransformerSeq2Seq(args.emb_dim, args.hidden_dim, args.num_layers, args.heads, args.depth_size, 90 | args.filter_size, tokenizer, args.pretrained_emb_file, args.pointer_gen, logger, 91 | multi_input=not args.single_input, 92 | attention_pooling_type=args.attention_pooling_type) 93 | else: 94 | model = TransformerSeq2Seq(args.emb_dim, args.hidden_dim, args.num_layers, args.heads, args.depth_size, 95 | args.filter_size, tokenizer, args.pretrained_emb_file, args.pointer_gen, logger, 96 | base_model='gru') 97 | args.dialog_embeddings = False 98 | 99 | model.shared_attention = (args.shared_attention == 1) 100 | model.shared_module = (args.shared_module == 1) 101 | model.attention_pooling_type = args.attention_pooling_type 102 | if args.model_type in ['gpt', 'dialogpt', 'gpt2']: 103 | tokenizer, additional_length = modify_tokenizer(tokenizer, args.data_type) 104 | model.embeddings_size = 768 105 | model.n_embeddings = len(tokenizer) 106 | model.shared_attention = (args.shared_attention == 1) 107 | model.shared_module = (args.shared_module == 1) 108 | model.attention_pooling_type = args.attention_pooling_type 109 | model.single_input = args.single_input 110 | if args.model_type == 'gpt': 111 | model_embedding_weight = model.transformer.tokens_embed.weight 112 | model.transformer.tokens_embed = nn.Embedding(model.n_embeddings, 768) 113 | model.lm_head = nn.Linear(768, model.n_embeddings, bias=False) 114 | model.transformer.tokens_embed.weight.data[:-additional_length, :] = model_embedding_weight.data 115 | model.transformer.tokens_embed.weight.data[-additional_length:, :] = 0 116 | model.lm_head.weight = model.transformer.tokens_embed.weight 117 | else: 118 | model_embedding_weight = model.transformer.wte.weight 119 | model.transformer.wte = nn.Embedding(model.n_embeddings, 768) 120 | model.lm_head = nn.Linear(768, model.n_embeddings, bias=False) 121 | model.transformer.wte.weight.data[:-additional_length, :] = model_embedding_weight.data 122 | model.transformer.wte.weight.data[-additional_length:, :] = 0 123 | model.lm_head.weight = model.transformer.wte.weight 124 | 125 | if not args.single_input: 126 | model.reload_module_dict() 127 | model.sent_dialog_id = tokenizer.sent_dialog_id 128 | 129 | model.padding_idx = tokenizer.pad_id 130 | model.n_pos_embeddings = 512 131 | 132 | model.talker1_id = tokenizer.talker1_bos_id 133 | model.talker2_id = tokenizer.talker2_bos_id 134 | model.bos_id = tokenizer.bos_id 135 | model.eos_id = tokenizer.eos_id 136 | model.beam_size = args.beam_size 137 | model.diversity_groups = 1 138 | model.max_seq_len = 32 139 | model.dialog_embeddings = args.dialog_embeddings 140 | model.bs_temperature = args.bs_temperature 141 | model.bs_nucleus_p = args.bs_nucleus_p 142 | model.annealing_topk = args.annealing_topk 143 | model.length_penalty_coef = args.length_penalty 144 | model.vocab = None 145 | model.annealing = args.annealing 146 | model.diversity_coef = args.diversity_coef 147 | model.sample = False 148 | model.inference_mode = args.inference_mode 149 | model.response_k = args.response_k 150 | 151 | logger.info('loading datasets') 152 | valid_dataset = None 153 | test_dataset = FacebookDataset(trainer_config.test_datasets, tokenizer, 154 | max_lengths=(model.n_pos_embeddings - 1) // (3 if args.single_input else 1), # A bit restrictive here 155 | dialog_embeddings=args.dialog_embeddings, 156 | cache=trainer_config.test_datasets_cache, 157 | use_start_end=args.use_start_end, 158 | negative_samples=0, # Keep all negative samples 159 | augment=False, 160 | aug_syn_proba=0.0, 161 | limit_size=trainer_config.limit_eval_size, 162 | single_input=args.single_input, 163 | data_type=args.data_type, 164 | parsed_data=parsed_test_data) 165 | logger.info(f'test dataset {(len(test_dataset))}') 166 | 167 | state_dict = torch.load(trainer_config.load_last, map_location=device) 168 | if state_dict.__contains__('model'): 169 | model.load_state_dict(state_dict['model'], strict=False) 170 | else: 171 | model.load_state_dict(state_dict) 172 | model.to(device) 173 | logger.info('Weights loaded from {}'.format(trainer_config.load_last)) 174 | 175 | trainer = Trainer(model, 176 | test_dataset, 177 | trainer_config, 178 | None, 179 | logger=logger, 180 | test_dataset=test_dataset, 181 | valid_dataset=valid_dataset, 182 | n_jobs=trainer_config.n_jobs, 183 | device=device, 184 | ignore_idxs=tokenizer.all_special_ids, 185 | local_rank=args.local_rank, 186 | apex_level=None, 187 | apex_loss_scale=trainer_config.apex_loss_scale, 188 | evaluate_full_sequences=trainer_config.evaluate_full_sequences, 189 | full_input=trainer_config.full_input, 190 | uncertainty_loss=args.uncertainty_loss) 191 | 192 | def external_metrics_func(full_references, full_predictions, epoch, metric=None, generate_entail=False): 193 | if epoch == -1: 194 | references_file_path = os.path.join(save_path, 'test_references_file.txt') 195 | predictions_file_path = os.path.join(save_path, 'test_predictions_file.txt') 196 | else: 197 | references_file_path = os.path.join(save_path, 'eval_references_file.txt') 198 | predictions_file_path = os.path.join(save_path, 'eval_predictions_file.txt') 199 | with open(references_file_path, 'w', encoding='utf-8') as f: 200 | f.write('\n'.join(full_references)) 201 | with open(predictions_file_path, 'w', encoding='utf-8') as f: 202 | f.write('\n'.join(full_predictions)) 203 | 204 | bleu, bleu_list, nist, nist_list, nist_bleu, nist_bleu_list, s_dist, c_dist, entropy, meteor, \ 205 | rouge_l, f1_score, avg_length = nlp_metrics(references_file_path, predictions_file_path) 206 | 207 | metrics = {'meteor': meteor * 100, 'avg_len': avg_length, 'rouge-l': rouge_l * 100, 'bleu': bleu, 'nist': nist, 208 | 'nist-bleu': nist_bleu, 'f1': f1_score * 100} 209 | for name, metric in ( 210 | ('bleu', bleu_list), ('nist', nist_list), ('nist_bleu', nist_bleu_list), ('entropy', entropy), 211 | ('sentence_div', s_dist), ('corpus_div', c_dist)): 212 | for i, m in enumerate(metric, 1): 213 | if name == 'sentence_div' or name == 'corpus_div': 214 | metrics['{}_{}'.format(name, i)] = m * 100 215 | else: 216 | metrics['{}_{}'.format(name, i)] = m 217 | if args.entail_score_refs_file and epoch == -1: 218 | entailment_scorer = EntailmentScorer(predictions_file_path, args.entail_score_refs_file, 219 | args.entail_model_path, device) 220 | metrics['entail_score'] = entailment_scorer.calculate_entailment_score() 221 | if args.bert_score_model_path is not None and epoch == -1: 222 | all_preds = get_bert_score( 223 | full_predictions, 224 | full_references, 225 | model_type=args.bert_score_model_path, 226 | num_layers=16, 227 | batch_size=16, 228 | ) 229 | metrics['bert_score_p'] = torch.mean(all_preds[0]).item() 230 | metrics['bert_score_r'] = torch.mean(all_preds[1]).item() 231 | metrics['bert_score_f'] = torch.mean(all_preds[2]).item() 232 | for k, v in metrics.items(): 233 | metrics[k] = round(v, 6) 234 | return metrics 235 | 236 | def external_metrics_func_entail_data(full_predictions, raw_entail_data): 237 | for i, prediction in enumerate(full_predictions): 238 | raw_entail_data[i][2] = prediction 239 | with open(generate_file_name, 'w') as f: 240 | json.dump(raw_entail_data, f) 241 | 242 | metric_funcs = {'f1_score': f1_score} 243 | # trainer.test(metric_funcs, external_metrics_func, epoch=0, inference=True) 244 | if args.data_type == 'entailment': 245 | with open(args.test_datasets, 'r') as f: 246 | raw_entail_data = json.load(f) 247 | trainer.test(metric_funcs, external_metrics_func_entail_data, epoch=-1, inference=True, 248 | raw_entail_data=raw_entail_data) 249 | else: 250 | trainer.test(metric_funcs, external_metrics_func, epoch=-1, inference=True) 251 | 252 | if __name__ == '__main__': 253 | main() 254 | -------------------------------------------------------------------------------- /attention_experiment/attention_exp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | 4 | import json 5 | import torch 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import torch.nn as nn 9 | 10 | from config import get_trainer_config, InputConfig 11 | from model.dataset import FacebookDataset 12 | from model.trainer import Trainer 13 | from model.gpt2_model import GPT2DoubleHeadsModel 14 | from transformers.tokenization_gpt2 import GPT2Tokenizer 15 | from model.utils import open, set_seed, config_logger 16 | from model.seq2seq import TransformerSeq2Seq 17 | from model.seq2seq_vocab import Seq2seqVocab 18 | 19 | PADDING_IDX = 0 20 | 21 | def modify_tokenizer(tokenizer, data_type): 22 | additional_special_tokens = ['', '', '', '', '', 23 | ''] 24 | if data_type == 'emoji': 25 | with open('datasets/emoji_talk/emojis.json', 'r') as f: 26 | emojis = json.load(f)['emojis'] 27 | additional_special_tokens.extend(emojis) 28 | tokenizer.add_special_tokens({'pad_token': '', 'bos_token': '', 'eos_token': '', 29 | 'additional_special_tokens': additional_special_tokens}) 30 | tokenizer.eos_id, tokenizer.bos_id, tokenizer.pad_id = tokenizer.eos_token_id, tokenizer.bos_token_id, tokenizer.pad_token_id 31 | tokenizer.sent_dialog_id = tokenizer.bos_token_id 32 | tokenizer.info_dialog_id, tokenizer.info_bos_id = tokenizer.added_tokens_encoder[''], \ 33 | tokenizer.added_tokens_encoder[ 34 | ''] 35 | tokenizer.info_eos_id = tokenizer.added_tokens_encoder[''] 36 | tokenizer.talker1_dialog_id, tokenizer.talker1_bos_id = tokenizer.added_tokens_encoder[''], \ 37 | tokenizer.added_tokens_encoder[''] 38 | tokenizer.talker1_eos_id = tokenizer.added_tokens_encoder[''] 39 | tokenizer.talker2_dialog_id, tokenizer.talker2_bos_id = tokenizer.added_tokens_encoder[''], \ 40 | tokenizer.added_tokens_encoder[''] 41 | tokenizer.talker2_eos_id = tokenizer.added_tokens_encoder[''] 42 | return tokenizer, len(additional_special_tokens) + 3 43 | 44 | def pad_sequence(sequences, batch_first=False, padding_value=0, left=False): 45 | # assuming trailing dimensions and type of all the Tensors 46 | # in sequences are same and fetching those from sequences[0] 47 | if not len(sequences): 48 | return torch.empty(0) 49 | trailing_dims = sequences[0].size()[1:] 50 | max_len = max([s.size(0) for s in sequences]) 51 | if batch_first: 52 | out_dims = (len(sequences), max_len) + trailing_dims 53 | else: 54 | out_dims = (max_len, len(sequences)) + trailing_dims 55 | 56 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) 57 | for i, tensor in enumerate(sequences): 58 | length = tensor.size(0) 59 | s_slice = slice(-length, None) if left else slice(None, length) 60 | s_slice = (i, s_slice) if batch_first else (s_slice, i) 61 | out_tensor[s_slice] = tensor 62 | 63 | return out_tensor 64 | 65 | def collate_func(data): 66 | persona_info, h, y, distractors_batch = zip(*data) 67 | 68 | contexts = [] 69 | 70 | if max(map(len, persona_info)) > 0: 71 | persona_info = [torch.tensor(d, dtype=torch.long) for d in persona_info] 72 | contexts.append(persona_info) 73 | 74 | if max(map(len, h)) > 0: 75 | h = [torch.tensor(d, dtype=torch.long) for d in h] 76 | contexts.append(h) 77 | 78 | y_out = [torch.tensor(d, dtype=torch.long) for d in y] 79 | 80 | distractors = [torch.tensor(d, dtype=torch.long) for distractors in distractors_batch for d in distractors] 81 | 82 | # Pad now so we pad correctly when we have only a single input (context concatenated with y) 83 | y_out = pad_sequence(y_out, batch_first=True, padding_value=PADDING_IDX) 84 | distractors = pad_sequence(distractors, batch_first=True, padding_value=PADDING_IDX) 85 | contexts = [pad_sequence(c, batch_first=True, padding_value=PADDING_IDX) for c in contexts] 86 | 87 | return contexts, y_out, distractors 88 | 89 | def _s2s_loss(targets, enc_contexts, model): 90 | hidden_state, padding_mask = None, None 91 | 92 | nexts = targets[:, 1:].contiguous() if targets.dim() == 2 else targets[:, 1:, 0].contiguous() 93 | outputs = model.decode(targets[:, :-1].contiguous(), enc_contexts) 94 | 95 | outputs = outputs.view(-1, outputs.shape[-1]).float() 96 | nexts = nexts.view(-1) 97 | 98 | lm_criterion = torch.nn.CrossEntropyLoss(ignore_index=PADDING_IDX) 99 | loss = lm_criterion(outputs, nexts) 100 | return loss, hidden_state, padding_mask 101 | 102 | def _lm_loss(contexts, enc_contexts, model, ignore_idxs, device): 103 | batch_lm_loss = torch.tensor(0, dtype=torch.float, device=device) 104 | 105 | for context in contexts: 106 | enc_context = model.encode(context.clone()) 107 | enc_contexts.append(enc_context) 108 | 109 | context_outputs = model.generate(enc_context[0]) 110 | ignore_mask = torch.stack([context == idx for idx in ignore_idxs], dim=-1).any(dim=-1) 111 | context.masked_fill_(ignore_mask, PADDING_IDX) 112 | prevs = context_outputs[:, :-1, :].contiguous() 113 | nexts = context[:, 1:].contiguous() if context.dim() == 2 else context[:, 1:, 0].contiguous() 114 | lm_criterion = torch.nn.CrossEntropyLoss(ignore_index=PADDING_IDX) 115 | batch_lm_loss += lm_criterion(prevs.view(-1, prevs.shape[-1]).float(), nexts.view(-1)) / len(contexts) 116 | return batch_lm_loss 117 | 118 | 119 | 120 | def main(): 121 | args = InputConfig().args 122 | 123 | trainer_config = get_trainer_config(args) 124 | 125 | set_seed(trainer_config.seed) 126 | device = torch.device(trainer_config.device) 127 | save_path = trainer_config.load_last[:trainer_config.load_last.rfind('/')] 128 | generate_file_name = args.generate_file_name 129 | 130 | logger = config_logger(os.path.join(save_path, 'inference.log')) 131 | 132 | parsed_valid_data, parsed_test_data = None, None 133 | if args.model_type == 'seq2seq': 134 | seq2seq_vocab = Seq2seqVocab(trainer_config.train_datasets, trainer_config.valid_datasets, 135 | trainer_config.test_datasets, args.vocab_path, data_type=args.data_type) 136 | tokenizer = seq2seq_vocab.vocab 137 | model = TransformerSeq2Seq(args.emb_dim, args.hidden_dim, args.num_layers, args.heads, args.depth_size, 138 | args.filter_size, tokenizer, args.pretrained_emb_file, args.pointer_gen, logger, 139 | multi_input=not args.single_input, 140 | attention_pooling_type=args.attention_pooling_type) 141 | args.dialog_embeddings = False 142 | else: 143 | model = GPT2DoubleHeadsModel.from_pretrained('./gpt2-small') 144 | tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-small') 145 | tokenizer, additional_length = modify_tokenizer(tokenizer, args.data_type) 146 | model.embeddings_size = 768 147 | model.n_embeddings = len(tokenizer) 148 | model.shared_attention = (args.shared_attention == 1) 149 | model.shared_module = (args.shared_module == 1) 150 | model.attention_pooling_type = args.attention_pooling_type 151 | model.single_input = args.single_input 152 | model_embedding_weight = model.transformer.wte.weight 153 | model.transformer.wte = nn.Embedding(model.n_embeddings, 768) 154 | model.lm_head = nn.Linear(768, model.n_embeddings, bias=False) 155 | model.transformer.wte.weight.data[:-additional_length, :] = model_embedding_weight.data 156 | model.transformer.wte.weight.data[-additional_length:, :] = 0 157 | model.lm_head.weight = model.transformer.wte.weight 158 | 159 | model.padding_idx = tokenizer.pad_id 160 | model.n_pos_embeddings = 512 161 | 162 | model.talker1_id = tokenizer.talker1_bos_id 163 | model.talker2_id = tokenizer.talker2_bos_id 164 | model.bos_id = tokenizer.bos_id 165 | model.eos_id = tokenizer.eos_id 166 | model.beam_size = args.beam_size 167 | model.diversity_groups = 1 168 | model.max_seq_len = 32 169 | model.dialog_embeddings = args.dialog_embeddings 170 | model.bs_temperature = args.bs_temperature 171 | model.bs_nucleus_p = args.bs_nucleus_p 172 | model.annealing_topk = args.annealing_topk 173 | model.length_penalty_coef = args.length_penalty 174 | model.vocab = None 175 | model.annealing = args.annealing 176 | model.diversity_coef = args.diversity_coef 177 | model.sample = False 178 | model.inference_mode = args.inference_mode 179 | model.response_k = args.response_k 180 | 181 | logger.info('loading datasets') 182 | valid_dataset = None 183 | test_dataset = FacebookDataset(trainer_config.test_datasets, tokenizer, 184 | max_lengths=(model.n_pos_embeddings - 1) // (3 if args.single_input else 1), # A bit restrictive here 185 | dialog_embeddings=args.dialog_embeddings, 186 | cache=trainer_config.test_datasets_cache, 187 | use_start_end=args.use_start_end, 188 | negative_samples=0, # Keep all negative samples 189 | augment=False, 190 | aug_syn_proba=0.0, 191 | limit_size=trainer_config.limit_eval_size, 192 | max_history_size=args.max_history_size, 193 | single_input=args.single_input, 194 | data_type=args.data_type, 195 | parsed_data=parsed_test_data) 196 | # logger.info(f'valid dataset {len(valid_dataset)} test dataset {(len(test_dataset))}') 197 | logger.info(f'test dataset {(len(test_dataset))}') 198 | 199 | state_dict = torch.load(trainer_config.load_last, map_location=device) 200 | if state_dict.__contains__('model'): 201 | model.load_state_dict(state_dict['model'], strict=False) 202 | else: 203 | model.load_state_dict(state_dict) 204 | model.to(device) 205 | logger.info('Weights loaded from {}'.format(trainer_config.load_last)) 206 | 207 | trainer = Trainer(model, 208 | test_dataset, 209 | trainer_config, 210 | None, 211 | logger=logger, 212 | test_dataset=test_dataset, 213 | valid_dataset=valid_dataset, 214 | n_jobs=trainer_config.n_jobs, 215 | device=device, 216 | ignore_idxs=tokenizer.all_special_ids, 217 | local_rank=args.local_rank, 218 | apex_level=None, 219 | apex_loss_scale=trainer_config.apex_loss_scale, 220 | full_input=trainer_config.full_input, 221 | uncertainty_loss=args.uncertainty_loss) 222 | 223 | _, all_attention, all_attention_inference = trainer.test_attention() 224 | torch.save(all_attention, 'augmentation/gpt2_th0.99_mix_all_attention.bin') 225 | 226 | def analysis_attention(): 227 | all_attention = torch.load('augmentation/gpt2_th0.99_raw_all_attention.bin') 228 | with open('augmentation/th0.99_gpt2_positions.json', 'r') as f: 229 | consistent_pos = json.load(f) 230 | token_pos, target_persona_pos, whole_persona_pos = consistent_pos['token_positions'], \ 231 | consistent_pos['target_persona_positions'], \ 232 | consistent_pos['whole_persona_positions'] 233 | token_level_attentions, sentence_level_attentions = [], [] 234 | persona_sentence_ratio, avg_token_number = [], [] 235 | for i in tqdm(range(len(target_persona_pos))): 236 | cur_attention = (all_attention[i] / torch.sum(all_attention[i], dim=-1, keepdim=True)).numpy() 237 | persona_sentence_attention = cur_attention[:, :, target_persona_pos[i][0]: target_persona_pos[i][1]] 238 | sentence_level_attentions.append(np.mean(np.sum(persona_sentence_attention, axis=-1), axis=-1)) 239 | cur_token_attention = [] 240 | for p in token_pos[i]: 241 | cur_token_attention.append(cur_attention[:, p[1], p[0]]) 242 | token_level_attentions.append(cur_token_attention) 243 | persona_sentence_ratio.append((target_persona_pos[i][1] - target_persona_pos[i][0]) / cur_attention.shape[2]) 244 | avg_token_number.append(cur_attention.shape[2]) 245 | token_by_layer, sentence_by_layer = [[] for _ in range(12)], [[] for _ in range(12)] 246 | for i in range(len(all_attention)): 247 | for j in range(12): 248 | sentence_by_layer[j].append(sentence_level_attentions[i][j]) 249 | if len(token_level_attentions[i]) > 0: 250 | token_by_layer[j].append(np.mean([a[j] for a in token_level_attentions[i]])) 251 | for i in range(12): 252 | print(str(i) + ' layer token-level: ' + str(np.mean(token_by_layer[i]))) 253 | print(str(i) + ' layer sentence-level: ' + str(np.mean(sentence_by_layer[i]))) 254 | print('mean token level value: ' + str(1/ np.mean(avg_token_number))) 255 | print('mean sentence level value: ' + str(np.mean(persona_sentence_ratio))) 256 | # cur_whole_attention, cur_persona_attention = [], [] 257 | # for p in token_pos[i]: 258 | # cur_whole_attention.append(F.softmax(all_attention[i][:, p[1], :], dim=-1).numo'tpy()) 259 | # cur_persona_attention.append(F.softmax(all_attention[i][:, p[1], :whole_persona_pos[i][1]], dim=-1).numpy()) 260 | # whole_matched_response_attentions.append(cur_whole_attention) 261 | # within_persona_matched_response_attentions.append(cur_persona_attention) 262 | # response_persona_attention = F.softmax(all_attention[i], dim=-1).numpy()[:, :, target_persona_pos[i][0]: target_persona_pos[i][1]] 263 | # sentence_level_attention.append(np.mean(np.sum(response_persona_attention, axis=-1), axis=-1)) 264 | # all_probs, all_probs_within_persona, all_prob_target_persona = [[] for _ in range(6)], [[] for _ in range(6)], \ 265 | # [[] for _ in range(6)] 266 | # all_acc, all_acc_within_persona = [0] * 6, [0] * 6 267 | # lengths, persona_lengths, target_persona_lengths = [], [], [] 268 | # for i in tqdm(range(len(target_persona_pos))): 269 | # for j, pos in enumerate(token_pos[i]): 270 | # for m in range(6): 271 | # all_probs[m].append(whole_matched_response_attentions[i][j][m][pos[0]]) 272 | # index = np.argsort(-whole_matched_response_attentions[i][j][m]) 273 | # if index[0] == pos[0]: 274 | # all_acc[m] += 1 275 | # all_probs_within_persona[m].append(within_persona_matched_response_attentions[i][j][m][pos[0]]) 276 | # index = np.argsort(-within_persona_matched_response_attentions[i][j][m]) 277 | # if index[0] == pos[0]: 278 | # all_acc_within_persona[m] += 1 279 | # lengths.append(whole_matched_response_attentions[i][j][0].shape[0]) 280 | # persona_lengths.append(within_persona_matched_response_attentions[i][j][0].shape[0]) 281 | # target_persona_lengths.append(target_persona_pos[i][1] - target_persona_pos[i][0]) 282 | # for m in range(6): 283 | # all_prob_target_persona[m].append(sentence_level_attention[i][m]) 284 | # for i in range(6): 285 | # print(str(i) + ' layer prob values: ' + str(np.mean(all_probs[i]))) 286 | # print(str(i) + ' layer acc: ' + str(all_acc[i] / len(all_probs[i]))) 287 | # print(str(i) + ' layer prob values within persona: ' + str(np.mean(all_probs_within_persona[i]))) 288 | # print(str(i) + ' layer acc within persona: ' + str(all_acc_within_persona[i] / len(all_probs[i]))) 289 | # print(str(i) + ' layer sentence-level prob: ' + str(np.mean(all_prob_target_persona[i]))) 290 | # print('mean prob value: ' + str(1 / np.mean(lengths))) 291 | # print('mean prob value within persona: ' + str(1 / np.mean(persona_lengths))) 292 | # print('mean prob value sentence-level: ' + str(1 / np.mean(target_persona_lengths))) 293 | 294 | def analysis_attention_history(): 295 | all_attention = torch.load('augmentation/gpt2_th0.99_raw_attention.bin') 296 | with open('augmentation/th0.99_gpt2_positions.json', 'r') as f: 297 | positions = json.load(f) 298 | avgs = [] 299 | lengths = [] 300 | attentions = [[] for _ in range(6)] 301 | for i in range(len(positions)): 302 | cur_attention = F.softmax(all_attention[i], dim=-1) 303 | cur_positions = positions[i] 304 | # if len(cur_positions[1]) >= 3: 305 | # start = cur_positions[0][-1] 306 | # if len(cur_positions[1]) > 3: 307 | # start = cur_positions[1][-4] 308 | # else: 309 | # start = cur_positions[0][-1] 310 | # else: 311 | # start = cur_positions[0][-1] 312 | start = cur_positions[1][-2] if len(cur_positions[1]) > 1 else cur_positions[0][-1] 313 | end = cur_positions[1][-1] 314 | target_attention = cur_attention[:, :, start: end] 315 | avg_attention = 1 / cur_attention.size()[2] * (end - start) 316 | for j in range(6): 317 | attentions[j].append(torch.mean(torch.sum(target_attention[j], dim=-1)).item()) 318 | avgs.append(avg_attention) 319 | lengths.append(cur_attention.size()[2]) 320 | print(np.mean(attentions[0])) 321 | print(np.mean(attentions[1])) 322 | print(np.mean(attentions[2])) 323 | print(np.mean(attentions[3])) 324 | print(np.mean(attentions[4])) 325 | print(np.mean(attentions[5])) 326 | print(np.mean(avgs)) 327 | print(1 / np.mean(lengths)) 328 | print('111') 329 | 330 | if __name__ == '__main__': 331 | # main() 332 | analysis_attention() 333 | # analysis_attention_history() 334 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from attrdict import AttrDict 2 | from copy import deepcopy 3 | import torch 4 | from model.utils import openai_transformer_config 5 | import git 6 | import argparse 7 | 8 | 9 | repo = git.Repo(search_parent_directories=True) 10 | 11 | 12 | def cast2(type_): 13 | return lambda val: val if val is None else type_(val) 14 | 15 | 16 | def get_model_config(args): 17 | default_config = openai_transformer_config() 18 | config = AttrDict({'bpe_vocab_path': './parameters/bpe.vocab', 19 | 'bpe_codes_path': './parameters/bpe.code', 20 | 'checkpoint_path': './checkpoints/last_checkpoint', # Keep the checpoint folder for the checkpoints of the agents 21 | 'n_layers': default_config.n_layers, 22 | 'n_pos_embeddings': 512, 23 | 'embeddings_size': default_config.embeddings_size, 24 | 'n_heads': default_config.n_heads, 25 | 'dropout': default_config.dropout, 26 | 'embed_dropout': default_config.embed_dropout, 27 | 'attn_dropout': default_config.attn_dropout, 28 | 'ff_dropout': default_config.ff_dropout, 29 | 'normalize_embeddings': args.normalize_embeddings, 30 | 'max_seq_len': 128, 31 | 'beam_size': args.beam_size, 32 | 'diversity_coef': args.diversity_coef, 33 | 'diversity_groups': args.diversity_groups, 34 | 'annealing_topk': args.annealing_topk, 35 | 'annealing': args.annealing, 36 | 'length_penalty': args.length_penalty, 37 | 'n_segments': None, 38 | 'constant_embedding': args.constant_embedding, 39 | 'multiple_choice_head': args.multiple_choice_head, 40 | 'share_models': True, 41 | 'successive_attention': args.successive_attention, 42 | 'sparse_embeddings': args.sparse_embeddings, 43 | 'shared_attention': args.shared_attention, 44 | 'dialog_embeddings': args.dialog_embeddings, 45 | 'single_input': args.single_input, 46 | 'use_start_end': args.use_start_end, 47 | 'apex_level': args.apex_level, # 'O0', 'O1', 'O2', 'O3', 48 | 'bs_temperature': args.bs_temperature, 49 | 'bs_nucleus_p': args.bs_nucleus_p, 50 | 'same_embedding_lm': args.same_embedding_lm, 51 | }) 52 | 53 | return config 54 | 55 | 56 | def get_trainer_config(args, curriculum_config=False): 57 | config = AttrDict({'n_epochs': args.n_epochs, 58 | 'writer_comment': args.writer_comment, 59 | 'train_batch_size': args.train_batch_size, 60 | 'meta_batch_size': args.meta_batch_size, 61 | 'batch_split': args.batch_split, 62 | 'test_batch_size': args.test_batch_size, 63 | 'lr': args.lr, 64 | 'lr_warmup': args.lr_warmup, # a fraction of total training (epoch * train_set_length) if linear_schedule == True 65 | 'weight_decay': 0.01, 66 | 's2s_weight': args.s2s_weight, 67 | 'lm_weight': args.lm_weight, 68 | 'risk_weight': args.risk_weight, 69 | 'hits_weight': args.hits_weight, 70 | 'negative_samples': args.negative_samples, 71 | 'n_jobs': 4, 72 | 'label_smoothing': args.label_smoothing, 73 | 'clip_grad': args.clip_grad, 74 | 'alpha_clip_grad': args.alpha_clip_grad, 75 | 'test_period': 1, 76 | 'seed': args.seed, 77 | 'device': 'cuda', 78 | 'zero_shot': args.zero_shot, 79 | 'persona_augment': args.persona_augment, 80 | 'persona_aug_syn_proba': args.persona_aug_syn_proba, 81 | 'apex_loss_scale': args.apex_loss_scale, # e.g. '128', 'dynamic' 82 | 'linear_schedule': args.linear_schedule, 83 | 'evaluate_full_sequences': args.evaluate_full_sequences, 84 | 'limit_eval_size': args.limit_eval_size, 85 | 'limit_train_size': args.limit_train_size, 86 | 'risk_metric': args.risk_metric, 87 | 'load_last': args.load_last, #./checkpoints/last_checkpoint', # Now that we save several experiments you can put the path of the checpoint file you want to load here 88 | 'load_alpha_last': args.load_alpha_last, 89 | 'repo_id': str(repo), 90 | 'repo_sha': str(repo.head.object.hexsha), 91 | 'repo_branch': str(repo.active_branch), 92 | 'openai_parameters_dir': './parameters', 93 | 'last_checkpoint_path': 'last_checkpoint', # there are now in the ./runs/XXX/ experiments folders 94 | 'eval_references_file': 'eval_references_file', 95 | 'eval_predictions_file': 'eval_predictions_file', 96 | 'test_references_file': 'test_references_file', 97 | 'test_predictions_file_best': 'test_predictions_file_best', 98 | 'test_predictions_file_last': 'test_predictions_file_last', 99 | 'interrupt_checkpoint_path': 'interrupt_checkpoint', # there are now in the ./runs/XXX/ experiments folders 100 | 'train_datasets': args.train_datasets, 101 | 'train_datasets_cache': args.train_datasets_cache, 102 | 'test_datasets': args.test_datasets, 103 | 'test_datasets_cache': args.test_datasets_cache, 104 | 'valid_datasets': args.valid_datasets, 105 | 'valid_datasets_cache': args.valid_datasets_cache, 106 | 'full_input': args.full_input, 107 | 'single_input': args.single_input, 108 | 'max_history_size': args.max_history_size, 109 | 'model_saving_interval': args.model_saving_interval, 110 | 'patience': args.patience, 111 | 'mixup_cache': args.mixup_cache, 112 | 'data_type': args.data_type, 113 | }) 114 | if curriculum_config: 115 | config.train_datasets = args.curriculum_train_datasets 116 | config.train_datasets_cache = args.curriculum_train_datasets_cache 117 | config.valid_datasets = args.curriculum_valid_datasets 118 | config.valid_datasets_cache = args.curriculum_valid_datasets_cache 119 | config.test_datasets = args.curriculum_valid_datasets 120 | config.test_datasets_cache = args.curriculum_valid_datasets_cache 121 | config.lr = args.curriculum_lr 122 | config.n_epochs = args.curriculum_n_epochs 123 | config.patience = args.patience 124 | config.max_history_size = args.curriculum_max_history_size 125 | config.eval_references_file = 'eval_references_file_stage1' 126 | config.eval_predictions_file = 'eval_predictions_file_stage1' 127 | config.test_references_file = 'test_reference_file_stage1' 128 | config.test_predictions_file_best = 'test_predictions_file_stage1_best' 129 | config.test_predictions_file_last = 'test_predictions_file_stage1_last' 130 | config.data_type = args.curriculum_data_type 131 | 132 | local_config = deepcopy(config) 133 | local_config.train_batch_size = 16 134 | local_config.batch_split = 2 135 | local_config.test_batch_size = 4 136 | local_config.n_jobs = 0 137 | local_config.device = 'cpu' 138 | local_config.risk_weight = 0 139 | local_config.zero_shot = False 140 | local_config.fp16 = False 141 | # local_config.train_datasets_cache = './datasets/train_datasets_cache.bin' 142 | # local_config.test_datasets_cache = './datasets/test_datasets_cache.bin' 143 | 144 | return config if torch.cuda.is_available() else local_config 145 | 146 | class InputConfig(): 147 | def __init__(self): 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('--seed', type=int, default=0) 150 | parser.add_argument('--normalize_embeddings', action='store_true') 151 | parser.add_argument('--beam_size', default=3, type=int) 152 | parser.add_argument('--inference_mode', default='beam', type=str) 153 | parser.add_argument('--response_k', default=1, type=int) 154 | parser.add_argument('--diversity_coef', default=0, type=int) 155 | parser.add_argument('--lr', default=6.25e-5, type=float) 156 | parser.add_argument('--meta_lr', default=1e-4, type=float) 157 | parser.add_argument('--finetune_lr', default=2e-4, type=float) 158 | parser.add_argument('--finetune_epochs', default=3, type=int) 159 | parser.add_argument('--lr_warmup', default=0.002, type=float) 160 | parser.add_argument('--clip_grad', default=None, type=float) 161 | parser.add_argument('--diversity_groups', default=1, type=int) 162 | parser.add_argument('--annealing_topk', default=None, type=int) 163 | parser.add_argument('--annealing', default=0, type=float) 164 | parser.add_argument('--length_penalty', default=0.6, type=float) 165 | parser.add_argument('--bs_temperature', default=1, type=float) 166 | parser.add_argument('--bs_nucleus_p', default=0, type=float) 167 | parser.add_argument('--apex_level', default=None, type=str) 168 | parser.add_argument('--constant_embedding', action='store_true') 169 | parser.add_argument('--multiple_choice_head', action='store_true') 170 | parser.add_argument('--successive_attention', action='store_true') 171 | parser.add_argument('--sparse_embeddings', default=False, type=bool) 172 | parser.add_argument('--dialog_embeddings', default=True, type=bool) 173 | parser.add_argument('--single_input', action='store_true') 174 | parser.add_argument('--no_persona', action='store_true') 175 | parser.add_argument('--mixup', action='store_true') 176 | parser.add_argument('--use_start_end', action='store_true') 177 | parser.add_argument('--zero_shot', action='store_true') 178 | parser.add_argument('--persona_augment', action='store_true') 179 | parser.add_argument('--linear_schedule', default=True, type=bool) 180 | parser.add_argument('--evaluate_full_sequences', default=True, type=bool) 181 | parser.add_argument('--n_epochs', default=3, type=int) 182 | parser.add_argument('--patience', default=-1, type=int, help="the training patience if the dev result " 183 | "does not promote then training ends") 184 | parser.add_argument('--train_batch_size', default=128, type=int) 185 | parser.add_argument('--batch_split', default=32, type=int) 186 | parser.add_argument('--test_batch_size', default=8, type=int) 187 | parser.add_argument('--writer_comment', default='', type=str) 188 | parser.add_argument('--s2s_weight', default=2, type=float) 189 | parser.add_argument('--lm_weight', default=1, type=float) 190 | parser.add_argument('--risk_weight', default=0, type=float) 191 | parser.add_argument('--hits_weight', default=0, type=float) 192 | parser.add_argument('--label_smoothing', default=-1, type=float, 193 | help='Config for Seq2Seq model, whether use label smoothing loss, -1 means no smoothing') 194 | parser.add_argument('--negative_samples', default=0, type=int) 195 | parser.add_argument('--persona_aug_syn_proba', default=0, type=float) 196 | parser.add_argument('--apex_loss_scale', default=None, type=str) 197 | parser.add_argument('--limit_eval_size', default=-1, type=int) 198 | parser.add_argument('--limit_train_size', default=-1, type=int) 199 | parser.add_argument('--risk_metric', default='f1', type=str) 200 | parser.add_argument('--load_last', default='', type=str) 201 | parser.add_argument('--load_alpha_last', default='', type=str) 202 | parser.add_argument('--data_type', default='persona', type=str, help='data set types, persona/emoji/daily') 203 | parser.add_argument('--test_data_type', default=None, type=str, help='data set types, persona/emoji/daily') 204 | parser.add_argument('--emb_dim', default=300, type=int, help='Config for Seq2Seq model') 205 | parser.add_argument('--hidden_dim', default=300, type=int, help='Config for Seq2Seq model') 206 | parser.add_argument('--num_layers', default=6, type=int, help='Config for Seq2Seq model') 207 | parser.add_argument('--heads', default=4, type=int, help='Config for Seq2Seq model') 208 | parser.add_argument('--depth_size', default=40, type=int, help='Config for Seq2Seq model') 209 | parser.add_argument('--filter_size', default=50, type=int, help='Config for Seq2Seq model') 210 | parser.add_argument('--pointer_gen', action='store_true', help='Config for Seq2Seq model') 211 | parser.add_argument('--pretrained_emb_file', default='./glove/glove.6B.300d.txt', type=str) 212 | parser.add_argument('--vocab_path', default='./datasets/persona_vocab.bin', type=str) 213 | parser.add_argument('--extend_exist_vocab', default=None, type=str) 214 | parser.add_argument('--train_datasets', default='datasets/ConvAI2/train_self_original.txt', type=str) 215 | parser.add_argument('--valid_datasets', default='datasets/ConvAI2/valid_self_original.txt', type=str) 216 | parser.add_argument('--test_datasets', default='datasets/ConvAI2/test_self_original.txt', type=str) 217 | parser.add_argument('--cache_vocab_path', default='datasets/ConvAI2/cached_vocab.pickle', type=str) 218 | parser.add_argument('--train_datasets_cache', default='datasets/train_cache.bin', type=str) 219 | parser.add_argument('--valid_datasets_cache', default='datasets/valid_cache.bin', type=str) 220 | parser.add_argument('--test_datasets_cache', default='datasets/test_cache.bin', type=str) 221 | parser.add_argument('--extra_train_path', default=None, type=str) 222 | parser.add_argument('--extra_data_type', default='persona', type=str) 223 | parser.add_argument('--extra_cvae_utterances_path', default=None, type=str, help='The path indicates the CVAE augmented utterances') 224 | parser.add_argument('--generate_file_name', default='generation.json', type=str, help='The saved json file name ' 225 | 'when inference on entailment data') 226 | parser.add_argument('--mixup_cache', default='datasets/mixup_cache.bin', type=str) 227 | parser.add_argument('--mixup_mode', default='alternate', type=str, help='The mode to execute the mixup operation' 228 | 'alternate=no mixup for one batch and mix up for next batch iteratively' 229 | 'all=mixup for all training samples, random=randomly mixup several samples within a batch' 230 | 'while rest samples remains unmixed') 231 | parser.add_argument('--mixup_model_path', default='./fasttext/persona_50_cbow.bin', type=str) 232 | parser.add_argument('--mixup_candidate_th', default=0.4, type=float) 233 | parser.add_argument('--mixup_ratio', default=0.15, type=float) 234 | parser.add_argument('--mixup_soft_loss_weight', default=0, type=float) 235 | parser.add_argument('--bert_mixup', action='store_true', help='whether use bert model for mixup') 236 | parser.add_argument('--replace', action='store_true', help='whether use directly replace token in samples') 237 | parser.add_argument('--few_shot', action='store_true', help='whether do few-shot learning') 238 | parser.add_argument('--shot_num', type=int, default=5, help='The shot number for training') 239 | parser.add_argument('--train_task_map', type=str, default=None, help='The path of task map json file between ' 240 | 'persona id and sample ids') 241 | parser.add_argument('--valid_task_map', type=str, default=None, help='The path of task map json file between ' 242 | 'persona id and sample ids') 243 | parser.add_argument('--test_task_map', type=str, default=None, help='The path of task map json file between ' 244 | 'persona id and sample ids') 245 | parser.add_argument('--meta_batch_size', type=int, default=16, help='The meta batch size') 246 | parser.add_argument('--meta_batch_ratio', type=float, default=1, help='The training and evaluation sample number' 247 | ' ratio in each batch') 248 | parser.add_argument('--full_input', action='store_true', help='whether use the concatenated persona, history' 249 | ' and reply as the input ids') 250 | parser.add_argument('--max_history_size', type=int, default=-1, help='max history size in input ids') 251 | parser.add_argument('--same_embedding_lm', type=int, default=1, help='the embedding in transformer and the ' 252 | 'weight in the lm are the same') 253 | parser.add_argument('--uncertainty_loss', action='store_true', help='whether use uncertainty loss') 254 | # parser.add_argument('--attention_weight', action='store_true', help='Whether use the learnable weight for ' 255 | # 'attention obtained from different source in decoder') 256 | parser.add_argument('--model_type', type=str, default='gpt', help='gpt/gpt2/se2seq/rnn-seq2seq') 257 | parser.add_argument('--model_saving_interval', type=int, default=10, help='model saving interval for seq2seq') 258 | parser.add_argument('--entail_score_refs_file', type=str, default=None, help='The persona and idx json file for each ' 259 | 'utterance, if None no entailment score will be calculated') 260 | parser.add_argument('--entail_model_path', type=str, default='./analysis/roberta_nli', 261 | help='The persona and idx json file for each utterance, if None no entailment score will be calculated') 262 | parser.add_argument('--bert_score_model_path', type=str, default='../roberta_large', help='The model path ' 263 | 'indicate which model will be used for bertscore, if None then no bertscore is calculated') 264 | parser.add_argument('--baseline_path', type=str, default='./bert_score/rescale_baseline/en/roberta-large.tsv', 265 | help='The file path for the bert score rescale baseline') 266 | parser.add_argument('--rescale_with_baseline', action='store_true', 267 | help='Whether rescale the bert score') 268 | parser.add_argument('--ignore_sample_indices', type=str, default=None, 269 | help='The json file indicating which samples will be ignored') 270 | parser.add_argument('--shared_module', type=int, default=1) 271 | parser.add_argument('--shared_attention', type=int, default=1) 272 | parser.add_argument('--attention_pooling_type', type=str, default='mean', help='the method to pool attention ' 273 | 'output from different source(mean/min/max/sw/dw/linear/att) ' 274 | 'sw=source level weight, dw=dimension level weight, linear=linear transform for concatenating,' 275 | 'att=extra transformer attention layer to fuse attention output,' 276 | 'dys=dynamic determine the scalar weight for each source by a linear layer' 277 | 'dyd=dynamic determine the vector weight for each dimension for each source by a linear layer' 278 | 'mdys=mutual determine the scalar weight for each source by a linear layer' 279 | 'mdyd=mutual dynamic determine the vector weight for each dimension for each source by a linear layer') 280 | 281 | '''Configurations related to curriculum learning''' 282 | parser.add_argument('--curriculum_learning', action='store_true', help='Whether to do curriculum learning') 283 | parser.add_argument('--curriculum_train_datasets', type=str, default=None) 284 | parser.add_argument('--curriculum_valid_datasets', type=str, default=None) 285 | parser.add_argument('--curriculum_train_datasets_cache', type=str, default=None) 286 | parser.add_argument('--curriculum_valid_datasets_cache', type=str, default=None) 287 | parser.add_argument('--curriculum_lr', type=float, default=2e-4) 288 | parser.add_argument('--curriculum_n_epochs', type=int, default=50) 289 | parser.add_argument('--curriculum_patience', type=int, default=-1) 290 | parser.add_argument('--curriculum_max_history_size', type=int, default=3) 291 | parser.add_argument('--curriculum_data_type', default='entailment', type=str) 292 | parser.add_argument('--curriculum_reverse', action='store_true', help='Whether using the reverse order of training') 293 | 294 | parser.add_argument('--local_rank', type=int, default=-1, help="Distributed training.") 295 | parser.add_argument('--server_ip', type=str, default='', help="Used for debugging on GPU machine.") 296 | parser.add_argument('--server_port', type=str, default='', help="Used for debugging on GPU machine.") 297 | 298 | self.args = parser.parse_args() 299 | -------------------------------------------------------------------------------- /bert_score/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | from math import log 5 | from itertools import chain 6 | from collections import defaultdict, Counter 7 | from multiprocessing import Pool 8 | from functools import partial 9 | from tqdm.auto import tqdm 10 | from torch.nn.utils.rnn import pad_sequence 11 | from distutils.version import LooseVersion 12 | 13 | from transformers import BertConfig, XLNetConfig, XLMConfig, RobertaConfig 14 | from transformers import AutoModel, GPT2Tokenizer 15 | 16 | from . import __version__ 17 | from transformers import __version__ as trans_version 18 | 19 | __all__ = [] 20 | 21 | SCIBERT_URL_DICT = { 22 | "scibert-scivocab-uncased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_scivocab_uncased.tar", # recommend by the SciBERT authors 23 | "scibert-scivocab-cased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_scivocab_cased.tar", 24 | "scibert-basevocab-uncased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_basevocab_uncased.tar", 25 | "scibert-basevocab-cased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_basevocab_cased.tar", 26 | } 27 | 28 | 29 | lang2model = defaultdict(lambda: "bert-base-multilingual-cased") 30 | lang2model.update( 31 | {"en": "roberta-large", "zh": "bert-base-chinese", "en-sci": "scibert-scivocab-uncased",} 32 | ) 33 | 34 | 35 | model2layers = { 36 | "bert-base-uncased": 9, # 0.6925188074454226 37 | "bert-large-uncased": 18, # 0.7210358126642836 38 | "bert-base-cased-finetuned-mrpc": 9, # 0.6721947475618048 39 | "bert-base-multilingual-cased": 9, # 0.6680687802637132 40 | "bert-base-chinese": 8, 41 | "roberta-base": 10, # 0.706288719158983 42 | "roberta-large": 17, # 0.7385974720781534 43 | "roberta-large-mnli": 19, # 0.7535618640417984 44 | "roberta-base-openai-detector": 7, # 0.7048158349432633 45 | "roberta-large-openai-detector": 15, # 0.7462770207355116 46 | "xlnet-base-cased": 5, # 0.6630103662114238 47 | "xlnet-large-cased": 7, # 0.6598800720297179 48 | "xlm-mlm-en-2048": 6, # 0.651262570131464 49 | "xlm-mlm-100-1280": 10, # 0.6475166424401905 50 | "scibert-scivocab-uncased": 8, # 0.6590354319927313 51 | "scibert-scivocab-cased": 9, # 0.6536375053937445 52 | "scibert-basevocab-uncased": 9, # 0.6748944832703548 53 | "scibert-basevocab-cased": 9, # 0.6524624150542374 54 | "distilroberta-base": 5, # 0.6797558139322964 55 | "distilbert-base-uncased": 5, # 0.6756659152782033 56 | "distilbert-base-uncased-distilled-squad": 4, # 0.6718318036382493 57 | "distilbert-base-multilingual-cased": 5, # 0.6178131050889238 58 | "albert-base-v1": 10, # 0.654237567249745 59 | "albert-large-v1": 17, # 0.6755890754323239 60 | "albert-xlarge-v1": 16, # 0.7031844211905911 61 | "albert-xxlarge-v1": 8, # 0.7508642218461096 62 | "albert-base-v2": 9, # 0.6682455591837927 63 | "albert-large-v2": 14, # 0.7008537594374035 64 | "albert-xlarge-v2": 13, # 0.7317228357869254 65 | "albert-xxlarge-v2": 8, # 0.7505160257184014 66 | "xlm-roberta-base": 9, # 0.6506799445871697 67 | "xlm-roberta-large": 17, # 0.6941551437476826 68 | "google/electra-small-generator": 9, # 0.6659421842117754 69 | "google/electra-small-discriminator": 11, # 0.6534639151385759 70 | "google/electra-base-generator": 10, # 0.6730033453857188 71 | "google/electra-base-discriminator": 9, # 0.7032089590812965 72 | "google/electra-large-generator": 18, # 0.6813370013104459 73 | "google/electra-large-discriminator": 14, # 0.6896675824733477 74 | "google/bert_uncased_L-2_H-128_A-2": 1, # 0.5887998733228855 75 | "google/bert_uncased_L-2_H-256_A-4": 1, # 0.6114863547661203 76 | "google/bert_uncased_L-2_H-512_A-8": 1, # 0.6177345529192847 77 | "google/bert_uncased_L-2_H-768_A-12": 2, # 0.6191261237956839 78 | "google/bert_uncased_L-4_H-128_A-2": 3, # 0.6076202863798991 79 | "google/bert_uncased_L-4_H-256_A-4": 3, # 0.6205239036810148 80 | "google/bert_uncased_L-4_H-512_A-8": 3, # 0.6375351621856903 81 | "google/bert_uncased_L-4_H-768_A-12": 3, # 0.6561849979644787 82 | "google/bert_uncased_L-6_H-128_A-2": 5, # 0.6200458425360283 83 | "google/bert_uncased_L-6_H-256_A-4": 5, # 0.6277501629539081 84 | "google/bert_uncased_L-6_H-512_A-8": 5, # 0.641952305130849 85 | "google/bert_uncased_L-6_H-768_A-12": 5, # 0.6762186226247106 86 | "google/bert_uncased_L-8_H-128_A-2": 7, # 0.6186876506711779 87 | "google/bert_uncased_L-8_H-256_A-4": 7, # 0.6447993208267708 88 | "google/bert_uncased_L-8_H-512_A-8": 6, # 0.6489729408169956 89 | "google/bert_uncased_L-8_H-768_A-12": 7, # 0.6705203359541737 90 | "google/bert_uncased_L-10_H-128_A-2": 8, # 0.6126762064125278 91 | "google/bert_uncased_L-10_H-256_A-4": 8, # 0.6376350032576573 92 | "google/bert_uncased_L-10_H-512_A-8": 9, # 0.6579006292799915 93 | "google/bert_uncased_L-10_H-768_A-12": 8, # 0.6861146692220176 94 | "google/bert_uncased_L-12_H-128_A-2": 10, # 0.6184105693383591 95 | "google/bert_uncased_L-12_H-256_A-4": 11, # 0.6374004994430261 96 | "google/bert_uncased_L-12_H-512_A-8": 10, # 0.65880012149526 97 | "google/bert_uncased_L-12_H-768_A-12": 9, # 0.675911357700092 98 | } 99 | 100 | 101 | def sent_encode(tokenizer, sent): 102 | "Encoding as sentence based on the tokenizer" 103 | sent = sent.strip() 104 | if sent == "": 105 | return tokenizer.build_inputs_with_special_tokens([]) 106 | elif isinstance(tokenizer, GPT2Tokenizer): 107 | # for RoBERTa and GPT-2 108 | import transformers 109 | 110 | if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): 111 | return tokenizer.encode( 112 | sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len, truncation=True 113 | ) 114 | else: 115 | return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len) 116 | else: 117 | import transformers 118 | 119 | if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): 120 | return tokenizer.encode(sent, add_special_tokens=True, truncation=True) 121 | else: 122 | return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len) 123 | 124 | 125 | def get_model(model_type, num_layers, all_layers=None): 126 | print(model_type) 127 | if model_type.startswith("scibert"): 128 | model = AutoModel.from_pretrained(cache_scibert(model_type)) 129 | else: 130 | model = AutoModel.from_pretrained(model_type) 131 | model.eval() 132 | 133 | # drop unused layers 134 | if not all_layers: 135 | if hasattr(model, "n_layers"): # xlm 136 | assert ( 137 | 0 <= num_layers <= model.n_layers 138 | ), f"Invalid num_layers: num_layers should be between 0 and {model.n_layers} for {model_type}" 139 | model.n_layers = num_layers 140 | elif hasattr(model, "layer"): # xlnet 141 | assert ( 142 | 0 <= num_layers <= len(model.layer) 143 | ), f"Invalid num_layers: num_layers should be between 0 and {len(model.layer)} for {model_type}" 144 | model.layer = torch.nn.ModuleList([layer for layer in model.layer[:num_layers]]) 145 | elif hasattr(model, "encoder"): # albert 146 | if hasattr(model.encoder, "albert_layer_groups"): 147 | assert ( 148 | 0 <= num_layers <= model.encoder.config.num_hidden_layers 149 | ), f"Invalid num_layers: num_layers should be between 0 and {model.encoder.config.num_hidden_layers} for {model_type}" 150 | model.encoder.config.num_hidden_layers = num_layers 151 | else: # bert, roberta 152 | assert ( 153 | 0 <= num_layers <= len(model.encoder.layer) 154 | ), f"Invalid num_layers: num_layers should be between 0 and {len(model.encoder.layer)} for {model_type}" 155 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]]) 156 | elif hasattr(model, "transformer"): # bert, roberta 157 | assert ( 158 | 0 <= num_layers <= len(model.transformer.layer) 159 | ), f"Invalid num_layers: num_layers should be between 0 and {len(model.transformer.layer)} for {model_type}" 160 | model.transformer.layer = torch.nn.ModuleList([layer for layer in model.transformer.layer[:num_layers]]) 161 | else: 162 | raise ValueError("Not supported") 163 | else: 164 | if hasattr(model, "output_hidden_states"): 165 | model.output_hidden_states = True 166 | elif hasattr(model, "encoder"): 167 | model.encoder.output_hidden_states = True 168 | elif hasattr(model, "transformer"): 169 | model.transformer.output_hidden_states = True 170 | else: 171 | raise ValueError(f"Not supported model architecture: {model_type}") 172 | 173 | return model 174 | 175 | 176 | def padding(arr, pad_token, dtype=torch.long): 177 | lens = torch.LongTensor([len(a) for a in arr]) 178 | max_len = lens.max().item() 179 | padded = torch.ones(len(arr), max_len, dtype=dtype) * pad_token 180 | mask = torch.zeros(len(arr), max_len, dtype=torch.long) 181 | for i, a in enumerate(arr): 182 | padded[i, : lens[i]] = torch.tensor(a, dtype=dtype) 183 | mask[i, : lens[i]] = 1 184 | return padded, lens, mask 185 | 186 | 187 | def bert_encode(model, x, attention_mask, all_layers=False): 188 | model.eval() 189 | with torch.no_grad(): 190 | out = model(x, attention_mask=attention_mask) 191 | if all_layers: 192 | emb = torch.stack(out[-1], dim=2) 193 | else: 194 | emb = out[0] 195 | return emb 196 | 197 | 198 | def process(a, tokenizer=None): 199 | if tokenizer is not None: 200 | a = sent_encode(tokenizer, a) 201 | return set(a) 202 | 203 | 204 | def get_idf_dict(arr, tokenizer, nthreads=4): 205 | """ 206 | Returns mapping from word piece index to its inverse document frequency. 207 | 208 | 209 | Args: 210 | - :param: `arr` (list of str) : sentences to process. 211 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`. 212 | - :param: `nthreads` (int) : number of CPU threads to use 213 | """ 214 | idf_count = Counter() 215 | num_docs = len(arr) 216 | 217 | process_partial = partial(process, tokenizer=tokenizer) 218 | 219 | with Pool(nthreads) as p: 220 | idf_count.update(chain.from_iterable(p.map(process_partial, arr))) 221 | 222 | idf_dict = defaultdict(lambda: log((num_docs + 1) / (1))) 223 | idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()}) 224 | return idf_dict 225 | 226 | 227 | def collate_idf(arr, tokenizer, idf_dict, device="cuda:0"): 228 | """ 229 | Helper function that pads a list of sentences to hvae the same length and 230 | loads idf score for words in the sentences. 231 | 232 | Args: 233 | - :param: `arr` (list of str): sentences to process. 234 | - :param: `tokenize` : a function that takes a string and return list 235 | of tokens. 236 | - :param: `numericalize` : a function that takes a list of tokens and 237 | return list of token indexes. 238 | - :param: `idf_dict` (dict): mapping a word piece index to its 239 | inverse document frequency 240 | - :param: `pad` (str): the padding token. 241 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda' 242 | """ 243 | arr = [sent_encode(tokenizer, a) for a in arr] 244 | 245 | idf_weights = [[idf_dict[i] for i in a] for a in arr] 246 | 247 | pad_token = tokenizer.pad_token_id 248 | 249 | padded, lens, mask = padding(arr, pad_token, dtype=torch.long) 250 | padded_idf, _, _ = padding(idf_weights, 0, dtype=torch.float) 251 | 252 | padded = padded.to(device=device) 253 | mask = mask.to(device=device) 254 | lens = lens.to(device=device) 255 | return padded, padded_idf, lens, mask 256 | 257 | 258 | def get_bert_embedding(all_sens, model, tokenizer, idf_dict, batch_size=-1, device="cuda:0", all_layers=False): 259 | """ 260 | Compute BERT embedding in batches. 261 | 262 | Args: 263 | - :param: `all_sens` (list of str) : sentences to encode. 264 | - :param: `model` : a BERT model from `pytorch_pretrained_bert`. 265 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`. 266 | - :param: `idf_dict` (dict) : mapping a word piece index to its 267 | inverse document frequency 268 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda' 269 | """ 270 | 271 | padded_sens, padded_idf, lens, mask = collate_idf(all_sens, tokenizer, idf_dict, device=device) 272 | 273 | if batch_size == -1: 274 | batch_size = len(all_sens) 275 | 276 | embeddings = [] 277 | with torch.no_grad(): 278 | for i in range(0, len(all_sens), batch_size): 279 | batch_embedding = bert_encode( 280 | model, padded_sens[i : i + batch_size], attention_mask=mask[i : i + batch_size], all_layers=all_layers 281 | ) 282 | embeddings.append(batch_embedding) 283 | del batch_embedding 284 | 285 | total_embedding = torch.cat(embeddings, dim=0) 286 | 287 | return total_embedding, mask, padded_idf 288 | 289 | 290 | def greedy_cos_idf(ref_embedding, ref_masks, ref_idf, hyp_embedding, hyp_masks, hyp_idf, all_layers=False): 291 | """ 292 | Compute greedy matching based on cosine similarity. 293 | 294 | Args: 295 | - :param: `ref_embedding` (torch.Tensor): 296 | embeddings of reference sentences, BxKxd, 297 | B: batch size, K: longest length, d: bert dimenison 298 | - :param: `ref_lens` (list of int): list of reference sentence length. 299 | - :param: `ref_masks` (torch.LongTensor): BxKxK, BERT attention mask for 300 | reference sentences. 301 | - :param: `ref_idf` (torch.Tensor): BxK, idf score of each word 302 | piece in the reference setence 303 | - :param: `hyp_embedding` (torch.Tensor): 304 | embeddings of candidate sentences, BxKxd, 305 | B: batch size, K: longest length, d: bert dimenison 306 | - :param: `hyp_lens` (list of int): list of candidate sentence length. 307 | - :param: `hyp_masks` (torch.LongTensor): BxKxK, BERT attention mask for 308 | candidate sentences. 309 | - :param: `hyp_idf` (torch.Tensor): BxK, idf score of each word 310 | piece in the candidate setence 311 | """ 312 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1)) 313 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1)) 314 | 315 | if all_layers: 316 | B, _, L, D = hyp_embedding.size() 317 | hyp_embedding = hyp_embedding.transpose(1, 2).transpose(0, 1).contiguous().view(L * B, hyp_embedding.size(1), D) 318 | ref_embedding = ref_embedding.transpose(1, 2).transpose(0, 1).contiguous().view(L * B, ref_embedding.size(1), D) 319 | batch_size = ref_embedding.size(0) 320 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)) 321 | masks = torch.bmm(hyp_masks.unsqueeze(2).float(), ref_masks.unsqueeze(1).float()) 322 | if all_layers: 323 | masks = masks.unsqueeze(0).expand(L, -1, -1, -1).contiguous().view_as(sim) 324 | else: 325 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim) 326 | 327 | masks = masks.float().to(sim.device) 328 | sim = sim * masks 329 | 330 | word_precision = sim.max(dim=2)[0] 331 | word_recall = sim.max(dim=1)[0] 332 | 333 | hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True)) 334 | ref_idf.div_(ref_idf.sum(dim=1, keepdim=True)) 335 | precision_scale = hyp_idf.to(word_precision.device) 336 | recall_scale = ref_idf.to(word_recall.device) 337 | if all_layers: 338 | precision_scale = precision_scale.unsqueeze(0).expand(L, B, -1).contiguous().view_as(word_precision) 339 | recall_scale = recall_scale.unsqueeze(0).expand(L, B, -1).contiguous().view_as(word_recall) 340 | P = (word_precision * precision_scale).sum(dim=1) 341 | R = (word_recall * recall_scale).sum(dim=1) 342 | F = 2 * P * R / (P + R) 343 | 344 | hyp_zero_mask = hyp_masks.sum(dim=1).eq(2) 345 | ref_zero_mask = ref_masks.sum(dim=1).eq(2) 346 | 347 | if all_layers: 348 | P = P.view(L, B) 349 | R = R.view(L, B) 350 | F = F.view(L, B) 351 | 352 | if torch.any(hyp_zero_mask): 353 | print("Warning: Empty candidate sentence detected; setting precision to be 0.", file=sys.stderr) 354 | P = P.masked_fill(hyp_zero_mask, 0.0) 355 | 356 | if torch.any(ref_zero_mask): 357 | print("Warning: Empty reference sentence detected; setting recall to be 0.", file=sys.stderr) 358 | R = R.masked_fill(ref_zero_mask, 0.0) 359 | 360 | F = F.masked_fill(torch.isnan(F), 0.0) 361 | 362 | return P, R, F 363 | 364 | 365 | def bert_cos_score_idf( 366 | model, refs, hyps, tokenizer, idf_dict, verbose=False, batch_size=64, device="cuda:0", all_layers=False 367 | ): 368 | """ 369 | Compute BERTScore. 370 | 371 | Args: 372 | - :param: `model` : a BERT model in `pytorch_pretrained_bert` 373 | - :param: `refs` (list of str): reference sentences 374 | - :param: `hyps` (list of str): candidate sentences 375 | - :param: `tokenzier` : a BERT tokenizer corresponds to `model` 376 | - :param: `idf_dict` : a dictionary mapping a word piece index to its 377 | inverse document frequency 378 | - :param: `verbose` (bool): turn on intermediate status update 379 | - :param: `batch_size` (int): bert score processing batch size 380 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda' 381 | """ 382 | preds = [] 383 | 384 | def dedup_and_sort(l): 385 | return sorted(list(set(l)), key=lambda x: len(x.split(" ")), reverse=True) 386 | 387 | sentences = dedup_and_sort(refs + hyps) 388 | embs = [] 389 | iter_range = range(0, len(sentences), batch_size) 390 | print(len(sentences)) 391 | if verbose: 392 | print("computing bert embedding.") 393 | iter_range = tqdm(iter_range) 394 | stats_dict = dict() 395 | for batch_start in tqdm(iter_range): 396 | sen_batch = sentences[batch_start : batch_start + batch_size] 397 | embs, masks, padded_idf = get_bert_embedding( 398 | sen_batch, model, tokenizer, idf_dict, device=device, all_layers=all_layers 399 | ) 400 | embs = embs.cpu() 401 | masks = masks.cpu() 402 | padded_idf = padded_idf.cpu() 403 | for i, sen in enumerate(sen_batch): 404 | sequence_len = masks[i].sum().item() 405 | emb = embs[i, :sequence_len] 406 | idf = padded_idf[i, :sequence_len] 407 | stats_dict[sen] = (emb, idf) 408 | 409 | def pad_batch_stats(sen_batch, stats_dict, device): 410 | stats = [stats_dict[s] for s in sen_batch] 411 | emb, idf = zip(*stats) 412 | emb = [e.to(device) for e in emb] 413 | idf = [i.to(device) for i in idf] 414 | lens = [e.size(0) for e in emb] 415 | emb_pad = pad_sequence(emb, batch_first=True, padding_value=2.0) 416 | idf_pad = pad_sequence(idf, batch_first=True) 417 | 418 | def length_to_mask(lens): 419 | lens = torch.tensor(lens, dtype=torch.long) 420 | max_len = max(lens) 421 | base = torch.arange(max_len, dtype=torch.long).expand(len(lens), max_len) 422 | return base < lens.unsqueeze(1) 423 | 424 | pad_mask = length_to_mask(lens).to(device) 425 | return emb_pad, pad_mask, idf_pad 426 | 427 | device = next(model.parameters()).device 428 | iter_range = range(0, len(refs), batch_size) 429 | if verbose: 430 | print("computing greedy matching.") 431 | iter_range = tqdm(iter_range) 432 | 433 | with torch.no_grad(): 434 | for batch_start in iter_range: 435 | batch_refs = refs[batch_start : batch_start + batch_size] 436 | batch_hyps = hyps[batch_start : batch_start + batch_size] 437 | ref_stats = pad_batch_stats(batch_refs, stats_dict, device) 438 | hyp_stats = pad_batch_stats(batch_hyps, stats_dict, device) 439 | 440 | P, R, F1 = greedy_cos_idf(*ref_stats, *hyp_stats, all_layers) 441 | preds.append(torch.stack((P, R, F1), dim=-1).cpu()) 442 | preds = torch.cat(preds, dim=1 if all_layers else 0) 443 | return preds 444 | 445 | 446 | def get_hash(model, num_layers, idf, rescale_with_baseline, use_custom_baseline): 447 | msg = "{}_L{}{}_version={}(hug_trans={})".format( 448 | model, num_layers, "_idf" if idf else "_no-idf", __version__, trans_version 449 | ) 450 | if rescale_with_baseline: 451 | if use_custom_baseline: 452 | msg += "-custom-rescaled" 453 | else: 454 | msg += "-rescaled" 455 | return msg 456 | 457 | 458 | def cache_scibert(model_type, cache_folder="~/.cache/torch/transformers"): 459 | if not model_type.startswith("scibert"): 460 | return model_type 461 | 462 | underscore_model_type = model_type.replace("-", "_") 463 | cache_folder = os.path.abspath(cache_folder) 464 | filename = os.path.join(cache_folder, underscore_model_type) 465 | 466 | # download SciBERT models 467 | if not os.path.exists(filename): 468 | cmd = f"mkdir -p {cache_folder}; cd {cache_folder};" 469 | cmd += f"wget {SCIBERT_URL_DICT[model_type]}; tar -xvf {underscore_model_type}.tar;" 470 | cmd += ( 471 | f"rm -f {underscore_model_type}.tar ; cd {underscore_model_type}; tar -zxvf weights.tar.gz; mv weights/* .;" 472 | ) 473 | cmd += f"rm -f weights.tar.gz; rmdir weights; mv bert_config.json config.json;" 474 | print(cmd) 475 | print(f"downloading {model_type} model") 476 | os.system(cmd) 477 | 478 | # fix the missing files in scibert 479 | json_file = os.path.join(filename, "special_tokens_map.json") 480 | if not os.path.exists(json_file): 481 | with open(json_file, "w") as f: 482 | print( 483 | '{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}', 484 | file=f, 485 | ) 486 | 487 | json_file = os.path.join(filename, "added_tokens.json") 488 | if not os.path.exists(json_file): 489 | with open(json_file, "w") as f: 490 | print("{}", file=f) 491 | 492 | if "uncased" in model_type: 493 | json_file = os.path.join(filename, "tokenizer_config.json") 494 | if not os.path.exists(json_file): 495 | with open(json_file, "w") as f: 496 | print('{"do_lower_case": true, "max_len": 512, "init_inputs": []}', file=f) 497 | 498 | return filename 499 | --------------------------------------------------------------------------------