├── model
├── __init__.py
├── postprocessing.py
├── loss.py
├── entailment_score.py
├── utils.py
├── seq2seq_vocab.py
└── optim.py
├── metrics
├── 3rdparty
│ └── .create
├── pred.txt
├── ref.txt
├── test_metrics.py
├── README.md
├── tokenizers.py
├── multi-bleu.perl
└── dstc_example.py
├── bert_score
├── __init__.py
├── rescale_baseline
│ └── en
│ │ └── roberta-large.tsv
├── score.py
├── scorer.py
└── utils.py
├── Software.zip
├── framework.PNG
├── BT
├── average_model.sh
├── evaluate_back.sh
├── preprocess.sh
├── evaluate.sh
├── train_en-fr.sh
├── train_fr-en.sh
├── bpe_split.py
├── get_bt_input_file.py
└── recover.py
├── multi_gpt2
└── config.json
├── cal_novelty.py
├── inference_multi_gpt2.sh
├── train_gpt2.sh
├── train_multi_gpt2_distilled.sh
├── train_seq2seq.sh
├── data_manipulation
├── prepare_model
│ ├── extract_personas_and_responses.py
│ ├── finetune_bert_and_gpt2.py
│ ├── train_nli_model.py
│ └── train_coherence_nli.py
├── data_diversification
│ ├── get_augmented_scores.py
│ └── filter_augmented_data.py
└── data_distillation
│ ├── get_distilled_dataset.py
│ └── calculate_entailment.py
├── train_gpt2_D3.sh
├── train_seq2seq_D3.sh
├── attention_experiment
├── get_target_samples.py
└── attention_exp.py
├── README.md
├── inference.py
└── config.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/metrics/3rdparty/.create:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/metrics/pred.txt:
--------------------------------------------------------------------------------
1 | i lkie this thing
2 |
--------------------------------------------------------------------------------
/metrics/ref.txt:
--------------------------------------------------------------------------------
1 | i do not like this
2 |
--------------------------------------------------------------------------------
/bert_score/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.3.6"
2 |
--------------------------------------------------------------------------------
/Software.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/caoyu-noob/D3/HEAD/Software.zip
--------------------------------------------------------------------------------
/framework.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/caoyu-noob/D3/HEAD/framework.PNG
--------------------------------------------------------------------------------
/BT/average_model.sh:
--------------------------------------------------------------------------------
1 | python ../fairseq/scripts/average_checkpoints.py \
2 | --input checkpoints/enfr/ \
3 | --num-epoch-checkpoints 5 \
4 | --output ./checkpoints/enfr/model_avg.pt \
5 |
--------------------------------------------------------------------------------
/BT/evaluate_back.sh:
--------------------------------------------------------------------------------
1 | sub_num=$1
2 | CUDA_VISIBLE_DEVICES=$sub_num python ../fairseq/interactive.py ./corpus/fren/ \
3 | -s fr -t en \
4 | --path ./checkpoints/fren/model_avg.pt \
5 | --beam 5 \
6 | --nbest 5 \
7 | --batch-size 128 \
8 | --buffer-size 8000 \
9 | --input ./en-fr.out$sub_num > fr-en.log$sub_num \
10 |
--------------------------------------------------------------------------------
/BT/preprocess.sh:
--------------------------------------------------------------------------------
1 | ROOT=./
2 | SRC=en
3 | TGT=fr
4 | DICT=./
5 |
6 | python ../fairseq/preprocess.py \
7 | --source-lang en \
8 | --target-lang fr \
9 | --trainpref train \
10 | --validpref valid \
11 | --testpref test \
12 | --destdir ../corpus/enfr \
13 | --thresholdtgt 0 \
14 | --thresholdsrc 0 \
15 | --workers 90
--------------------------------------------------------------------------------
/BT/evaluate.sh:
--------------------------------------------------------------------------------
1 | sub_num=$1
2 | CUDA_VISIBLE_DEVICES=$sub_num python ../fairseq/interactive.py ./corpus/enfr/ \
3 | -s en -t fr \
4 | --path ./checkpoints/enfr/model_avg.pt \
5 | --beam 5 \
6 | --nbest 5 \
7 | --batch-size 128 \
8 | --buffer-size 8000 \
9 | --input ./th0.99_entail_dev_bpe.txt$sub_num > en-fr.log$sub_num \
10 |
--------------------------------------------------------------------------------
/multi_gpt2/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "GPT2LMHeadModel"
4 | ],
5 | "initializer_range": 0.02,
6 | "layer_norm_epsilon": 1e-05,
7 | "n_ctx": 1024,
8 | "n_embd": 768,
9 | "n_head": 12,
10 | "n_layer": 12,
11 | "n_positions": 1024,
12 | "vocab_size": 50257,
13 | "shared_module": false,
14 | "shared_attention": false,
15 | "context_size": 2
16 | }
--------------------------------------------------------------------------------
/cal_novelty.py:
--------------------------------------------------------------------------------
1 | import json
2 | from metrics import cal_novelty
3 |
4 | with open('original_utterances.json', 'r') as f:
5 | original_utterances = json.load(f)
6 | with open('new_utterances.json', 'r') as f:
7 | new_utterances = json.load(f)
8 | un1, un2, un3, un4 = cal_novelty(original_utterances, new_utterances[40000:50000])
9 | print(un1)
10 | print(un2)
11 | print(un3)
12 | print(un4)
--------------------------------------------------------------------------------
/inference_multi_gpt2.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python inference.py \
2 | --test_datasets augmented_data/th0.99_self_without_response.json \
3 | --test_datasets_cache augmented_data/th0.99_self_without_response_gpt2_cache \
4 | --generate_file_name augmented_data/th0.99_self_with_generated_response.json \
5 | --data_type entailment \
6 | --load_last runs/multi_gpt2/best_model \
7 | --max_history_size 1 \
8 | --model_type gpt2 \
9 | --shared_module 0 \
10 | --shared_attention 0 \
11 | --attention_pooling_type sw \
12 |
--------------------------------------------------------------------------------
/metrics/test_metrics.py:
--------------------------------------------------------------------------------
1 | from metrics import *
2 | from tokenizers import *
3 |
4 | # evaluation
5 |
6 |
7 | nist, bleu, meteor, entropy, diversity, avg_len = nlp_metrics(
8 | path_refs=['demo/ref0.txt', 'demo/ref1.txt'],
9 | path_hyp='demo/hyp.txt')
10 |
11 | print(nist)
12 | print(bleu)
13 | print(meteor)
14 | print(entropy)
15 | print(diversity)
16 | print(avg_len)
17 |
18 | # tokenization
19 |
20 | s = " I don't know:). how about this?https://github.com/golsun/deep-RL-time-series"
21 | print(clean_str(s))
22 |
--------------------------------------------------------------------------------
/train_gpt2.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train.py \
2 | --train_batch_size 256 \
3 | --batch_split 32 \
4 | --n_epochs 15 \
5 | --lr 6.25e-5 \
6 | --train_datasets datasets/train_self_original.txt \
7 | --valid_datasets datasets/valid_self_original.txt \
8 | --test_datasets datasets/test_self_original.txt \
9 | --train_datasets_cache datasets/th0.99_entail_gpt2_self_cache \
10 | --valid_datasets_cache datasets/valid_gpt2_self_cache \
11 | --test_datasets_cache datasets/test_gpt2_self_cache \
12 | --model_type gpt2 \
13 | --single_input \
14 | --hits_weight 1 \
15 | --s2s_weight 1 \
16 | --negative_samples 1 \
17 | --patience 5 \
18 | --max_history_size 5 \
19 | --entail_score_refs_file ./persona_test_entailment_idx.json \
20 |
--------------------------------------------------------------------------------
/BT/train_en-fr.sh:
--------------------------------------------------------------------------------
1 | python ../fairseq/train.py ./enfr/ \
2 | --arch transformer_wmt_en_de \
3 | --criterion label_smoothed_cross_entropy \
4 | --label-smoothing 0.1 \
5 | --lr 7e-4 \
6 | --warmup-init-lr 1e-7 \
7 | --min-lr 1e-9 \
8 | --lr-scheduler inverse_sqrt \
9 | --warmup-updates 4000 \
10 | --optimizer adam \
11 | --adam-betas '(0.9,0.98)' \
12 | --adam-eps 1e-6 \
13 | -s en \
14 | -t fr \
15 | --max-tokens 8192 \
16 | --max-update 100000 \
17 | --weight-decay 0.01 \
18 | --seed 0 \
19 | --save-dir ./checkpoints/enfr/ \
20 | --ddp-backend=no_c10d \
21 | --fp16 \
22 | --update-freq 8 \
23 | --dropout 0.3 \
24 | --no-progress-bar \
25 | --log-format simple \
26 | --log-interval 50 \
27 | --save-interval-updates 4000 \
28 | --share-decoder-input-output-embed \
29 |
--------------------------------------------------------------------------------
/BT/train_fr-en.sh:
--------------------------------------------------------------------------------
1 | python ../fairseq/train.py ./fren/ \
2 | --arch transformer_wmt_en_de \
3 | --criterion label_smoothed_cross_entropy \
4 | --label-smoothing 0.1 \
5 | --lr 7e-4 \
6 | --warmup-init-lr 1e-7 \
7 | --min-lr 1e-9 \
8 | --lr-scheduler inverse_sqrt \
9 | --warmup-updates 4000 \
10 | --optimizer adam \
11 | --adam-betas '(0.9,0.98)' \
12 | --adam-eps 1e-6 \
13 | -s fr \
14 | -t en \
15 | --max-tokens 8192 \
16 | --max-update 100000 \
17 | --weight-decay 0.01 \
18 | --seed 0 \
19 | --save-dir ./checkpoints/fren/ \
20 | --ddp-backend=no_c10d \
21 | --fp16 \
22 | --update-freq 8 \
23 | --dropout 0.3 \
24 | --no-progress-bar \
25 | --log-format simple \
26 | --log-interval 50 \
27 | --save-interval-updates 4000 \
28 | --share-decoder-input-output-embed \
29 |
--------------------------------------------------------------------------------
/train_multi_gpt2_distilled.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train.py \
2 | --train_batch_size 256 \
3 | --batch_split 32 \
4 | --lr 5e-4 \
5 | --train_datasets datasets/th0.99_model_entail_train_data.json \
6 | --valid_datasets datasets/th0.99_model_entail_dev_data.json \
7 | --test_datasets datasets/th0.99_model_entail_dev_data.json \
8 | --train_datasets_cache datasets/th0.99_model_entail_train_gpt2_cache \
9 | --valid_datasets_cache datasets/th0.99_model_entail_valid_gpt2_cache \
10 | --test_datasets_cache datasets/th0.99_model_entail_valid_gpt2_cache \
11 | --model_type gpt2 \
12 | --data_type entailment \
13 | --shared_module 0 \
14 | --shared_attention 0 \
15 | --n_epochs 5 \
16 | --attention_pooling_type sw \
17 | --extra_module_lr_rate 5.0 \
18 | --max_history_size 1 \
19 |
--------------------------------------------------------------------------------
/BT/bpe_split.py:
--------------------------------------------------------------------------------
1 | import sentencepiece as spm
2 | import sys
3 |
4 | split_num = int(sys.argv[1])
5 | sp=spm.SentencePieceProcessor()
6 | sp.load('sentence.bpe.model')
7 |
8 | with open('th0.99_entail_dev.txt', 'r') as f:
9 | lines = f.readlines()
10 | new_lines = []
11 | for line in lines:
12 | nl = sp.EncodeAsPieces(line.strip())
13 | new_lines.append(' '.join(nl) + '\n')
14 | file_name = 'th0.99_entail_dev_bpe.txt'
15 | avg_size = len(new_lines) // split_num
16 | start = 0
17 | for i in range(split_num):
18 | if i != split_num - 1:
19 | sub_lines = new_lines[start:start + avg_size]
20 | else:
21 | sub_lines = new_lines[start:]
22 | start += avg_size
23 | with open(file_name + str(i), 'w') as f:
24 | f.writelines(sub_lines)
25 |
26 |
--------------------------------------------------------------------------------
/train_seq2seq.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train.py \
2 | --train_batch_size 256 \
3 | --batch_split 8 \
4 | --n_epochs 100 \
5 | --lr 2e-4 \
6 | --train_datasets datasets/train_self_original.txt \
7 | --valid_datasets datasets/valid_self_original.txt \
8 | --test_datasets datasets/test_self_original.txt \
9 | --train_datasets_cache datasets/train_seq2seq_self_cache \
10 | --valid_datasets_cache datasets/valid_seq2seq_self_cache \
11 | --test_datasets_cache datasets/test_seq2seq_self_cache \
12 | --vocab_path datasets/vocab/persona_self_vocab.bin \
13 | --model_type seq2seq \
14 | --single_input \
15 | --pointer_gen \
16 | --label_smoothing 0.1 \
17 | --s2s_weight 1 \
18 | --negative_samples 0 \
19 | --max_history_size 5 \
20 | --patience 15 \
21 | --entail_score_refs_file ./persona_test_entailment_idx.json \
22 |
--------------------------------------------------------------------------------
/bert_score/rescale_baseline/en/roberta-large.tsv:
--------------------------------------------------------------------------------
1 | LAYER,P,R,F
2 | 0,0.3712891,0.37132213,0.36826715
3 | 1,0.67176163,0.6717439,0.6703483
4 | 2,0.70031923,0.7003052,0.69969934
5 | 3,0.7080897,0.7081011,0.707698
6 | 4,0.6976306,0.69762677,0.69710517
7 | 5,0.7187199,0.71873325,0.71828526
8 | 6,0.74678195,0.74678224,0.74642223
9 | 7,0.7772428,0.7772184,0.77691925
10 | 8,0.8021733,0.8021747,0.8019093
11 | 9,0.8067641,0.80678225,0.8065291
12 | 10,0.8366976,0.8367098,0.8364913
13 | 11,0.8163513,0.816369,0.8161064
14 | 12,0.8175406,0.8175611,0.81728977
15 | 13,0.82106245,0.8210674,0.82080233
16 | 14,0.81487834,0.8148861,0.8145652
17 | 15,0.8243552,0.8243522,0.8240494
18 | 16,0.8341641,0.8341684,0.833912
19 | 17,0.83150584,0.8314941,0.83122575
20 | 18,0.8314624,0.83146274,0.8311686
21 | 19,0.82761073,0.8276117,0.8273196
22 | 20,0.799873,0.79988,0.79956234
23 | 21,0.8082163,0.80819315,0.8079286
24 | 22,0.83196104,0.83195347,0.83174026
25 | 23,0.8408042,0.8408027,0.8405716
26 | 24,0.96022236,0.96021587,0.960168
--------------------------------------------------------------------------------
/model/postprocessing.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 |
4 | import nltk
5 | from nltk.corpus import wordnet
6 |
7 |
8 | def augment_replica(seq):
9 | _exceptions = ['your', 'persona']
10 | pos2wn = {'NN': wordnet.NOUN,
11 | 'JJ': wordnet.ADJ,
12 | 'VBP': wordnet.VERB,
13 | 'RB': wordnet.ADV}
14 |
15 | synonyms = defaultdict(list)
16 |
17 | tagged_seq = seq.replace('i ', 'I ')
18 | tagged_seq = nltk.pos_tag(nltk.word_tokenize(tagged_seq))
19 |
20 | for word, pos in tagged_seq:
21 | if pos not in pos2wn or word in _exceptions:
22 | continue
23 |
24 | pos = pos2wn[pos]
25 | synnets = wordnet.synsets(word, pos=pos)
26 |
27 | for synnet in synnets:
28 | for syn in synnet.lemma_names():
29 | if syn != word:
30 | synonyms[word].append(syn.replace('_', ' '))
31 | break
32 | if synonyms:
33 | for key, values in synonyms.items():
34 | seq = seq.replace(key, random.choice(list(values)))
35 |
36 | return seq
--------------------------------------------------------------------------------
/data_manipulation/prepare_model/extract_personas_and_responses.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | INPUT_FILE = '../../datasets/train_self_original.txt'
4 |
5 | with open(INPUT_FILE, 'r') as f:
6 | lines = f.readlines()
7 | personas, utterances = [], []
8 | for line in lines:
9 | line = line.strip()
10 | if len(line) == 0:
11 | continue
12 | space_idx = line.find(' ')
13 | if space_idx == -1:
14 | dialog_idx = int(line)
15 | else:
16 | dialog_idx = int(line[:space_idx])
17 | dialog_line = line[space_idx + 1:].split('\t')
18 | dialog_line = [l.strip() for l in dialog_line]
19 |
20 | if dialog_line[0].startswith('your persona:'):
21 | persona_info = dialog_line[0].replace('your persona: ', '')
22 | if persona_info[-1] == '.' and persona_info[-2] != ' ':
23 | persona_info = persona_info[:-1] + ' .'
24 | personas.append(persona_info)
25 | elif len(dialog_line) > 1:
26 | utterances.append(dialog_line[1])
27 |
28 | with open('personas.json', 'w') as f:
29 | json.dump(list(set(personas)), f)
30 | with open('responses.json', 'w') as f:
31 | json.dump(list(set(utterances)), f)
--------------------------------------------------------------------------------
/train_gpt2_D3.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train.py \
2 | --train_batch_size 256 \
3 | --batch_split 32 \
4 | --n_epochs 15 \
5 | --lr 6.25e-5 \
6 | --train_datasets datasets/train_self_original.txt \
7 | --valid_datasets datasets/valid_self_original.txt \
8 | --test_datasets datasets/test_self_original.txt \
9 | --train_datasets_cache datasets/train_gpt2_self_cache \
10 | --valid_datasets_cache datasets/valid_gpt2_self_cache \
11 | --test_datasets_cache datasets/test_gpt2_self_cache \
12 | --model_type gpt2 \
13 | --single_input \
14 | --hits_weight 1 \
15 | --s2s_weight 1 \
16 | --negative_samples 1 \
17 | --patience 5 \
18 | --max_history_size 5 \
19 | --curriculum_learning \
20 | --curriculum_train_datasets datasets/augmented/th0.99_self_augmented.json \
21 | --curriculum_train_datasets_cache datasets/augmented/th0.99_self_augmented_gpt2_cache \
22 | --curriculum_valid_datasets datasets/augmented/th0.99_dev_augmented.json \
23 | --curriculum_valid_datasets_cache datasets/augmented/th0.99_dev_augmented_gpt2_cache \
24 | --curriculum_lr 6.25e-5 \
25 | --curriculum_n_epochs 15 \
26 | --curriculum_patience 5 \
27 | --curriculum_max_history_size 1 \
28 | --entail_score_refs_file .persona_test_entailment_idx.json \
29 |
--------------------------------------------------------------------------------
/train_seq2seq_D3.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train.py \
2 | --train_batch_size 256 \
3 | --batch_split 8 \
4 | --n_epochs 100 \
5 | --lr 2e-4 \
6 | --train_datasets datasets/train_self_original.txt \
7 | --valid_datasets datasets/valid_self_original.txt \
8 | --test_datasets datasets/test_self_original.txt \
9 | --train_datasets_cache datasets/train_seq2seq_self_cache \
10 | --valid_datasets_cache datasets/valid_seq2seq_self_cache \
11 | --test_datasets_cache datasets/test_seq2seq_self_cache \
12 | --curriculum_learning \
13 | --curriculum_train_datasets datasets/augmented/th0.99_self_augmented.json \
14 | --curriculum_train_datasets_cache datasets/augmented/th0.99_self_augmented_seq2seq_cache \
15 | --curriculum_valid_datasets datasets/augmented/th0.99_dev_augmented.json \
16 | --curriculum_valid_datasets_cache datasets/augmented/th0.99_dev_augmented_seq2seq_cache \
17 | --curriculum_lr 2e-4 \
18 | --curriculum_n_epochs 50 \
19 | --curriculum_patience 15 \
20 | --curriculum_max_history_size 1 \
21 | --curriculum_data_type entailment \
22 | --extend_exist_vocab datasets/vocab/persona_self_vocab.bin \
23 | --vocab_path datasets/vocab/persona_self_augmented_vocab.bin \
24 | --model_type seq2seq \
25 | --single_input \
26 | --label_smoothing 0.1 \
27 | --pointer_gen \
28 | --s2s_weight 1 \
29 | --patience 15 \
30 | --negative_samples 0 \
31 | --max_history_size 5 \
32 | --entail_score_refs_file ./persona_test_entailment_idx.json \
33 |
--------------------------------------------------------------------------------
/BT/get_bt_input_file.py:
--------------------------------------------------------------------------------
1 | import json
2 | from sys import argv
3 |
4 | # The input file of the distilled data (json file) or the original file
5 | input_file = argv[1]
6 | # The output file that only save the utterances of the input data for back-translation
7 | output_file = argv[2]
8 |
9 | new_lines = []
10 | line_idx = []
11 | if '.txt' in input_file:
12 | with open(input_file, 'r') as f:
13 | lines = f.readlines()
14 | for i, line in enumerate(lines):
15 | persona_idx = line.find('your persona: ')
16 | if persona_idx != -1 and persona_idx < 5:
17 | new_lines.append(line[persona_idx + 14:])
18 | line_idx.append(i)
19 | else:
20 | space_idx = line.find(' ')
21 | history = line[space_idx + 1:].split('\t')
22 | new_lines.append(history[0].strip() + '\n')
23 | new_lines.append(history[1].strip() + '\n')
24 | line_idx.append(i)
25 | line_idx.append(i)
26 | if '.json' in input_file:
27 | with open(input_file, 'r') as f:
28 | data = json.load(f)
29 | for i, d in enumerate(data):
30 | new_lines.append(d[0].strip() + '\n')
31 | line_idx.append([i, 0])
32 | for j, u in enumerate(d[1]):
33 | new_lines.append(u.strip() + '\n')
34 | line_idx.append([i, 3 + j])
35 | new_lines.append(d[2].strip() + '\n')
36 | line_idx.append([i, 2])
37 | with open(output_file, 'w') as f:
38 | f.writelines(new_lines)
39 | with open(output_file + '_idx.json', 'w') as f:
40 | json.dump(line_idx, f)
--------------------------------------------------------------------------------
/metrics/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Adapted from the Automatic evaluation script for DSTC7 Task 2
3 | Original repo: https://github.com/mgalley/DSTC7-End-to-End-Conversation-Modeling
4 |
5 | # To install 3rd party scripts
6 | Download the following 3rd-party packages and save in folder `3rdparty`:
7 |
8 | * [**mteval-v14c.pl**](ftp://jaguar.ncsl.nist.gov/mt/resources/mteval-v14c.pl) to compute [NIST](http://www.mt-archive.info/HLT-2002-Doddington.pdf). You may need to install the following [perl](https://www.perl.org/get.html) modules (e.g. by `cpan install`): XML:Twig, Sort:Naturally and String:Util.
9 | * [**meteor-1.5**](http://www.cs.cmu.edu/~alavie/METEOR/download/meteor-1.5.tar.gz) to compute [METEOR](http://www.cs.cmu.edu/~alavie/METEOR/index.html). It requires [Java](https://www.java.com/en/download/help/download_options.xml).
10 |
11 |
12 | # What does it do?
13 | (Based on this [repo](https://github.com/golsun/NLP-tools) by [Sean Xiang Gao](https://www.linkedin.com/in/gxiang1228/))
14 |
15 | * **evaluation**: calculate automated NLP metrics (BLEU, NIST, METEOR, entropy, etc...)
16 | ```python
17 | from metrics import nlp_metrics
18 | nist, bleu, meteor, entropy, diversity, avg_len = nlp_metrics(
19 | path_refs=["demo/ref0.txt", "demo/ref1.txt"],
20 | path_hyp="demo/hyp.txt")
21 |
22 | # nist = [1.8338, 2.0838, 2.1949, 2.1949]
23 | # bleu = [0.4667, 0.441, 0.4017, 0.3224]
24 | # meteor = 0.2832
25 | # entropy = [2.5232, 2.4849, 2.1972, 1.7918]
26 | # diversity = [0.8667, 1.000]
27 | # avg_len = 5.0000
28 | ```
29 | * **tokenization**: clean string and deal with punctation, contraction, url, mention, tag, etc
30 | ```python
31 | from tokenizers import clean_str
32 | s = " I don't know:). how about this?https://github.com"
33 | clean_str(s)
34 |
35 | # i do n't know :) . how about this ? __url__
36 | ```
37 |
--------------------------------------------------------------------------------
/metrics/tokenizers.py:
--------------------------------------------------------------------------------
1 | # author: Xiang Gao @ Microsoft Research, Oct 2018
2 | # clean and tokenize natural language text
3 |
4 | import re
5 | from util import *
6 | from nltk.tokenize import TweetTokenizer
7 |
8 | def clean_str(txt):
9 | #print("in=[%s]" % txt)
10 | txt = txt.lower()
11 | txt = re.sub('^',' ', txt)
12 | txt = re.sub('$',' ', txt)
13 |
14 | # url and tag
15 | words = []
16 | for word in txt.split():
17 | i = word.find('http')
18 | if i >= 0:
19 | word = word[:i] + ' ' + '__url__'
20 | words.append(word.strip())
21 | txt = ' '.join(words)
22 |
23 | # remove markdown URL
24 | txt = re.sub(r'\[([^\]]*)\] \( *__url__ *\)', r'\1', txt)
25 |
26 | # remove illegal char
27 | txt = re.sub('__url__','URL',txt)
28 | txt = re.sub(r"[^A-Za-z0-9():,.!?\"\']", " ", txt)
29 | txt = re.sub('URL','__url__',txt)
30 |
31 | # contraction
32 | add_space = ["'s", "'m", "'re", "n't", "'ll","'ve","'d","'em"]
33 | tokenizer = TweetTokenizer(preserve_case=False)
34 | txt = ' ' + ' '.join(tokenizer.tokenize(txt)) + ' '
35 | txt = txt.replace(" won't ", " will n't ")
36 | txt = txt.replace(" can't ", " can n't ")
37 | for a in add_space:
38 | txt = txt.replace(a+' ', ' '+a+' ')
39 |
40 | txt = re.sub(r'^\s+', '', txt)
41 | txt = re.sub(r'\s+$', '', txt)
42 | txt = re.sub(r'\s+', ' ', txt) # remove extra spaces
43 |
44 | #print("out=[%s]" % txt)
45 | return txt
46 |
47 |
48 | if __name__ == '__main__':
49 | ss = [
50 | " I don't know:). how about this?https://github.com/golsun/deep-RL-time-series",
51 | "please try [ GitHub ] ( https://github.com )",
52 | ]
53 | for s in ss:
54 | print(s)
55 | print(clean_str(s))
56 | print()
57 |
58 |
--------------------------------------------------------------------------------
/data_manipulation/prepare_model/finetune_bert_and_gpt2.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from sys import argv
4 | from tqdm import tqdm
5 |
6 | from transformers.modeling_gpt2 import GPT2LMHeadModel
7 | from transformers.tokenization_gpt2 import GPT2Tokenizer
8 | from transformers.modeling_bert import BertForPreTraining
9 | from transformers.tokenization_bert import BertTokenizer
10 | from transformers import AdamW
11 | from torch.utils.data import DataLoader
12 | from torch.utils.data import TensorDataset
13 | from torch.nn import CrossEntropyLoss
14 |
15 | LR = 1e-5
16 | BATCH_SIZE = 32
17 | STEPS = 100
18 | input_file = argv[1]
19 | output_model = argv[2]
20 | base_model = 'BERT'
21 |
22 | BERT_MODEL_PATH = '../bert_model'
23 | GPT2_MODEL_PATH = '../gpt2-small'
24 |
25 | with open(input_file, 'r') as f:
26 | sentences = json.load(f)
27 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
28 | if base_model == 'GPT2':
29 | tokenizer = GPT2Tokenizer.from_pretrained(GPT2_MODEL_PATH)
30 | tokenizer.pad_token = tokenizer.eos_token
31 | model = GPT2LMHeadModel.from_pretrained(GPT2_MODEL_PATH)
32 | elif base_model == 'BERT':
33 | tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_PATH)
34 | tokenizer.pad_token = '[PAD]'
35 | model = BertForPreTraining.from_pretrained(BERT_MODEL_PATH)
36 | optimizer = AdamW(model.parameters(), lr=LR, correct_bias=True)
37 | model.to(device)
38 |
39 | all_inputs = tokenizer(sentences, return_tensors='pt', padding=True).data
40 | dataset = TensorDataset(all_inputs['input_ids'], all_inputs['attention_mask'])
41 | dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
42 | steps = 0
43 | s2s_loss = 0
44 | while True:
45 | tqdm_data = tqdm(dataloader)
46 | for batch in tqdm_data:
47 | optimizer.zero_grad()
48 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
49 | for k, v in inputs.items():
50 | inputs[k] = v.to(device)
51 | if base_model == 'GPT2':
52 | loss = model(**inputs, labels=inputs['input_ids'])[0]
53 | else:
54 | logits = model(**inputs)[0]
55 | loss_fct = CrossEntropyLoss()
56 | labels = inputs['input_ids']
57 | labels = labels.masked_fill((labels==0).long(), -100)
58 | loss = loss_fct(logits.view(-1, model.config.vocab_size), labels.view(-1))
59 | loss.backward()
60 | optimizer.step()
61 | tqdm_data.set_postfix({'s2s_loss': loss.item()})
62 | steps += 1
63 | if steps > STEPS:
64 | break
65 | if steps > STEPS:
66 | break
67 | torch.save(model.state_dict(), output_model)
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 |
22 | class LabelSmoothingLoss(nn.Module):
23 | def __init__(self, n_labels, smoothing=0.0, ignore_index=-100, size_average=True):
24 | super(LabelSmoothingLoss, self).__init__()
25 | assert 0 <= smoothing <= 1
26 |
27 | self.ignore_index = ignore_index
28 | self.confidence = 1 - smoothing
29 |
30 | if smoothing > 0:
31 | self.criterion = nn.KLDivLoss(size_average=size_average)
32 | n_ignore_idxs = 1 + (ignore_index >= 0)
33 | one_hot = torch.full((1, n_labels), fill_value=(smoothing / (n_labels - n_ignore_idxs)))
34 | if ignore_index >= 0:
35 | one_hot[0, ignore_index] = 0
36 | self.register_buffer('one_hot', one_hot)
37 | else:
38 | self.criterion = nn.NLLLoss(size_average=size_average, ignore_index=ignore_index)
39 |
40 | def forward(self, log_inputs, targets):
41 | if self.confidence < 1:
42 | tdata = targets.data
43 |
44 | tmp = self.one_hot.repeat(targets.shape[0], 1)
45 | tmp.scatter_(1, tdata.unsqueeze(1), self.confidence)
46 |
47 | if self.ignore_index >= 0:
48 | mask = torch.nonzero(tdata.eq(self.ignore_index)).squeeze(-1)
49 | if mask.numel() > 0:
50 | tmp.index_fill_(0, mask, 0)
51 |
52 | targets = tmp
53 |
54 | return self.criterion(log_inputs, targets)
55 |
56 | class SoftCrossEntropyLoss(nn.Module):
57 | def __init(self):
58 | super(SoftCrossEntropyLoss).__init__()
59 |
60 | def forward(self, input, soft_targets, lengths):
61 | log_input = -F.log_softmax(input, dim=-1)
62 | loss = torch.mean(torch.sum(torch.sum(torch.mul(log_input, soft_targets), dim=-1), dim=-1) / (lengths + 0.01))
63 | return loss
64 |
--------------------------------------------------------------------------------
/model/entailment_score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader, TensorDataset
3 | import json
4 | from tqdm import tqdm
5 | from typing import Union
6 |
7 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelWithLMHead
8 | from transformers.data.processors.utils import InputExample, InputFeatures
9 |
10 | class EntailmentScorer:
11 | def __init__(self, pred_file, entail_idx_file, model_path, device):
12 | with open(pred_file, 'r') as f:
13 | lines = f.readlines()
14 | self.all_preds = [line.strip() for line in lines]
15 | with open(entail_idx_file, 'r') as f:
16 | refs = json.load(f)
17 | self.all_data = []
18 | for ref in refs:
19 | persona = ref['persona']
20 | idx = ref['index']
21 | for i in idx:
22 | self.all_data.append([[p, self.all_preds[i]] for p in persona])
23 | self.tokenizer = AutoTokenizer.from_pretrained(model_path)
24 | self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
25 | self.model.to(device)
26 | self.device = device
27 |
28 | def _convert_examples_to_features(self, examples, tokenizer, label_list=None, max_length=128, output_mode=None,):
29 | if max_length is None:
30 | max_length = tokenizer.max_len
31 |
32 | label_map = {label: i for i, label in enumerate(label_list)}
33 |
34 | def label_from_example(example: InputExample) -> Union[int, float, None]:
35 | if example.label is None:
36 | return None
37 | if output_mode == "classification":
38 | return label_map[example.label]
39 | elif output_mode == "regression":
40 | return float(example.label)
41 | raise KeyError(output_mode)
42 |
43 | labels = [label_from_example(example) for example in examples]
44 |
45 | batch_encoding = tokenizer(
46 | [(example.text_a, example.text_b) for example in examples],
47 | max_length=max_length,
48 | padding="max_length",
49 | truncation=True,
50 | )
51 |
52 | features = []
53 | for i in range(len(examples)):
54 | inputs = {k: batch_encoding[k][i] for k in batch_encoding}
55 | feature = InputFeatures(**inputs, label=labels[i])
56 | features.append(feature)
57 |
58 | return features
59 |
60 | def calculate_entailment_score(self):
61 | self.model.eval()
62 | entailed_results = []
63 | with torch.no_grad():
64 | for i in tqdm(range(len(self.all_data))):
65 | cur_data = self.all_data[i]
66 | cnt = 0
67 | input_examples = []
68 | for sample in cur_data:
69 | input_examples.append(InputExample(str(cnt), sample[0], sample[1], '0'))
70 | cnt += 1
71 | features = self._convert_examples_to_features(
72 | input_examples,
73 | self.tokenizer,
74 | label_list=['0', '1'],
75 | max_length=128,
76 | output_mode='classification',
77 | )
78 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(self.device)
79 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long).to(self.device)
80 | dataset = TensorDataset(all_input_ids, all_attention_mask)
81 | train_dataloader = DataLoader(dataset, batch_size=8)
82 | all_logits = None
83 | for batch in train_dataloader:
84 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
85 | outputs = self.model(**inputs)
86 | if all_logits is None:
87 | all_logits = outputs[0].detach()
88 | else:
89 | all_logits = torch.cat((all_logits, outputs[0]), dim=0)
90 | results = torch.argmax(all_logits, dim=1)
91 | entailed_results.append(torch.sum(results - 1).item())
92 | return sum(entailed_results)/len(entailed_results)
93 |
--------------------------------------------------------------------------------
/data_manipulation/data_diversification/get_augmented_scores.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import math
4 | import torch.nn as nn
5 | from tqdm import tqdm
6 | from sys import argv
7 |
8 | from transformers.modeling_gpt2 import GPT2LMHeadModel
9 | from transformers.tokenization_gpt2 import GPT2Tokenizer
10 | from torch.utils.data import DataLoader
11 | from torch.utils.data import TensorDataset
12 | from transformers import AutoModelForSequenceClassification
13 | from transformers import AutoTokenizer
14 | from transformers import glue_convert_examples_to_features as convert_examples_to_features
15 | from transformers.data.processors.utils import InputExample
16 |
17 | input_file = argv[1]
18 | output_file = argv[2]
19 | BATCH_SIZE = 16
20 |
21 | def calculate_ppls(responses, device):
22 | print('calculate PPL scores...')
23 | tokenizer = GPT2Tokenizer.from_pretrained('./gpt2_utterance_model')
24 | model = GPT2LMHeadModel.from_pretrained('./gpt2_utterance_model')
25 | model.to(device)
26 | tokenizer.pad_token = tokenizer.eos_token
27 | ppls = []
28 | for r in tqdm(responses):
29 | inputs = tokenizer(r, return_tensors='pt').data
30 | for k, v in inputs.items():
31 | inputs[k] = v.to(device)
32 | loss = model(**inputs, labels=inputs['input_ids'])[0]
33 | ppls.append(math.exp(loss.item()))
34 | return ppls
35 |
36 | def calculate_nli_scores(sentences1, sentences2, model, tokenizer, device):
37 | input_examples = []
38 | cnt = 0
39 | for i in range(len(sentences1)):
40 | input_examples.append(InputExample(str(cnt), sentences1[i], sentences2[i], '0'))
41 | cnt += 1
42 | features = convert_examples_to_features(
43 | input_examples,
44 | tokenizer,
45 | label_list=['0', '1', '2'],
46 | max_length=128,
47 | output_mode='classification',
48 | )
49 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(device)
50 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long).to(device)
51 | dataset = TensorDataset(all_input_ids, all_attention_mask)
52 | dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
53 | model.eval()
54 | all_probs = None
55 | const = torch.tensor([-1, 0, 1], device=device)
56 | with torch.no_grad():
57 | for batch in tqdm(dataloader):
58 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
59 | outputs = model(**inputs)
60 | cur_scores = torch.sum(nn.Softmax(dim=-1)(outputs[0].detach()) * const, dim=-1)
61 | if all_probs is None:
62 | all_probs = cur_scores.cpu()
63 | else:
64 | all_probs = torch.cat((all_probs, cur_scores.cpu()), dim=0)
65 | scores = all_probs.tolist()
66 | return scores
67 |
68 | def calculate_entailment_scores(personas, responses, device):
69 | print('calculate entailment scores...')
70 | tokenizer = AutoTokenizer.from_pretrained('../roberta_mnli')
71 | model = AutoModelForSequenceClassification.from_pretrained('../roberta_mnli')
72 | entailment_scores = calculate_nli_scores(personas, responses, model, tokenizer, device)
73 | return entailment_scores
74 |
75 | def calculate_coherence_scores(history, responses, device):
76 | print('calculate coherence scores...')
77 | tokenizer = AutoTokenizer.from_pretrained('../coherence_nli_model')
78 | model = AutoModelForSequenceClassification.from_pretrained('../coherence_nli_model')
79 | coherence_scores = calculate_nli_scores(history, responses, model, tokenizer, device)
80 | return coherence_scores
81 |
82 | if __name__ == '__main__':
83 | with open(input_file, 'r', encoding='utf-8') as f:
84 | input_data = json.load(f)
85 | responses = [d[2] for d in input_data]
86 | personas = [d[0] for d in input_data]
87 | history = [d[1][-1] for d in input_data]
88 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
89 | ppls = calculate_ppls(responses, device)
90 | entailment_scores = calculate_entailment_scores(personas, responses, device)
91 | coherence_scores = calculate_coherence_scores(history, responses, device)
92 | with open(output_file, 'w') as f:
93 | json.dump({'ppls': ppls, 'entailment_scores': entailment_scores, 'coherence_scores': coherence_scores}, f)
--------------------------------------------------------------------------------
/data_manipulation/data_diversification/filter_augmented_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | from sys import argv
3 | from metrics import cal_novelty
4 |
5 | PPL_WEIGHT = 0.2
6 | ENTAILMENT_WEIGHT = 0.6
7 | COHERENCE_WEIGHT = 0.2
8 | PPL_NORMALIZER = 50
9 | TH = 0.15
10 |
11 | def cal_new_token(data):
12 | token_set = set()
13 | for d in data:
14 | for t in d[0].split():
15 | token_set.add(t)
16 | for s in d[1]:
17 | for t in s.split():
18 | token_set.add(t)
19 | for t in d[2].split():
20 | token_set.add(t)
21 | return len(token_set)
22 |
23 | def get_novelty(original_data, new_data):
24 | original_sentence_persona, original_sentence_utterances = set(), set()
25 | new_persona, new_utterances = [], []
26 | for d in original_data:
27 | original_sentence_persona.add(d[0])
28 | original_sentence_utterances.add(d[2])
29 | original_sentence_utterances.add(d[1][-1])
30 | for d in new_data:
31 | new_persona.append(d[0])
32 | new_utterances.append(d[2])
33 | new_utterances.append(d[1][-1])
34 | n1, n2, n3, n4 = cal_novelty(list(original_sentence_persona), new_persona)
35 | un1, un2, un3, un4 = cal_novelty(list(original_sentence_utterances), new_utterances)
36 | print('111')
37 |
38 | PUNCS = [',', '.', ';', '!', '?', ':']
39 | TYPE_MAP = {(0, 0, 0): 'M_P_G', (0, 0, 1): 'M_P_R', (0, 1, 0): 'M_O_G', (0, 1, 1): 'M_O_R',
40 | (1, 0, 0): 'P_P_G', (1, 0, 1): 'P_P_R', (1, 1, 0): 'P_O_G', (1, 1, 1): 'P_O_R'}
41 |
42 |
43 | input_prefix = argv[1]
44 | with open(input_prefix + '_with_replace_response.json', 'r') as f:
45 | replace_data = json.load(f)
46 | with open(input_prefix + '_with_replace_response_type.json', 'r') as f:
47 | replace_types = json.load(f)
48 | replace_raw_idx = [d for d in replace_types['raw_data_idx']]
49 | replace_types = [d + [1] for d in replace_types['data_type']]
50 | with open(input_prefix + '_with_generated_response.json', 'r') as f:
51 | generated_data = json.load(f)
52 | with open(input_prefix + '_without_response_type.json', 'r') as f:
53 | generated_types = json.load(f)
54 | generated_raw_idx = [d for d in generated_types['raw_data_idx']]
55 | generated_types = [d + [0] for d in generated_types['data_type']]
56 | with open(input_prefix + '_with_replace_response_scores.json', 'r') as f:
57 | replace_scores = json.load(f)
58 | with open(input_prefix + '_with_generated_response_scores.json', 'r') as f:
59 | generated_scores = json.load(f)
60 |
61 | # generated_scores = {'ppls':[], 'entailment_scores': [], 'coherence_scores': []}
62 | all_data = replace_data[:len(replace_data)] + generated_data[:len(generated_data)]
63 | all_types = replace_types[:len(replace_data)] + generated_types[:len(generated_data)]
64 | all_raw_idx = replace_raw_idx[:len(replace_data)] + generated_raw_idx[:len(generated_data)]
65 | all_scores= {}
66 | for k, v in replace_scores.items():
67 | all_scores[k] = replace_scores[k] + generated_scores[k]
68 |
69 | weighted_scores = []
70 | for i in range(len(all_data)):
71 | weighted_scores.append(-all_scores['ppls'][i] / PPL_NORMALIZER * PPL_WEIGHT +
72 | all_scores['entailment_scores'] * ENTAILMENT_WEIGHT +
73 | all_scores['coherence_scores'][i] * COHERENCE_WEIGHT)
74 |
75 | res = []
76 | res_types = []
77 | raw_idx_map = {}
78 | selected_raw_idx = []
79 | for i in range(len(weighted_scores)):
80 | if not raw_idx_map.__contains__(all_raw_idx[i]):
81 | raw_idx_map[all_raw_idx[i]] = [[], []]
82 | if weighted_scores[i] > TH:
83 | res.append(all_data[i])
84 | res_types.append(all_types[i])
85 | if all_types[i] == [1,0,0]:
86 | raw_idx_map[all_raw_idx[i]][0].append(all_data[i])
87 | selected_raw_idx.append(all_raw_idx[i])
88 | else:
89 | raw_idx_map[all_raw_idx[i]][1].append(all_data[i])
90 | type_cnt = {}
91 | for res_type in res_types:
92 | if not type_cnt.__contains__(TYPE_MAP[tuple(res_type)]):
93 | type_cnt[TYPE_MAP[tuple(res_type)]] = 0
94 | type_cnt[TYPE_MAP[tuple(res_type)]] += 1
95 | print('The new augmented data number is ' + str(len(res)))
96 | with open('base_data/th0.99_model_entail_train_self.json', 'r') as f:
97 | original_data = json.load(f)
98 | # get_novelty(original_data, res)
99 | with open('th0.99_self_augmented_no_persona_filter.json', 'w') as f:
100 | json.dump(original_data + res, f)
101 | print('111')
102 |
--------------------------------------------------------------------------------
/data_manipulation/data_distillation/get_distilled_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | from sys import argv
4 |
5 | import numpy as np
6 |
7 | TH = 0.99
8 |
9 | def get_raw_data(original_lines):
10 | raw_data = []
11 | all_personas = set()
12 | for i, line in enumerate(original_lines):
13 | line = line.strip()
14 |
15 | if len(line) == 0:
16 | continue
17 |
18 | space_idx = line.find(' ')
19 | if space_idx == -1:
20 | dialog_idx = int(line)
21 | else:
22 | dialog_idx = int(line[:space_idx])
23 |
24 | if int(dialog_idx) == 1:
25 | raw_data.append({'persona': [], 'revised_persona': [], 'dialog': []})
26 |
27 | dialog_line = line[space_idx + 1:].split('\t')
28 | dialog_line = [l.strip() for l in dialog_line]
29 |
30 | if dialog_line[0].startswith('your persona:'):
31 | persona_info = dialog_line[0].replace('your persona: ', '')
32 | all_personas.add(persona_info[:-1] + ' .')
33 | raw_data[-1]['persona'].append(persona_info[:-1] + ' .')
34 | elif len(dialog_line) > 1:
35 | raw_data[-1]['dialog'].append(dialog_line[0])
36 | raw_data[-1]['dialog'].append(dialog_line[1])
37 | return raw_data, list(all_personas)
38 |
39 | # The original train data file
40 | input_file = argv[1]
41 | # The logits given by NLI model obtained before
42 | logits_file = argv[2]
43 | # The output json file that contains all distilled samples that were determined as entailed by the NLI model
44 | output_file = argv[3]
45 |
46 | if '.json' in input_file:
47 | with open(input_file, 'r') as f:
48 | data = json.load(f)
49 | cnt1, cnt2, cnt3 = 0, 0, 0
50 | with open(logits_file, 'rb') as f:
51 | logits = pickle.load(f)
52 | entail = np.argmax(logits, axis=-1)
53 | entail_result = []
54 | for i, d in enumerate(data):
55 | cur_logit = logits[i]
56 | if entail[i] == 0:
57 | cnt1 += 1
58 | elif entail[i] == 1:
59 | cnt2 += 1
60 | elif entail[i] == 2:
61 | softmax = np.exp(logits[i]) / np.sum(np.exp(logits[i]))
62 | if softmax[2] < TH:
63 | cnt2 += 1
64 | else:
65 | cnt3 += 1
66 | entail_result.append(data[i])
67 | print(cnt1)
68 | print(cnt2)
69 | print(cnt3)
70 | with open(output_file, 'w') as f:
71 | json.dump(entail_result, f)
72 | else:
73 | with open(input_file, 'r', encoding='utf-8') as f:
74 | lines = f.readlines()
75 | raw_data, all_personas = get_raw_data(lines)
76 |
77 | with open(logits_file, 'rb') as f:
78 | d = pickle.load(f)
79 | cnt1, cnt2, cnt3 = 0, 0, 0
80 | entail_result = []
81 | for dialog in d:
82 | cur_entail_result = []
83 | for p in dialog:
84 | res = np.argmax(p, axis=-1)
85 | res = list(res)
86 | for i, r in enumerate(res):
87 | if r == 2:
88 | softmax = np.exp(p[i]) / np.sum(np.exp(p[i]))
89 | if softmax[2] < TH:
90 | res[i] = 1
91 | cur_entail_result.append(list(res))
92 | entail_result.append(cur_entail_result)
93 | contradict_list = []
94 | entail_list = []
95 | neutral_list = []
96 | entail_sample_idx = []
97 | sample_idx = 0
98 | for i, dialog in enumerate(entail_result):
99 | for j, p in enumerate(dialog):
100 | for k, u in enumerate(p):
101 | if u == 0:
102 | cnt1 += 1
103 | contradict_list.append((i, j, k))
104 | if u == 1:
105 | cnt2 += 1
106 | if u == 2:
107 | cnt3 += 1
108 | entail_list.append((i, j, k))
109 | entail_sample_idx.append(sample_idx + k)
110 | sample_idx += len(p)
111 | print(cnt1)
112 | print(cnt2)
113 | print(cnt3)
114 |
115 | entail_data = []
116 | for idx in entail_list:
117 | persona = raw_data[idx[0]]['persona'][idx[1]]
118 | response_idx = idx[2] * 2 + 1
119 | response = raw_data[idx[0]]['dialog'][response_idx]
120 | history = raw_data[idx[0]]['dialog'][max(0, response_idx - 3): response_idx]
121 | if len(history) == 0:
122 | history = ['__SILENCE__']
123 | entail_data.append([persona, history, response])
124 | with open(output_file, 'w', encoding='utf-8') as f:
125 | json.dump(entail_data, f)
126 | with open(output_file + 'raw_idx', 'w', encoding='utf-8') as f:
127 | json.dump(entail_sample_idx, f)
128 |
--------------------------------------------------------------------------------
/data_manipulation/data_distillation/calculate_entailment.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | from sys import argv
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from torch.utils.data import TensorDataset
8 | from tqdm import tqdm
9 | from transformers import AutoModelForSequenceClassification
10 | from transformers import AutoTokenizer
11 | from transformers import glue_convert_examples_to_features as convert_examples_to_features
12 | from transformers.data.processors.utils import InputExample
13 |
14 | NLI_MODEL_PATH = './persona_nli'
15 |
16 | # The original train file
17 | input_file = argv[1]
18 | # The output file that saves the NLI logits given the train samples
19 | output_file = argv[2]
20 |
21 | def get_dataloader(input_examples, tokenizer, device):
22 | features = convert_examples_to_features(
23 | input_examples,
24 | tokenizer,
25 | label_list=['0', '1'],
26 | max_length=128,
27 | output_mode='classification',
28 | )
29 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(device)
30 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long).to(device)
31 | dataset = TensorDataset(all_input_ids, all_attention_mask)
32 | dataloader = DataLoader(dataset, batch_size=6)
33 | return dataloader
34 |
35 |
36 | def read_txt_data(input_file):
37 | with open(input_file, 'r') as f:
38 | lines = f.readlines()
39 | cur_persona1, cur_persona2 = [], []
40 | cur_dialogs1, cur_dialogs2 = [], []
41 | personas, dialogs = [], []
42 | sentence_pairs = []
43 | start = True
44 | for line in lines:
45 | if 'your persona:' in line or 'partner\'s persona' in line:
46 | if start and len(cur_persona1) > 0:
47 | personas.append(cur_persona1)
48 | personas.append(cur_persona2)
49 | dialogs.append(cur_dialogs1)
50 | dialogs.append(cur_dialogs2)
51 | cur_persona1, cur_persona2 = [], []
52 | cur_dialogs1, cur_dialogs2 = [], []
53 | start = False
54 | if 'your persona:' in line:
55 | persona_index = line.find('your persona:')
56 | persona = line[persona_index + 14: -1]
57 | cur_persona1.append(persona)
58 | elif 'partner\'s persona' in line:
59 | persona_index = line.find('partner\'s persona:')
60 | persona = line[persona_index + 19: -1]
61 | cur_persona2.append(persona)
62 | else:
63 | start = True
64 | space_index = line.find(' ')
65 | sents = line[space_index + 1:].split('\t')
66 | cur_dialogs1.append(sents[1])
67 | cur_dialogs2.append(sents[0])
68 | return personas, dialogs
69 |
70 |
71 | def read_json_data(input_file):
72 | with open(input_file, 'r') as f:
73 | data = json.load(f)
74 | examples = []
75 | cnt = 0
76 | for d in data:
77 | examples.append(InputExample(str(cnt), d[0], d[2], '0'))
78 | cnt += 1
79 | return examples
80 |
81 | tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL_PATH)
82 | model = AutoModelForSequenceClassification.from_pretrained(NLI_MODEL_PATH)
83 |
84 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
85 | model.to(device)
86 | model.eval()
87 | pred_results = []
88 | if '.json' in input_file:
89 | all_logits = None
90 | input_examples = read_json_data(input_file)
91 | train_dataloader = get_dataloader(input_examples, tokenizer, device)
92 | with torch.no_grad():
93 | for batch in tqdm(train_dataloader):
94 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
95 | outputs = model(**inputs)
96 | if all_logits is None:
97 | all_logits = outputs[0].cpu().detach()
98 | else:
99 | all_logits = torch.cat((all_logits, outputs[0].cpu().detach()), dim=0)
100 | all_logits = all_logits.numpy()
101 | with open(output_file, 'wb') as f:
102 | pickle.dump(all_logits, f)
103 | else:
104 | personas, dialogs = read_txt_data(input_file)
105 | entailed_results = []
106 | with torch.no_grad():
107 | for i in tqdm(range(len(personas))):
108 | cur_persona = personas[i]
109 | cur_dialogs = dialogs[i]
110 | cnt = 0
111 | cur_pred_results = []
112 | for persona in cur_persona:
113 | input_examples = []
114 | for dialog in cur_dialogs:
115 | input_examples.append(InputExample(str(cnt), persona, dialog, '0'))
116 | cnt += 1
117 | train_dataloader = get_dataloader(input_examples, tokenizer, device)
118 | all_logits = None
119 | for batch in train_dataloader:
120 | inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
121 | outputs = model(**inputs)
122 | if all_logits is None:
123 | all_logits = outputs[0].detach()
124 | else:
125 | all_logits = torch.cat((all_logits, outputs[0].detach()), dim=0)
126 | results = torch.argmax(all_logits, dim=1)
127 | for j, r in enumerate(results):
128 | if r == 2:
129 | entailed_results.append((persona, cur_dialogs[j]))
130 | cur_pred_results.append(all_logits.cpu())
131 | pred_results.append(cur_pred_results)
132 | with open('entailed_sentences.json', 'w') as f:
133 | json.dump(entailed_results, f)
134 | torch.save(pred_results, 'entailment_scores.bin')
135 |
--------------------------------------------------------------------------------
/data_manipulation/prepare_model/train_nli_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
3 | import json
4 | from tqdm import tqdm
5 | import numpy as np
6 |
7 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
8 | from transformers.data.processors.utils import InputExample
9 | from transformers import glue_convert_examples_to_features as convert_examples_to_features
10 |
11 | EPOCHS = 2
12 | LR = 1e-5
13 | WEIGHT_DECAY = 0.0
14 | WARMUP_RATIO = 0.05
15 | EVAL_INTERVAL = 1000
16 | BATCH_SIZE = 32
17 | MAX_GRAD_NORM = 1.0
18 |
19 | INPUT_MODEL_PATH = './roberta_mnli'
20 | OUTPUT_MODEL_FILE = 'best_model.bin'
21 |
22 | def get_input_examples(data):
23 | input_examples = []
24 | label_dict = {'negative': '0', 'neutral': '1', 'positive': '2'}
25 | for d in data:
26 | input_examples.append(InputExample(d['id'], d['sentence1'], d['sentence2'], label_dict[d['label']]))
27 | return input_examples
28 |
29 | def eval_model(model, dev_dataloader, prev_best, step):
30 | dev_tqdm_data = tqdm(dev_dataloader, desc='Evaluation (step #{})'.format(step))
31 | eval_loss = 0
32 | model.eval()
33 | preds, out_label_ids = None, None
34 | eval_step = 0
35 | with torch.no_grad():
36 | for batch in dev_tqdm_data:
37 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], 'labels': batch[2]}
38 | outputs = model(**inputs)
39 | tmp_eval_loss, logits = outputs[:2]
40 | eval_step += 1
41 | eval_loss += tmp_eval_loss.mean().item()
42 | dev_tqdm_data.set_postfix({'loss': eval_loss / eval_step})
43 | if preds is None:
44 | preds = logits.detach().cpu().numpy()
45 | out_label_ids = inputs["labels"].detach().cpu().numpy()
46 | else:
47 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
48 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
49 | preds = np.argmax(preds, axis=1)
50 | accuracy = (preds == out_label_ids).astype(np.float32).mean().item()
51 | if accuracy > prev_best:
52 | print('Current model BEATS the previous best model, previous best is {:.3f}, current is {:.3f}'.format(prev_best, accuracy))
53 | torch.save(model.state_dict(), OUTPUT_MODEL_FILE)
54 | prev_best = accuracy
55 | else:
56 | print('Current model CANNOT BEAT the previous best model, previous best is {:.3f}, current is {:.3f}'.format(prev_best, accuracy))
57 | return prev_best
58 |
59 | with open('dialogue_nli_dataset/dialogue_nli_train.jsonl', 'r') as f:
60 | train_data = json.load(f)
61 | with open('dialogue_nli_dataset/dialogue_nli_dev.jsonl', 'r') as f:
62 | dev_data = json.load(f)
63 | train_examples = get_input_examples(train_data)
64 | dev_examples = get_input_examples(dev_data)
65 |
66 | tokenizer = AutoTokenizer.from_pretrained(INPUT_MODEL_PATH)
67 | model = AutoModelForSequenceClassification.from_pretrained(INPUT_MODEL_PATH)
68 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
69 | model.to(device)
70 | if torch.cuda.device_count() > 1:
71 | device = torch.device('cuda:0')
72 | model = model.to(device)
73 | model = torch.nn.parallel.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
74 |
75 | train_features = convert_examples_to_features(
76 | train_examples,
77 | tokenizer,
78 | label_list=['0', '1', '2'],
79 | max_length=128,
80 | output_mode='classification',
81 | )
82 | dev_features = convert_examples_to_features(
83 | dev_examples,
84 | tokenizer,
85 | label_list=['0', '1', '2'],
86 | max_length=128,
87 | output_mode='classification',
88 | )
89 | train_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).to(device)
90 | train_attention_mask = torch.tensor([f.attention_mask for f in train_features], dtype=torch.long).to(device)
91 | train_labels = torch.tensor([f.label for f in train_features], dtype=torch.long).to(device)
92 | train_dataset = TensorDataset(train_input_ids, train_attention_mask, train_labels)
93 | train_sampler = RandomSampler(train_dataset)
94 | train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
95 | dev_input_ids = torch.tensor([f.input_ids for f in dev_features], dtype=torch.long).to(device)
96 | dev_attention_mask = torch.tensor([f.attention_mask for f in dev_features], dtype=torch.long).to(device)
97 | dev_labels = torch.tensor([f.label for f in dev_features], dtype=torch.long).to(device)
98 | dev_dataset = TensorDataset(dev_input_ids, dev_attention_mask, dev_labels)
99 | dev_dataloader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
100 |
101 | #eval_model(model, dev_dataloader, 0, 0)
102 |
103 | t_total = len(train_dataloader) * EPOCHS
104 | no_decay = ["bias", "LayerNorm.weight"]
105 | optimizer_grouped_parameters = [
106 | {
107 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
108 | "weight_decay": WEIGHT_DECAY,
109 | },
110 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
111 | ]
112 | optimizer = AdamW(optimizer_grouped_parameters, lr=LR, eps=1e-8)
113 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_RATIO, num_training_steps=t_total)
114 |
115 | prev_best = 0
116 | for epoch in range(EPOCHS):
117 | total_loss = 0.0
118 | tqdm_data = tqdm(train_dataloader, desc='Train (epoch #{})'.format(epoch + 1))
119 | step = 0
120 | for batch in tqdm_data:
121 | model.train()
122 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], 'labels': batch[2]}
123 | outputs = model(**inputs)
124 | loss = outputs[0]
125 | loss = loss.mean()
126 | loss.backward()
127 | total_loss += loss.item()
128 | step += 1
129 | tqdm_data.set_postfix({'loss': total_loss / step})
130 | torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
131 | optimizer.step()
132 | scheduler.step()
133 | optimizer.zero_grad()
134 | if step % EVAL_INTERVAL == 0:
135 | prev_best = eval_model(model, dev_dataloader, prev_best, step)
136 | prev_best = eval_model(model, dev_dataloader, prev_best, step)
137 |
138 |
--------------------------------------------------------------------------------
/metrics/multi-bleu.perl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env perl
2 | #
3 | # This file is part of moses. Its use is licensed under the GNU Lesser General
4 | # Public License version 2.1 or, at your option, any later version.
5 |
6 | # $Id$
7 | use warnings;
8 | use strict;
9 |
10 | my $lowercase = 0;
11 | if ($ARGV[0] eq "-lc") {
12 | $lowercase = 1;
13 | shift;
14 | }
15 |
16 | my $stem = $ARGV[0];
17 | if (!defined $stem) {
18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n";
19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n";
20 | exit(1);
21 | }
22 |
23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";
24 |
25 | my @REF;
26 | my $ref=0;
27 | while(-e "$stem$ref") {
28 | &add_to_ref("$stem$ref",\@REF);
29 | $ref++;
30 | }
31 | &add_to_ref($stem,\@REF) if -e $stem;
32 | die("ERROR: could not find reference file $stem") unless scalar @REF;
33 |
34 | # add additional references explicitly specified on the command line
35 | shift;
36 | foreach my $stem (@ARGV) {
37 | &add_to_ref($stem,\@REF) if -e $stem;
38 | }
39 |
40 |
41 |
42 | sub add_to_ref {
43 | my ($file,$REF) = @_;
44 | my $s=0;
45 | if ($file =~ /.gz$/) {
46 | open(REF,"gzip -dc $file|") or die "Can't read $file";
47 | } else {
48 | open(REF,$file) or die "Can't read $file";
49 | }
50 | while([) {
51 | chop;
52 | push @{$$REF[$s++]}, $_;
53 | }
54 | close(REF);
55 | }
56 |
57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference);
58 | my $s=0;
59 | while() {
60 | chop;
61 | $_ = lc if $lowercase;
62 | my @WORD = split;
63 | my %REF_NGRAM = ();
64 | my $length_translation_this_sentence = scalar(@WORD);
65 | my ($closest_diff,$closest_length) = (9999,9999);
66 | foreach my $reference (@{$REF[$s]}) {
67 | # print "$s $_ <=> $reference\n";
68 | $reference = lc($reference) if $lowercase;
69 | my @WORD = split(' ',$reference);
70 | my $length = scalar(@WORD);
71 | my $diff = abs($length_translation_this_sentence-$length);
72 | if ($diff < $closest_diff) {
73 | $closest_diff = $diff;
74 | $closest_length = $length;
75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
76 | } elsif ($diff == $closest_diff) {
77 | $closest_length = $length if $length < $closest_length;
78 | # from two references with the same closeness to me
79 | # take the *shorter* into account, not the "first" one.
80 | }
81 | print "$length \n";
82 | for(my $n=1;$n<=4;$n++) {
83 | my %REF_NGRAM_N = ();
84 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
85 | my $ngram = "$n";
86 | for(my $w=0;$w<$n;$w++) {
87 | $ngram .= " ".$WORD[$start+$w];
88 | }
89 | $REF_NGRAM_N{$ngram}++;
90 | }
91 | foreach my $ngram (keys %REF_NGRAM_N) {
92 | if (!defined($REF_NGRAM{$ngram}) ||
93 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
94 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
95 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}]
\n";
96 | }
97 | }
98 | }
99 | }
100 | $length_translation += $length_translation_this_sentence;
101 | $length_reference += $closest_length;
102 | for(my $n=1;$n<=4;$n++) {
103 | my %T_NGRAM = ();
104 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
105 | my $ngram = "$n";
106 | for(my $w=0;$w<$n;$w++) {
107 | $ngram .= " ".$WORD[$start+$w];
108 | }
109 | $T_NGRAM{$ngram}++;
110 | }
111 | foreach my $ngram (keys %T_NGRAM) {
112 | $ngram =~ /^(\d+) /;
113 | # my $n = $1;
114 | # my $corr = 0;
115 | # print "$i e $ngram $T_NGRAM{$ngram}
\n";
116 | $TOTAL[$n] += $T_NGRAM{$ngram};
117 | if (defined($REF_NGRAM{$ngram})) {
118 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
119 | $CORRECT[$n] += $T_NGRAM{$ngram};
120 | # $corr = $T_NGRAM{$ngram};
121 | # print "$i e correct1 $T_NGRAM{$ngram}
\n";
122 | }
123 | else {
124 | $CORRECT[$n] += $REF_NGRAM{$ngram};
125 | # $corr = $REF_NGRAM{$ngram};
126 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n";
127 | }
128 | }
129 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
130 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
131 | }
132 | }
133 | $s++;
134 | }
135 | my $brevity_penalty = 1;
136 | my $bleu = 0;
137 |
138 | my @bleu=();
139 |
140 | for(my $n=1;$n<=4;$n++) {
141 | if (defined ($TOTAL[$n])){
142 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
143 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
144 | }else{
145 | $bleu[$n]=0;
146 | }
147 | }
148 |
149 | if ($length_reference==0){
150 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
151 | exit(1);
152 | }
153 |
154 | if ($length_translation<$length_reference) {
155 | $brevity_penalty = exp(1-$length_reference/$length_translation);
156 | }
157 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
158 | my_log( $bleu[2] ) +
159 | my_log( $bleu[3] ) +
160 | my_log( $bleu[4] ) ) / 4) ;
161 | printf "BLEU = %.5f, %.5f/%.5f/%.5f/%.5f (BP=%.5f, ratio=%.5f, hyp_len=%d, ref_len=%d)\n",
162 | 100*$bleu,
163 | 100*$bleu[1],
164 | 100*$bleu[2],
165 | 100*$bleu[3],
166 | 100*$bleu[4],
167 | $brevity_penalty,
168 | $length_translation / $length_reference,
169 | $length_translation,
170 | $length_reference;
171 |
172 |
173 | print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n";
174 |
175 | sub my_log {
176 | return -9999999999 unless $_[0];
177 | return log($_[0]);
178 | }
--------------------------------------------------------------------------------
/BT/recover.py:
--------------------------------------------------------------------------------
1 | import json
2 | from sys import argv
3 |
4 | from nltk.translate.bleu_score import sentence_bleu
5 | from tqdm import tqdm
6 |
7 | bt_file = argv[1]
8 | original_file = argv[2]
9 | index_file = argv[3]
10 | output_file = argv[4]
11 |
12 | SYMBOLS = ['.', ',', '!', '?', ';', '"']
13 | SPECIAL1 = '’'
14 |
15 | def find_the_most_different_replace(original_sentence, candidates):
16 | min_bleu = 100
17 | res = candidates[0]
18 | for c in candidates:
19 | bleu = sentence_bleu(original_sentence.lower(), c.lower())
20 | if bleu < min_bleu and abs(len(original_sentence) - len(c)) < len(original_sentence) * 0.4:
21 | res = c.lower()
22 | min_bleu = bleu
23 | return res
24 |
25 | def clean_data_core(line):
26 | line_list = []
27 | for c in line:
28 | if c in SYMBOLS or c == SPECIAL1:
29 | if len(line_list) > 0 and line_list[-1] != ' ':
30 | line_list.append(' ')
31 | if c == SPECIAL1:
32 | c = '\''
33 | line_list.append(c)
34 | new_sentence = ''.join(line_list)
35 | end_idx = len(new_sentence)
36 | for i in range(len(new_sentence)):
37 | cur_idx = len(new_sentence) - 1 - i
38 | if new_sentence[cur_idx] not in SYMBOLS and new_sentence[cur_idx] != ' ':
39 | break
40 | else:
41 | if new_sentence[cur_idx] in SYMBOLS:
42 | end_idx = cur_idx + 1
43 | new_line = new_sentence[:end_idx]
44 | return new_line
45 |
46 | def recover_persona_end(cleaned_persona):
47 | if len(cleaned_persona) >= 2 and cleaned_persona[-1] in SYMBOLS and cleaned_persona[-2] == ' ':
48 | j = len(cleaned_persona) - 2
49 | while cleaned_persona[j] == ' ':
50 | j -= 1
51 | cleaned_persona = cleaned_persona[:j + 1] + '.'
52 | elif len(cleaned_persona) > 0 and cleaned_persona[-1] not in SYMBOLS and cleaned_persona[-1] != ' ':
53 | cleaned_persona = cleaned_persona + '.'
54 | return cleaned_persona
55 |
56 | def clean_data(data, is_json=False):
57 | if is_json:
58 | for i, sample in enumerate(data):
59 | persona, history, response = sample[0], sample[1], sample[2]
60 | cleaned_persona = clean_data_core(persona)
61 | cleaned_persona = recover_persona_end(cleaned_persona)
62 | for j, h in enumerate(history):
63 | if h != '__SILENCE__':
64 | history[j] = clean_data_core(h)
65 | cleaned_response = clean_data_core(response)
66 | data[i] = [cleaned_persona, history, cleaned_response]
67 | else:
68 | for i, line in enumerate(data):
69 | line = line.lower().strip()
70 | if 'your persona: ' in line:
71 | cleaned_persona = clean_data_core(line)
72 | if cleaned_persona[-1] in SYMBOLS and cleaned_persona[-2] == ' ':
73 | j = len(cleaned_persona) - 2
74 | while cleaned_persona[j] == ' ':
75 | j -= 1
76 | cleaned_persona = cleaned_persona[:j + 1] + '.'
77 | elif cleaned_persona[-1] not in SYMBOLS and cleaned_persona[-1] != ' ':
78 | cleaned_persona = cleaned_persona + '.'
79 | data[i] = cleaned_persona + '\n'
80 | else:
81 | items = line.split('\t')
82 | for j, s in enumerate(items[:2]):
83 | items[j] = clean_data_core(s)
84 | data[i] = '\t'.join(items) + '\n'
85 | return data
86 |
87 | beam_size = 25
88 | bt_sentences = []
89 | with open(bt_file, 'r', encoding='utf-8') as f:
90 | lines = f.readlines()
91 | i = 0
92 | while i < len(lines):
93 | cur_sentences = [s.strip() for s in lines[i: i + beam_size]]
94 | i += beam_size
95 | bt_sentences.append(cur_sentences)
96 | with open(index_file, 'r') as f:
97 | indices = json.load(f)
98 | if '.txt' in original_file:
99 | with open(original_file, 'r') as f:
100 | original_lines = f.readlines()
101 | prev_line_idx = -1
102 | for i, line_idx in enumerate(tqdm(indices)):
103 | cur_bt_sentences = bt_sentences[i]
104 | line = original_lines[line_idx].strip()
105 | if 'your persona: ' in line:
106 | start_index = line.find('your persona: ')
107 | original_sentence = line[start_index + 14:]
108 | replace_start, replace_end = start_index + 14, len(line)
109 | else:
110 | space_index = line.find(' ')
111 | items = line[space_index + 1:].split('\t')
112 | if line_idx != prev_line_idx:
113 | original_sentence = items[0]
114 | replace_start, replace_end = space_index + 1, space_index + len(original_sentence)
115 | else:
116 | original_sentence = items[1]
117 | t_index = line.find('\t')
118 | replace_start, replace_end = t_index + 1, t_index + len(original_sentence) + 1
119 | replace_sentence = find_the_most_different_replace(original_sentence, cur_bt_sentences)
120 | original_lines[line_idx] = line[:replace_start] + replace_sentence + line[replace_end:] + '\n'
121 | prev_line_idx = line_idx
122 | cleaned_lines = clean_data(original_lines)
123 | with open(output_file, 'w', encoding='utf-8') as f:
124 | f.writelines(original_lines)
125 | if '.json' in original_file:
126 | with open(original_file, 'r', encoding='utf-8') as f:
127 | original_data = json.load(f)
128 | for i, data_index in enumerate(tqdm(indices)):
129 | cur_bt_sentences = bt_sentences[i]
130 | sample_index = data_index[0]
131 | original_sample = original_data[sample_index]
132 | if data_index[1] < 3:
133 | original_sentence = original_sample[data_index[1]]
134 | replace_sentence = find_the_most_different_replace(original_sentence, cur_bt_sentences)
135 | original_sample[data_index[1]] = replace_sentence
136 | else:
137 | original_sentence = original_sample[1][data_index[1] - 3]
138 | if original_sentence != '__SILENCE__':
139 | replace_sentence = find_the_most_different_replace(original_sentence, cur_bt_sentences)
140 | original_sample[1][data_index[1] - 3] = replace_sentence
141 | original_data[sample_index] = original_sample
142 | cleaned_data = clean_data(original_data, is_json=True)
143 | with open(output_file, 'w', encoding='utf-8') as f:
144 | json.dump(original_data, f)
145 |
--------------------------------------------------------------------------------
/data_manipulation/prepare_model/train_coherence_nli.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
3 | import json
4 | from tqdm import tqdm
5 | import numpy as np
6 |
7 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
8 | from transformers.data.processors.utils import InputExample
9 | from transformers import glue_convert_examples_to_features as convert_examples_to_features
10 |
11 | EPOCHS = 3
12 | LR = 1e-5
13 | WEIGHT_DECAY = 0.0
14 | WARMUP_RATIO = 0.05
15 | EVAL_INTERVAL = 2000
16 | BATCH_SIZE = 96
17 | MAX_GRAD_NORM = 1.0
18 |
19 | INPUT_MODEL_PATH = './roberta_mnli'
20 | OUTPUT_MODEL_FILE = 'best_model.bin'
21 |
22 | def get_input_examples(data):
23 | input_examples = []
24 | label_dict = {'contradiction': '0', 'neutral': '1', 'entailment': '2'}
25 | for d in data:
26 | input_examples.append(InputExample(d['pairID'], d['sentence1'], d['sentence2'], label_dict[d['gold_label']]))
27 | return input_examples
28 |
29 | def eval_model(model, dev_dataloader, prev_best, step):
30 | dev_tqdm_data = tqdm(dev_dataloader, desc='Evaluation (step #{})'.format(step))
31 | eval_loss = 0
32 | model.eval()
33 | preds, out_label_ids = None, None
34 | eval_step = 0
35 | with torch.no_grad():
36 | for batch in dev_tqdm_data:
37 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], 'labels': batch[2]}
38 | outputs = model(**inputs)
39 | tmp_eval_loss, logits = outputs[:2]
40 | eval_step += 1
41 | eval_loss += tmp_eval_loss.mean().item()
42 | dev_tqdm_data.set_postfix({'loss': eval_loss / eval_step})
43 | if preds is None:
44 | preds = logits.detach().cpu().numpy()
45 | out_label_ids = inputs["labels"].detach().cpu().numpy()
46 | else:
47 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
48 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
49 | preds = np.argmax(preds, axis=1)
50 | accuracy = (preds == out_label_ids).astype(np.float32).mean().item()
51 | if accuracy > prev_best:
52 | print('Current model BEATS the previous best model, previous best is {:.3f}, current is {:.3f}'.format(prev_best, accuracy))
53 | torch.save(model.state_dict(), OUTPUT_MODEL_FILE)
54 | prev_best = accuracy
55 | else:
56 | print('Current model CANNOT BEAT the previous best model, previous best is {:.3f}, current is {:.3f}'.format(prev_best, accuracy))
57 | return prev_best
58 |
59 | with open('convai_nli_valid.jsonl', 'r', encoding='utf-8') as f:
60 | lines = f.readlines()
61 | train_data = []
62 | for line in lines:
63 | train_data.append(json.loads(line.strip()))
64 | with open('convai_nli_valid.jsonl', 'r', encoding='utf-8') as f:
65 | lines = f.readlines()
66 | dev_data = []
67 | for line in lines:
68 | dev_data.append(json.loads(line.strip()))
69 | train_examples = get_input_examples(train_data)
70 | dev_examples = get_input_examples(dev_data)
71 |
72 | tokenizer = AutoTokenizer.from_pretrained(INPUT_MODEL_PATH)
73 | model = AutoModelForSequenceClassification.from_pretrained(INPUT_MODEL_PATH)
74 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
75 | model.to(device)
76 | if torch.cuda.device_count() > 1:
77 | device = torch.device('cuda:0')
78 | model = model.to(device)
79 | model = torch.nn.parallel.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
80 |
81 | train_features = convert_examples_to_features(
82 | train_examples,
83 | tokenizer,
84 | label_list=['0', '1', '2'],
85 | max_length=128,
86 | output_mode='classification',
87 | )
88 | dev_features = convert_examples_to_features(
89 | dev_examples,
90 | tokenizer,
91 | label_list=['0', '1', '2'],
92 | max_length=128,
93 | output_mode='classification',
94 | )
95 | train_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).to(device)
96 | train_attention_mask = torch.tensor([f.attention_mask for f in train_features], dtype=torch.long).to(device)
97 | train_labels = torch.tensor([f.label for f in train_features], dtype=torch.long).to(device)
98 | train_dataset = TensorDataset(train_input_ids, train_attention_mask, train_labels)
99 | train_sampler = RandomSampler(train_dataset)
100 | train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
101 | dev_input_ids = torch.tensor([f.input_ids for f in dev_features], dtype=torch.long).to(device)
102 | dev_attention_mask = torch.tensor([f.attention_mask for f in dev_features], dtype=torch.long).to(device)
103 | dev_labels = torch.tensor([f.label for f in dev_features], dtype=torch.long).to(device)
104 | dev_dataset = TensorDataset(dev_input_ids, dev_attention_mask, dev_labels)
105 | dev_dataloader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
106 |
107 | t_total = len(train_dataloader) * EPOCHS
108 | no_decay = ["bias", "LayerNorm.weight"]
109 | optimizer_grouped_parameters = [
110 | {
111 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
112 | "weight_decay": WEIGHT_DECAY,
113 | },
114 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
115 | ]
116 | optimizer = AdamW(optimizer_grouped_parameters, lr=LR, eps=1e-8)
117 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_RATIO, num_training_steps=t_total)
118 |
119 | prev_best = 0
120 | for epoch in range(EPOCHS):
121 | total_loss = 0.0
122 | tqdm_data = tqdm(train_dataloader, desc='Train (epoch #{})'.format(epoch + 1))
123 | step = 0
124 | prev_best = eval_model(model, dev_dataloader, prev_best, step)
125 | for batch in tqdm_data:
126 | model.train()
127 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], 'labels': batch[2]}
128 | outputs = model(**inputs)
129 | loss = outputs[0]
130 | loss = loss.mean()
131 | loss.backward()
132 | total_loss += loss.item()
133 | step += 1
134 | tqdm_data.set_postfix({'loss': total_loss / step})
135 | torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
136 | optimizer.step()
137 | scheduler.step()
138 | optimizer.zero_grad()
139 | if step % EVAL_INTERVAL == 0:
140 | prev_best = eval_model(model, dev_dataloader, prev_best, step)
141 | prev_best = eval_model(model, dev_dataloader, prev_best, step)
142 |
--------------------------------------------------------------------------------
/metrics/dstc_example.py:
--------------------------------------------------------------------------------
1 | # author: Xiang Gao @ Microsoft Research, Oct 2018
2 | # evaluate DSTC-task2 submissions. https://github.com/DSTC-MSR-NLP/DSTC7-End-to-End-Conversation-Modeling
3 |
4 | from util import *
5 | from metrics import *
6 | from tokenizers import *
7 |
8 | def extract_cells(path_in, path_hash):
9 | keys = [line.strip('\n') for line in open(path_hash)]
10 | cells = dict()
11 | for line in open(path_in, encoding='utf-8'):
12 | c = line.strip('\n').split('\t')
13 | k = c[0]
14 | if k in keys:
15 | cells[k] = c[1:]
16 | return cells
17 |
18 |
19 | def extract_hyp_refs(raw_hyp, raw_ref, path_hash, fld_out, n_refs=6, clean=False, vshuman=-1):
20 | cells_hyp = extract_cells(raw_hyp, path_hash)
21 | cells_ref = extract_cells(raw_ref, path_hash)
22 | if not os.path.exists(fld_out):
23 | os.makedirs(fld_out)
24 |
25 | def _clean(s):
26 | if clean:
27 | return clean_str(s)
28 | else:
29 | return s
30 |
31 | keys = sorted(cells_hyp.keys())
32 | with open(fld_out + '/hash.txt', 'w', encoding='utf-8') as f:
33 | f.write(unicode('\n'.join(keys)))
34 |
35 | lines = [_clean(cells_hyp[k][-1]) for k in keys]
36 | path_hyp = fld_out + '/hyp.txt'
37 | with open(path_hyp, 'w', encoding='utf-8') as f:
38 | f.write(unicode('\n'.join(lines)))
39 |
40 | lines = []
41 | for _ in range(n_refs):
42 | lines.append([])
43 | for k in keys:
44 | refs = cells_ref[k]
45 | for i in range(n_refs):
46 | idx = i % len(refs)
47 | if idx == vshuman:
48 | idx = (idx + 1) % len(refs)
49 | lines[i].append(_clean(refs[idx].split('|')[1]))
50 |
51 | path_refs = []
52 | for i in range(n_refs):
53 | path_ref = fld_out + '/ref%i.txt'%i
54 | with open(path_ref, 'w', encoding='utf-8') as f:
55 | f.write(unicode('\n'.join(lines[i])))
56 | path_refs.append(path_ref)
57 |
58 | return path_hyp, path_refs
59 |
60 |
61 | def eval_one_system(submitted, keys, multi_ref, n_refs=6, n_lines=None, clean=False, vshuman=-1, PRINT=True):
62 |
63 | print('evaluating %s' % submitted)
64 |
65 | fld_out = submitted.replace('.txt','')
66 | if clean:
67 | fld_out += '_cleaned'
68 | path_hyp, path_refs = extract_hyp_refs(submitted, multi_ref, keys, fld_out, n_refs, clean=clean, vshuman=vshuman)
69 | nist, bleu, meteor, entropy, div, avg_len = nlp_metrics(path_refs, path_hyp, fld_out, n_lines=n_lines)
70 |
71 | if n_lines is None:
72 | n_lines = len(open(path_hyp, encoding='utf-8').readlines())
73 |
74 | if PRINT:
75 | print('n_lines = '+str(n_lines))
76 | print('NIST = '+str(nist))
77 | print('BLEU = '+str(bleu))
78 | print('METEOR = '+str(meteor))
79 | print('entropy = '+str(entropy))
80 | print('diversity = ' + str(div))
81 | print('avg_len = '+str(avg_len))
82 |
83 | return [n_lines] + nist + bleu + [meteor] + entropy + div + [avg_len]
84 |
85 |
86 | def eval_all_systems(files, path_report, keys, multi_ref, n_refs=6, n_lines=None, clean=False, vshuman=False):
87 | # evaluate all systems (*.txt) in each folder `files`
88 |
89 | with open(path_report, 'w') as f:
90 | f.write('\t'.join(
91 | ['fname', 'n_lines'] + \
92 | ['nist%i'%i for i in range(1, 4+1)] + \
93 | ['bleu%i'%i for i in range(1, 4+1)] + \
94 | ['meteor'] + \
95 | ['entropy%i'%i for i in range(1, 4+1)] +\
96 | ['div1','div2','avg_len']
97 | ) + '\n')
98 |
99 | for fl in files:
100 | if fl.endswith('.txt'):
101 | submitted = fl
102 | results = eval_one_system(submitted, keys=keys, multi_ref=multi_ref, n_refs=n_refs, clean=clean, n_lines=n_lines, vshuman=vshuman, PRINT=False)
103 | with open(path_report, 'a') as f:
104 | f.write('\t'.join(map(str, [submitted] + results)) + '\n')
105 | else:
106 | for fname in os.listdir(fl):
107 | if fname.endswith('.txt'):
108 | submitted = fl + '/' + fname
109 | results = eval_one_system(submitted, keys=keys, multi_ref=multi_ref, n_refs=n_refs, clean=clean, n_lines=n_lines, vshuman=vshuman, PRINT=False)
110 | with open(path_report, 'a') as f:
111 | f.write('\t'.join(map(str, [submitted] + results)) + '\n')
112 |
113 | print('report saved to: '+path_report, file=sys.stderr)
114 |
115 |
116 | if __name__ == '__main__':
117 |
118 | parser = argparse.ArgumentParser()
119 | parser.add_argument('submitted') # if 'all' or '*', eval all teams listed in dstc/teams.txt
120 | # elif endswith '.txt', eval this single file
121 | # else, eval all *.txt in folder `submitted_fld`
122 |
123 | parser.add_argument('--clean', '-c', action='store_true') # whether to clean ref and hyp before eval
124 | parser.add_argument('--n_lines', '-n', type=int, default=-1) # eval all lines (default) or top n_lines (e.g., for fast debugging)
125 | parser.add_argument('--n_refs', '-r', type=int, default=6) # number of references
126 | parser.add_argument('--vshuman', '-v', type=int, default='1') # when evaluating against human performance (N in refN.txt that should be removed)
127 | # in which case we need to remove human output from refs
128 | parser.add_argument('--refs', '-g', default='dstc/test.refs')
129 | parser.add_argument('--keys', '-k', default='keys/test.2k.txt')
130 | parser.add_argument('--teams', '-i', type=str, default='dstc/teams.txt')
131 | parser.add_argument('--report', '-o', type=str, default=None)
132 | args = parser.parse_args()
133 | print('Args: %s\n' % str(args), file=sys.stderr)
134 |
135 | if args.n_lines < 0:
136 | n_lines = None # eval all lines
137 | else:
138 | n_lines = args.n_lines # just eval top n_lines
139 |
140 | if args.submitted.endswith('.txt'):
141 | eval_one_system(args.submitted, keys=args.keys, multi_ref=args.refs, clean=args.clean, n_lines=n_lines, n_refs=args.n_refs, vshuman=args.vshuman)
142 | else:
143 | fname_report = 'report_ref%i'%args.n_refs
144 | if args.clean:
145 | fname_report += '_cleaned'
146 | fname_report += '.tsv'
147 | if args.submitted == 'all' or args.submitted == '*':
148 | files = ['dstc/' + line.strip('\n') for line in open(args.teams)]
149 | path_report = 'dstc/' + fname_report
150 | else:
151 | files = [args.submitted]
152 | path_report = args.submitted + '/' + fname_report
153 | if args.report != None:
154 | path_report = args.report
155 | eval_all_systems(files, path_report, keys=args.keys, multi_ref=args.refs, clean=args.clean, n_lines=n_lines, n_refs=args.n_refs, vshuman=args.vshuman)
156 |
--------------------------------------------------------------------------------
/attention_experiment/get_target_samples.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 |
4 | from transformers.tokenization_gpt2 import GPT2Tokenizer
5 |
6 | from itertools import chain
7 |
8 | IGNORE_TOKENS = ['i', 'my', 'he', 'she', '.', 'am', 'was', 'is', 'are', 'have', 'has', 'had']
9 |
10 | def read_txt_data(input_file):
11 | with open(input_file, 'r', encoding='utf-8') as f:
12 | lines = f.readlines()
13 | data = []
14 | for line in lines:
15 | line = line.strip()
16 | if len(line) == 0:
17 | continue
18 |
19 | space_idx = line.find(' ')
20 | if space_idx == -1:
21 | dialog_idx = int(line)
22 | else:
23 | dialog_idx = int(line[:space_idx])
24 |
25 | if int(dialog_idx) == 1:
26 | data.append({'persona_info': [], 'dialog': []})
27 |
28 | dialog_line = line[space_idx + 1:].split('\t')
29 | dialog_line = [l.strip() for l in dialog_line]
30 |
31 | if dialog_line[0].startswith('your persona:'):
32 | persona_info = dialog_line[0].replace('your persona: ', '')
33 | if persona_info[-1] == '.' and persona_info[-2] != ' ':
34 | persona_info = persona_info[:-1] + ' .'
35 | data[-1]['persona_info'].append(persona_info)
36 | elif len(dialog_line) > 1:
37 | data[-1]['dialog'].append(dialog_line[0])
38 | data[-1]['dialog'].append(dialog_line[1])
39 | return data
40 |
41 | def get_dataset_and_history_positions(data):
42 | dataset = []
43 | positions = []
44 | for chat in data:
45 | persona_info = [s.split() for s in chat['persona_info']]
46 | dialog = []
47 | for i, replica in enumerate(chat['dialog'], 1):
48 | dialog.append(replica.split())
49 | if not i % 2:
50 | dataset.append((persona_info, dialog[:], []))
51 | persona_len = [len(x) for x in persona_info]
52 | dialog_len = [len(x) for x in dialog]
53 | persona_pos, history_pos = [], []
54 | p = 1
55 | for l in persona_len:
56 | p = p + l
57 | persona_pos.append(p)
58 | for j in range(max(len(dialog_len) - 6, 0), len(dialog_len) - 1):
59 | p = p + 1 + dialog_len[j]
60 | history_pos.append(p)
61 | positions.append([persona_pos, history_pos])
62 | for i, data in enumerate(dataset):
63 | dataset[i] = [[' '.join(p) for p in data[0]], [' '.join(u) for u in data[1]]]
64 | return dataset, positions
65 |
66 | def _get_entail_index_matched_token_positions(entail_data, tokenizer):
67 | all_attention_positions = []
68 | matched_token_positions = []
69 | for i in range(len(entail_data)):
70 | entail_sample = entail_data[i]
71 | raw_idx = raw_sample_idx[i]
72 | if tokenizer is None:
73 | persona = entail_sample[0].split()
74 | response = entail_sample[2].split()
75 | else:
76 | persona = [t[1:] if t[0] == 'Ġ' else t for t in tokenizer.tokenize(entail_sample[0])]
77 | response = [t[1:] if t[0] == 'Ġ' else t for t in tokenizer.tokenize(entail_sample[2])]
78 | target_positions = []
79 | for i in range(len(persona)):
80 | for j in range(len(response)):
81 | if persona[i] not in IGNORE_TOKENS and persona[i] == response[j]:
82 | target_positions.append([i, j])
83 | all_attention_positions.append([raw_idx, entail_sample[0]])
84 | matched_token_positions.append(target_positions)
85 | return all_attention_positions, matched_token_positions
86 |
87 | def get_dataset_and_persona_positions(raw_data, entail_data, tokenizer):
88 | all_attention_positions, all_matched_token_positions = _get_entail_index_matched_token_positions(entail_data, tokenizer)
89 | raw_sample_idx_set = set([s[0] for s in all_attention_positions])
90 | all_dataset = []
91 | index = 0
92 | for chat in raw_data:
93 | persona_info = [s for s in chat['persona_info']]
94 | dialog = []
95 | for i, replica in enumerate(chat['dialog'], 1):
96 | dialog.append(replica)
97 | if not i % 2:
98 | all_dataset.append((persona_info, dialog[:], []))
99 | new_dataset = [all_dataset[i] for i in [x[0] for x in all_attention_positions]]
100 |
101 | all_target_persona_sentence_positions = []
102 | all_persona_positions = []
103 | for i in range(len(all_attention_positions)):
104 | target_persona = all_attention_positions[i][1]
105 | cur_data = new_dataset[i]
106 | target_persona_text_index = 0
107 | while target_persona != cur_data[0][target_persona_text_index]:
108 | target_persona_text_index += 1
109 | if tokenizer is None:
110 | tokenized_personas = [p.split() for p in cur_data[0]]
111 | else:
112 | tokenized_personas = [tokenizer.tokenize(p) for p in cur_data[0]]
113 | target_persona_start_token_index = 1 + sum([len(x) for x in tokenized_personas[:target_persona_text_index]])
114 | target_persona_end_token_index = target_persona_start_token_index + \
115 | len(tokenized_personas[target_persona_text_index])
116 | all_target_persona_sentence_positions.append([target_persona_start_token_index, target_persona_end_token_index])
117 | all_persona_length = 1 + sum([len(x) for x in tokenized_personas])
118 | all_persona_positions.append([1, all_persona_length])
119 | if len(all_matched_token_positions[i]) > 0:
120 | for j, positions in enumerate(all_matched_token_positions[i]):
121 | all_matched_token_positions[i][j][0] = all_matched_token_positions[i][j][0] + \
122 | target_persona_start_token_index
123 | return new_dataset, all_target_persona_sentence_positions, all_persona_positions, all_matched_token_positions
124 |
125 | with open('base_data/th0.99_dev_self.json', 'r') as f:
126 | entail_data = json.load(f)
127 | with open('base_data/th0.99_dev_raw_sample_idx.json', 'r') as f:
128 | raw_sample_idx = json.load(f)
129 | data = read_txt_data('../datasets/ConvAI2/valid_self_original.txt')
130 | gpt2_special_sumbol = 'Ġ'
131 |
132 | MODE = 'persona'
133 | MODEL = 'gpt2'
134 |
135 | if MODE == 'persona':
136 | tokenizer = None
137 | if MODEL == 'gpt2':
138 | tokenizer = GPT2Tokenizer.from_pretrained('../gpt2-small')
139 | new_dataset, target_persona_sentence_position, all_persona_positions, matched_token_positions = \
140 | get_dataset_and_persona_positions(data, entail_data, tokenizer)
141 | with open('th0.99_consistent_dataset.json', 'w') as f:
142 | json.dump(new_dataset, f)
143 | with open('th0.99_consistent_positions.json', 'w') as f:
144 | json.dump({'token_positions': matched_token_positions,
145 | 'target_persona_positions': target_persona_sentence_position,
146 | 'whole_persona_positions': all_persona_positions}, f)
147 | else:
148 | data, positions = get_dataset_and_history_positions(data)
149 | with open('attention_dev_data.json', 'w', encoding='utf-8') as f:
150 | json.dump(data, f)
151 | with open('attention_dev_position.json', 'w', encoding='utf-8') as f:
152 | json.dump(positions, f)
153 |
154 | print('111')
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # D3
2 | ### The implementation for ACL 2022 paper
3 |
4 | ### A Model-Agnostic Data Manipulation Method for Persona-based Dialogue Generation
5 |
6 | 
7 |
8 | ---
9 |
10 | ### File structure
11 | 1. The main entrance to train the model is in `train.py` in the root directory. We also provide some example shells
12 | for running under different conditions.
13 | 2. The code related to our data manipulation method D3 is under `./data_manipulation`, where you can obtain augmented
14 | data by using code under this directory.
15 | 3. `./attention_experiment` contains scripts for the attention experiments (like Appendix C.1 and C.4) in our paper
16 | 4. `./model` contains scripts for all other necessary parts to run experiment, including models, optimizer, data interface
17 | and so on.
18 |
19 | ---
20 |
21 | ### Requirements
22 | 1. python == 3.7.0
23 | 2. torch==1.5.0
24 | 3. transformers==3.1.0
25 | 4. spacy==2.2.4
26 | 5. fairseq==0.9.0 (I downloaded the source code into the root directory)
27 | 6. sentencepiece==0.1.94
28 |
29 | For evaluating the generated responses, you need to install `java-1.8.0`, `perl`, java-1.8.0, as well as
30 | perl library including XML::Twig, Sort::Naturally, String::Util (I use cpanm to install them on Linux).
31 |
32 | [METEOR](https://www.cs.cmu.edu/~alavie/METEOR/download/meteor-1.5.tar.gz) is also needed for evaluating the quality
33 | of responses, please unzip it and put it under `./metrics/`
34 |
35 | We also use [BERTScore](https://github.com/Tiiiger/bert_score) as a metric in our experiments, you may need to download
36 | a proper BERT model for a successful evaluation. Here we use a roberta-large model. To rescale the score, we have put
37 | the baseline file under `./bert_score/rescale_baseline/en`. If you want to rescale the bert score, please add
38 | `--rescale_with_baseline` in the training shell.
39 |
40 | ---
41 | ## Run the code
42 |
43 | To be honest, just applying step 3.Data Distillation can achieve a satisfactory performance. The step 4.data diversification
44 | contribute less to the final results and is more complex.
45 |
46 | ## 1. Obtain PersonaChat Dataset
47 |
48 | Obtain PersonaChat dataset via [ParlAI](https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/personachat)
49 | or [our zipped version](https://drive.google.com/file/d/1zQVO5MuEy3wBUfZpM39uYmloD3-Rld4T/view) and put them into the `./datasets` directory.
50 |
51 | ## 2. Prepare models
52 | At first, we have to get all trained models we need for data manipulation in experiments.
53 | You need go to `./data_manipulation/prepare_model`.
54 |
55 | ##### 1) NLI model for evaluating persona consistency
56 | You need to download [DialogueNLI dataset](https://wellecks.github.io/dialogue_nli/)
57 | and put it under this directory. Also, download large size [RoBERTa MNLI model](https://huggingface.co/roberta-large-mnli)
58 | and put it under this directory, renaming the document as `roberta_mnli/`.
59 |
60 | Then you can train the NLI model using this dataset using script `train_nli_model.py`.
61 |
62 | After obtain the trained best model, you need to renamed the file `best_model.bin` as `pytorch_model.bin` for the following
63 | use. Define the path that saves the trained NLI model for persona consistency as `PERSONA_NLI`.
64 |
65 | We also provide our trained [NLI model](https://drive.google.com/file/d/1QnT8V2Yj4Zl2yW2rnQIi2p56I_wbN3Ee/view?usp=sharing)
66 | for downloading.
67 |
68 | ##### 2) NLI model for evaluating coherence of dialogue history
69 |
70 | Using the same RoBERTa MNLI model we used in 1 and `train_coherence_nli.py` to train it on the [InferConvAI2 dataset](https://github.com/nouhadziri/DialogEntailment).
71 | It is a dialogue NLI dataset designed for evaluating the coherence of dialogue history.
72 |
73 | Save the obtained model, define the path containing the model as `COHERENCE_NLI`.
74 |
75 | ##### 3) BERT and GPT2 model used in data diversification
76 |
77 | First use `extract_personas_and_responses.py` to extract persona and response texts into two json files.
78 |
79 | Download the [bert-based-uncased model](https://huggingface.co/bert-base-uncased) and [gpt2-small model](https://huggingface.co/gpt2),
80 | put them under the corresponding directories you like.
81 | Then using `finetune_bert_and_gpt2.py` to fine tune BERT and GPT2 model on `personas.json`, obtaining BERTper and
82 | GPT2per, then fine tune GPT2 on `responses.json` to obtain GPT2res, editing the code to assign the model paths
83 | of BERT and GPT2 you just defined before.
84 |
85 | ##### 4) Back translation model for dialogue history diversification
86 |
87 | Got to directory `./BT`.
88 |
89 | Download [WMT14 en-fr corpus](http://statmt.org/wmt14/translation-task.html#Download), and pre-processing it with
90 | BPE from sentencepiece using `preprocess.sh`, obtaining `sentence.bpe.model`.
91 |
92 | Train en-fr and fr-en translation model using `train_en-fr.sh` and `train_fr-en.sh` under this directory and the average the last 5 models using
93 | `average_model.sh`. Define the obtained model checkpoints as `BT_EN-FR` and `BT-FR-EN`.
94 |
95 | ## 3. Data Distillation
96 |
97 | Go to `./data_augmentation/data_distillation`.
98 |
99 | Using `calculate_entailment.py` to obtained the predicted results given by the NLI model under `PERSONA_NLI`
100 | you obtained before.
101 |
102 | Then using `get_distilled_dataset.py` to obtain the distilled dataset using the previously logits given by the NLI model.
103 | Assume that the obtain distilled data file is `DISTILL_DATA`.
104 |
105 | ## 4. Data diversification
106 |
107 | ##### 1) Obtain the Multi-GPT2 model for response align under new personas
108 | At first you need to obtain a Multi-GPT2 model trained on the distilled samples. You can use the shell
109 | `train_multi_gpt2_distilled.sh` under the root directory. Set the training data as `DISTILL_DATA`
110 | according to the definitions of `config.py`. Note that you should use `config.json` under `multi_gpt2` to replace the
111 | original `config.json` in the initial model weight path to train this model.
112 |
113 | ##### 2) Augment dialogue history
114 | Then you need to augment dialogue history. Go to `./BT`, using `get_bt_input_file.py` to transform the distilled data
115 | `DISTILL_DATA` into the format for back translation. Then use `bpe_split.py` to pre-process the newly obtained txt file with BPE.
116 |
117 | Using `evaluate.sh` and `evaluate_back.sh` you can translate all utterance into French and then back to English.
118 |
119 | Finally, using `recover.py` you can recover the txt file into its original distilled data format in a json file.
120 |
121 | ##### 3) Editing personas and align responses
122 | Go to `./data_augmentation/data_diversification`. Using `generate_new_personas_and_edit_responses.py` you can obtain
123 | new personas as well as some samples with edited new responses if applicable.
124 |
125 | Using `inference_multi_gpt2.sh` in the root directory you can get the predicted responses for the rest samples.
126 |
127 | Using `get_augmented_scores.py` you can get the filter scores for each new sample.
128 |
129 | Using `filter_augmented_data.py` you can get the filtered diversified samples along with the distilled one. They form
130 | the augmented dataset used as an easy curriculum for training.
131 |
132 | ## 5. Train model
133 |
134 | Put the obtained augmented dataset into `./datasets/augmented/` and then you can train two models using
135 | `train_seq2seq_D3.sh` and `train_gpt2_D3.sh`.
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | import copy
18 | import io
19 | import json
20 | import os
21 | import random
22 | import re
23 | import sys
24 | import logging
25 | from collections import Counter, namedtuple
26 |
27 | import numpy as np
28 | import torch
29 | import torch.nn as nn
30 | from scipy.interpolate import RectBivariateSpline
31 | from torch.utils.checkpoint import checkpoint
32 |
33 | py_version = sys.version.split('.')[0]
34 | if py_version == '2':
35 | open = io.open
36 | unicode = unicode
37 | else:
38 | unicode = str
39 | open = open
40 |
41 |
42 | def set_seed(seed):
43 | torch.manual_seed(seed)
44 | torch.cuda.manual_seed(seed)
45 | random.seed(seed)
46 |
47 |
48 | def repeat_along_dim1(obj, repetitions):
49 | """ repeat (a possibly nested object of) tensors from (batch, ...) to (batch * repetitions, ...) """
50 | if isinstance(obj, tuple):
51 | return tuple(repeat_along_dim1(o, repetitions) for o in obj)
52 | if isinstance(obj, list):
53 | return list(repeat_along_dim1(o, repetitions) for o in obj)
54 |
55 | obj = obj.unsqueeze(1).repeat([1, repetitions] + [1] * len(obj.size()[1:]))
56 | return obj.view(-1, *obj.size()[2:])
57 |
58 |
59 | def pad_sequence(sequences, batch_first=False, padding_value=0, left=False):
60 | # assuming trailing dimensions and type of all the Tensors
61 | # in sequences are same and fetching those from sequences[0]
62 | if not len(sequences):
63 | return torch.empty(0)
64 | trailing_dims = sequences[0].size()[1:]
65 | max_len = max([s.size(0) for s in sequences])
66 | if batch_first:
67 | out_dims = (len(sequences), max_len) + trailing_dims
68 | else:
69 | out_dims = (max_len, len(sequences)) + trailing_dims
70 |
71 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
72 | for i, tensor in enumerate(sequences):
73 | length = tensor.size(0)
74 | s_slice = slice(-length, None) if left else slice(None, length)
75 | s_slice = (i, s_slice) if batch_first else (s_slice, i)
76 | out_tensor[s_slice] = tensor
77 |
78 | return out_tensor
79 |
80 |
81 | def checkpoint_sequential(functions, segments, *inputs):
82 | def run_function(start, end, functions):
83 | def forward(*inputs):
84 | for j in range(start, end + 1):
85 | inputs = functions[j](*inputs)
86 | return inputs
87 | return forward
88 |
89 | if isinstance(functions, torch.nn.Sequential):
90 | functions = list(functions.children())
91 |
92 | segment_size = len(functions) // segments
93 | # the last chunk has to be non-volatile
94 | end = -1
95 | for start in range(0, segment_size * (segments - 1), segment_size):
96 | end = start + segment_size - 1
97 | inputs = checkpoint(run_function(start, end, functions), *inputs)
98 | if not isinstance(inputs, tuple):
99 | inputs = (inputs,)
100 | return run_function(end + 1, len(functions) - 1, functions)(*inputs)
101 |
102 |
103 | def f1_score(predictions, targets, average=True):
104 | def f1_score_items(pred_items, gold_items):
105 | common = Counter(gold_items) & Counter(pred_items)
106 | num_same = sum(common.values())
107 |
108 | if num_same == 0:
109 | return 0
110 |
111 | precision = num_same / len(pred_items)
112 | recall = num_same / len(gold_items)
113 | f1 = (2 * precision * recall) / (precision + recall)
114 |
115 | return f1
116 |
117 | scores = [f1_score_items(p, t) for p, t in zip(predictions, targets)]
118 |
119 | if average:
120 | return sum(scores) / len(scores)
121 |
122 | return scores
123 |
124 |
125 | def openai_transformer_config():
126 | class dotdict(dict):
127 | __getattr__ = dict.get
128 | __setattr__ = dict.__setitem__
129 | __delattr__ = dict.__delitem__
130 |
131 | cfg = dotdict({'n_layers': 12, 'n_embeddings': 40477, 'n_pos_embeddings': 512,
132 | 'embeddings_size': 768, 'n_heads': 12, 'dropout': 0.1,
133 | 'embed_dropout': 0.1, 'attn_dropout': 0.1, 'ff_dropout': 0.1})
134 |
135 | return cfg
136 |
137 |
138 | def load_openai_weights(model, directory, n_special_tokens=0, use_tokenizer=False):
139 | # TODO: add check of shapes
140 |
141 | parameters_names_path = os.path.join(directory, 'parameters_names.json')
142 | parameters_shapes_path = os.path.join(directory, 'parameters_shapes.json')
143 | parameters_weights_paths = [os.path.join(directory, 'params_{}.npy'.format(n)) for n in range(10)]
144 |
145 | with open(parameters_names_path, 'r') as parameters_names_file:
146 | parameters_names = json.load(parameters_names_file)
147 |
148 | with open(parameters_shapes_path, 'r') as parameters_shapes_file:
149 | parameters_shapes = json.load(parameters_shapes_file)
150 |
151 | parameters_weights = [np.load(path) for path in parameters_weights_paths]
152 | parameters_offsets = np.cumsum([np.prod(shape) for shape in parameters_shapes])
153 | parameters_weights = np.split(np.concatenate(parameters_weights, 0), parameters_offsets)[:-1]
154 | parameters_weights = [p.reshape(s) for p, s in zip(parameters_weights, parameters_shapes)]
155 |
156 | if not use_tokenizer:
157 | parameters_weights[1] = parameters_weights[1][1:] # skip 0 -
158 |
159 | if isinstance(model.pos_embeddings, nn.Embedding):
160 | if model.pos_embeddings.num_embeddings - 1 > parameters_weights[0].shape[0]:
161 | xx = np.linspace(0, parameters_weights[0].shape[0], model.pos_embeddings.num_embeddings - 1)
162 | new_kernel = RectBivariateSpline(np.arange(parameters_weights[0].shape[0]),
163 | np.arange(parameters_weights[0].shape[1]),
164 | parameters_weights[0])
165 | parameters_weights[0] = new_kernel(xx, np.arange(parameters_weights[0].shape[1]))
166 |
167 | # parameters_weights[0] = parameters_weights[0][:model.pos_embeddings.num_embeddings - 1]
168 | # model.pos_embeddings.weight.data[1:] = torch.from_numpy(parameters_weights[0])
169 | model.pos_embeddings.weight.data = torch.from_numpy(parameters_weights[0])
170 |
171 |
172 | if use_tokenizer:
173 | model.embeddings.weight.data[-n_special_tokens + 1:] = 0
174 | model.embeddings.weight.data[: -n_special_tokens + 1] = torch.from_numpy(parameters_weights[1])
175 | else:
176 | parameters_weights[1] = parameters_weights[1][:model.embeddings.num_embeddings - n_special_tokens]
177 | model.embeddings.weight.data[:n_special_tokens] = 0
178 | model.embeddings.weight.data[n_special_tokens:] = torch.from_numpy(parameters_weights[1])
179 |
180 | parameters_weights = parameters_weights[2:]
181 |
182 | for name, weights in zip(parameters_names, parameters_weights):
183 | name = name[6:] # skip "model/"
184 | assert name[-2:] == ':0'
185 | name = name[:-2]
186 | name = name.split('/')
187 |
188 | pointer = model
189 | for m_name in name:
190 | if re.fullmatch(r'[A-Za-z]+\d+', m_name):
191 | l = re.split(r'(\d+)', m_name)
192 | else:
193 | l = [m_name]
194 |
195 | pointer = getattr(pointer, l[0])
196 |
197 | if len(l) >= 2:
198 | num = int(l[1])
199 | pointer = pointer[num]
200 |
201 | if len(weights.shape) == 3: # conv1d to linear
202 | weights = weights[0].transpose((1, 0))
203 |
204 | pointer.data[...] = torch.from_numpy(weights)
205 |
206 | # Initialize shared attention layer is necessary
207 | for layer in model.layers:
208 | attn_state = layer.attn.state_dict()
209 | for context_attn in layer.context_attns:
210 | context_attn.load_state_dict(copy.deepcopy(attn_state), strict=False)
211 |
212 | def config_logger(log_path):
213 | logger = logging.getLogger()
214 | logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',
215 | level=logging.INFO)
216 | file_handler = logging.FileHandler(log_path, mode='w')
217 | file_handler.setLevel(logging.INFO)
218 | file_handler.setFormatter(
219 | logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'))
220 | logger.addHandler(file_handler)
221 | return logger
222 |
--------------------------------------------------------------------------------
/model/seq2seq_vocab.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import json
4 |
5 | class Seq2seqTokenizer:
6 | def __init__(self):
7 | self.word2idx = {"": 0, "": 1, "": 2, "": 3, '': 4, '': 5}
8 | self.idx2word = {0: "", 1: "", 2: "", 3: "", 4: "", 5: ""}
9 | self.n_words = 6
10 | self.all_special_ids = [0, 1, 2, 3, 4, 5]
11 | self.pad_id = 1
12 | self.bos_id = 2
13 | self.eos_id = 3
14 | self.talker1_bos_id = 4
15 | self.talker2_bos_id = 5
16 |
17 | def tokenize(self, str):
18 | res = str.strip().split(' ')
19 | res = [x.lower() for x in res]
20 | return res
21 |
22 | def encode(self, tokenized_str):
23 | res = []
24 | for token in tokenized_str:
25 | if self.word2idx.__contains__(token):
26 | res.append(self.word2idx[token])
27 | return res
28 |
29 | def decode(self, ids, skip_special_tokens=True, clean_up_tokenization_spaces=False):
30 | res = []
31 | for id in ids:
32 | if skip_special_tokens and id in self.all_special_ids:
33 | continue
34 | res.append(self.idx2word[id])
35 | text = ' '.join(res)
36 | return text
37 |
38 | def index_words(self, sentence):
39 | for word in sentence.split(' '):
40 | self.index_word(word)
41 |
42 | def index_word(self, word):
43 | if not self.word2idx.__contains__(word):
44 | self.word2idx[word] = self.n_words
45 | self.idx2word[self.n_words] = word
46 | self.n_words += 1
47 |
48 | class Seq2seqVocab:
49 | def __init__(self, train_dataset_path, valid_dataset_path, test_dataset_path, vocab_path, data_type='persona',
50 | extra_train_data_path=None, extra_data_type='persona', extend_exist_vocab=None):
51 | if (os.path.exists(vocab_path)):
52 | with open(vocab_path, 'rb') as f:
53 | cached_data = pickle.load(f)
54 | self.vocab = cached_data[0]
55 | self.all_data = cached_data[1]
56 | else:
57 | if extend_exist_vocab:
58 | with open(extend_exist_vocab, 'rb') as f:
59 | cached_data = pickle.load(f)
60 | self.vocab = cached_data[0]
61 | print('loaded vocab size' + str(self.vocab.n_words))
62 | else:
63 | self.vocab = Seq2seqTokenizer()
64 | self.all_data = self._parse_data(train_dataset_path, valid_dataset_path, test_dataset_path, data_type)
65 | if extra_train_data_path:
66 | extra_data = self._parse_data(extra_train_data_path, None, None, extra_data_type)
67 | self.all_data.extend(extra_data)
68 | self.parse_vocab(self.all_data, self.vocab)
69 | with open(vocab_path, 'wb') as f:
70 | pickle.dump([self.vocab, []], f)
71 |
72 | def _parse_data(self, train_dataset_path, valid_dataset_path, test_dataset_path, data_type):
73 | data = None
74 | if data_type == 'persona':
75 | data = self.parse_data_persona(train_dataset_path, valid_dataset_path, test_dataset_path)
76 | elif data_type == 'emoji':
77 | data = self.parse_data_emoji(train_dataset_path, valid_dataset_path, test_dataset_path)
78 | elif data_type == 'daily':
79 | data = self.parse_data_daily(train_dataset_path, valid_dataset_path, test_dataset_path)
80 | elif data_type == 'entailment':
81 | data = self.parse_data_entailment(train_dataset_path, valid_dataset_path, test_dataset_path)
82 | return data
83 |
84 | def parse_data_persona(self, train_dataset_path, valid_dataset_path, test_dataset_path):
85 | subsets = [train_dataset_path, valid_dataset_path, test_dataset_path]
86 | all_data = []
87 | for subset in subsets:
88 | data = []
89 | if subset is None or len(subset) == 0:
90 | all_data.append(data)
91 | continue
92 | with open(subset, 'r', encoding='utf-8') as f:
93 | for line in f.readlines():
94 | line = line.strip()
95 | space_idx = line.find(' ')
96 | if space_idx == -1:
97 | dialog_idx = int(line)
98 | else:
99 | dialog_idx = int(line[:space_idx])
100 |
101 | if int(dialog_idx) == 1:
102 | data.append({'persona_info': [], 'dialog': [], 'candidates': []})
103 |
104 | dialog_line = line[space_idx + 1:].split('\t')
105 | dialog_line = [l.strip() for l in dialog_line]
106 |
107 | if dialog_line[0].startswith('your persona:'):
108 | persona_info = dialog_line[0].replace('your persona: ', '')
109 | if persona_info[-1] == '.' and persona_info[-2] != ' ':
110 | persona_info = persona_info[:-1] + ' .'
111 | data[-1]['persona_info'].append(persona_info)
112 | if dialog_line[0].startswith('partner\'s person'):
113 | if not data[-1].__contains__('partner_persona_info'):
114 | data[-1]['partner_persona_info'] = []
115 | persona_info = dialog_line[0].replace('partner\'s persona: ', '')
116 | if persona_info[-1] == '.' and persona_info[-2] != ' ':
117 | persona_info = persona_info[:-1] + ' .'
118 | data[-1]['partner_persona_info'].append(persona_info)
119 | elif len(dialog_line) > 1:
120 | data[-1]['dialog'].append(dialog_line[0])
121 | data[-1]['dialog'].append(dialog_line[1])
122 | if len(dialog_line) == 4:
123 | data[-1]['candidates'].append(dialog_line[3].split('|')[:-1]) # the last candidate is a duplicate of the good answer (dialog_line[1])
124 |
125 | all_data.append(data)
126 | return all_data
127 |
128 | def parse_data_emoji(self, train_dataset_path, valid_dataset_path, test_dataset_path):
129 | subsets = [train_dataset_path, valid_dataset_path, test_dataset_path]
130 | all_data = []
131 | for subset in subsets:
132 | data = []
133 | if subset is None or len(subset) == 0:
134 | all_data.append(data)
135 | continue
136 | with open(subset, 'r', encoding='utf-8') as f:
137 | for line in f.readlines():
138 | line = line.strip()
139 | items = line.split('\t')
140 | data.append({'persona_info': [], 'dialog': [], 'candidates': []})
141 | data[-1]['persona_info'].append(items[0])
142 | data[-1]['dialog'].append(items[1])
143 | data[-1]['dialog'].append(items[2])
144 | all_data.append(data)
145 | return all_data
146 |
147 | def parse_data_daily(self, train_dataset_path, valid_dataset_path, test_dataset_path):
148 | subsets = [train_dataset_path, valid_dataset_path, test_dataset_path]
149 | all_data = []
150 | for subset in subsets:
151 | data = []
152 | if subset is None or len(subset) == 0:
153 | all_data.append(data)
154 | continue
155 | with open(subset, 'r', encoding='utf-8') as f:
156 | for line in f.readlines():
157 | line = line.strip()
158 | items = line.split('\t')
159 | items = [item.strip().lower() for item in items]
160 | data.append({'persona_info': [], 'dialog': [], 'candidates': []})
161 | data[-1]['persona_info'].append(items[0])
162 | for i in range(1, len(items)):
163 | data[-1]['dialog'].append(items[i])
164 | all_data.append(data)
165 | return all_data
166 |
167 | def parse_data_entailment(self, train_dataset_path, valid_dataset_path, test_dataset_path):
168 | subsets = [train_dataset_path, valid_dataset_path, test_dataset_path]
169 | all_data = []
170 | for subset in subsets:
171 | data = []
172 | if subset is None or len(subset) == 0:
173 | all_data.append(data)
174 | continue
175 | try:
176 | with open(subset, 'r', encoding='utf-8') as f:
177 | data =[]
178 | list = json.load(f)
179 | for item in list:
180 | data.append({'persona_info': [], 'dialog': [], 'candidates': []})
181 | data[-1]['persona_info'].append(item[0])
182 | data[-1]['dialog'].extend(item[1])
183 | data[-1]['dialog'].append(item[2])
184 | except:
185 | print('Incorrect data format ' + subset)
186 | all_data.append(data)
187 | return all_data
188 |
189 | def parse_vocab(self, all_data, vocab):
190 | for data in all_data:
191 | for p in data:
192 | for s in p['persona_info']:
193 | vocab.index_words(s)
194 | for s in p['dialog']:
195 | vocab.index_words(s)
196 | for c in p['candidates']:
197 | for s in c:
198 | vocab.index_words(s)
--------------------------------------------------------------------------------
/model/optim.py:
--------------------------------------------------------------------------------
1 | # transformer_chatbot
2 | # Copyright (C) 2018 Golovanov, Tselousov
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU Affero General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU Affero General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU Affero General Public License
15 | # along with this program. If not, see .
16 |
17 | import logging
18 | import math
19 |
20 | import torch
21 |
22 | logger = logging.getLogger(__file__)
23 |
24 |
25 | class Adam(torch.optim.Optimizer):
26 | """Implements Adam algorithm.
27 | This implementation is modified from torch.optim.Adam based on:
28 | `Fixed Weight Decay Regularization in Adam`
29 | (see https://arxiv.org/abs/1711.05101)
30 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
31 | Arguments:
32 | params (iterable): iterable of parameters to optimize or dicts defining
33 | parameter groups
34 | lr (float, optional): learning rate (default: 1e-3)
35 | betas (Tuple[float, float], optional): coefficients used for computing
36 | running averages of gradient and its square (default: (0.9, 0.999))
37 | eps (float, optional): term added to the denominator to improve
38 | numerical stability (default: 1e-8)
39 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
40 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
41 | algorithm from the paper `On the Convergence of Adam and Beyond`_
42 | .. _Adam\: A Method for Stochastic Optimization:
43 | https://arxiv.org/abs/1412.6980
44 | .. _On the Convergence of Adam and Beyond:
45 | https://openreview.net/forum?id=ryQu7f-RZ
46 | """
47 |
48 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
49 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
50 | super(Adam, self).__init__(params, defaults)
51 |
52 | def step(self, closure=None):
53 | """Performs a single optimization step.
54 | Arguments:
55 | closure (callable, optional): A closure that reevaluates the model
56 | and returns the loss.
57 | """
58 | loss = None
59 | if closure is not None:
60 | loss = closure()
61 |
62 | for group in self.param_groups:
63 | for p in group['params']:
64 | if p.grad is None:
65 | continue
66 | grad = p.grad.data
67 |
68 | amsgrad = group['amsgrad']
69 |
70 | state = self.state[p]
71 |
72 | # State initialization
73 | if len(state) == 0:
74 | state['step'] = 0
75 | # Exponential moving average of gradient values
76 | state['exp_avg'] = torch.zeros_like(p.data)
77 | # Exponential moving average of squared gradient values
78 | state['exp_avg_sq'] = torch.zeros_like(p.data)
79 | if amsgrad:
80 | # Maintains max of all exp. moving avg. of sq. grad. values
81 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
82 |
83 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
84 | if amsgrad:
85 | max_exp_avg_sq = state['max_exp_avg_sq']
86 | beta1, beta2 = group['betas']
87 |
88 | state['step'] += 1
89 |
90 | if grad.is_sparse:
91 | grad = grad.coalesce() # the update is non-linear so indices must be unique
92 | grad_indices = grad._indices()
93 | grad_values = grad._values()
94 | size = grad.size()
95 |
96 | def make_sparse(values):
97 | constructor = grad.new
98 | if grad_indices.dim() == 0 or values.dim() == 0:
99 | return constructor().resize_as_(grad)
100 | return constructor(grad_indices, values, size)
101 |
102 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
103 | beta1, beta2 = group['betas']
104 |
105 | # Decay the first and second moment running average coefficient
106 | # old <- b * old + (1 - b) * new
107 | # <==> old += (1 - b) * (new - old)
108 | old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
109 | exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
110 | exp_avg.add_(make_sparse(exp_avg_update_values))
111 | old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
112 | exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
113 | exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
114 |
115 | # Dense addition again is intended, avoiding another sparse_mask
116 | numer = exp_avg_update_values.add_(old_exp_avg_values)
117 | exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
118 | denom = exp_avg_sq_update_values.sqrt_().add_(group['eps'])
119 | del exp_avg_update_values, exp_avg_sq_update_values
120 |
121 | bias_correction1 = 1 - beta1 ** state['step']
122 | bias_correction2 = 1 - beta2 ** state['step']
123 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
124 |
125 | p.data.add_(make_sparse(-step_size * numer.div_(denom)))
126 | else:
127 | # Decay the first and second moment running average coefficient
128 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
129 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
130 | if amsgrad:
131 | # Maintains the maximum of all 2nd moment running avg. till now
132 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
133 | # Use the max. for normalizing running avg. of gradient
134 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
135 | else:
136 | denom = exp_avg_sq.sqrt().add_(group['eps'])
137 |
138 | bias_correction1 = 1 - beta1 ** state['step']
139 | bias_correction2 = 1 - beta2 ** state['step']
140 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
141 |
142 | if group['weight_decay'] != 0:
143 | p.data.add_(-group['weight_decay'] * group['lr'], p.data)
144 |
145 | p.data.addcdiv_(-step_size, exp_avg, denom)
146 |
147 | return loss
148 |
149 | def backward(self, losses):
150 | with torch.autograd.set_detect_anomaly(True):
151 | if not isinstance(losses, (tuple, list)):
152 | losses = [losses]
153 | full_loss = sum(losses, 0)
154 | full_loss.backward()
155 | return full_loss
156 |
157 | class NoamOpt:
158 | def __init__(self, embeddings_size, warmup, optimizer, linear_schedule=False, lr=None, total_steps=None,
159 | apex_level=None, loss_weight=None, extra_module_lr_rate=1.0):
160 | self.embeddings_size = embeddings_size
161 | self.warmup = warmup
162 | self.optimizer = optimizer
163 | self.linear_schedule = linear_schedule
164 | self.apex_level = apex_level
165 | self.lr = lr
166 | self.total_steps = total_steps
167 | self.loss_weight = loss_weight
168 | self.extra_module_lr_rate = extra_module_lr_rate
169 |
170 | self._step = 0
171 |
172 | def state_dict(self):
173 | return {'step': self._step,
174 | 'optimizer': self.optimizer.state_dict()}
175 |
176 | def load_state_dict(self, state_dict):
177 | self._step = state_dict['step']
178 | try:
179 | self.optimizer.load_state_dict(state_dict['optimizer'])
180 | except ValueError as e:
181 | logger.info("Optimizer cannot be loaded from checkpoint: {}".format(e))
182 | except KeyError as e:
183 | logger.info("Optimizer cannot be loaded from checkpoint: {}".format(e))
184 |
185 | def backward(self, losses):
186 | if not isinstance(losses, (tuple, list)):
187 | losses = [losses]
188 | if self.loss_weight is None:
189 | full_loss = sum(losses, 0)
190 | else:
191 | full_loss = torch.sum(torch.stack(losses, 0) * torch.exp(self.loss_weight[1])) + torch.sum(self.loss_weight[1])
192 |
193 | if self.apex_level is not None:
194 | try:
195 | from apex.amp import scale_loss
196 | except ImportError:
197 | raise ImportError("Please install apex.")
198 |
199 | for loss_id, loss in enumerate(losses):
200 | with scale_loss(loss, self.optimizer, loss_id=loss_id) as scaled_loss:
201 | scaled_loss.backward()
202 | else:
203 | full_loss.backward()
204 | return full_loss
205 |
206 | def zero_grad(self):
207 | return self.optimizer.zero_grad()
208 |
209 | def get_lr(self):
210 | return self.optimizer.param_groups[0]['lr']
211 |
212 | @property
213 | def param_groups(self):
214 | return self.optimizer.param_groups
215 |
216 | def step(self):
217 | self._step += 1
218 | rate = self.rate_linear() if self.linear_schedule else self.rate()
219 | for p in self.optimizer.param_groups:
220 | if p.__contains__('extra'):
221 | p['lr'] = rate * self.extra_module_lr_rate
222 | else:
223 | p['lr'] = rate
224 | self.optimizer.step()
225 |
226 | def rate(self, step=None):
227 | if step is None:
228 | step = self._step
229 |
230 | return self.lr * (self.embeddings_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))
231 |
232 | @staticmethod
233 | def warmup_linear(x, warmup=0.002):
234 | if x < warmup:
235 | return x/warmup
236 | return 1.0 - x
237 |
238 | def rate_linear(self, step=None):
239 | if step is None:
240 | step = self._step
241 | assert self.lr is not None and self.total_steps is not None
242 |
243 | return self.lr * self.warmup_linear(step/self.total_steps, self.warmup)
244 |
--------------------------------------------------------------------------------
/bert_score/score.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import pathlib
5 | import torch
6 | import matplotlib.pyplot as plt
7 | from mpl_toolkits.axes_grid1 import make_axes_locatable
8 | import numpy as np
9 | import pandas as pd
10 |
11 | from collections import defaultdict
12 | from transformers import AutoTokenizer
13 |
14 | from .utils import (
15 | get_model,
16 | get_idf_dict,
17 | bert_cos_score_idf,
18 | get_bert_embedding,
19 | lang2model,
20 | model2layers,
21 | get_hash,
22 | cache_scibert,
23 | sent_encode,
24 | )
25 |
26 |
27 | __all__ = ["score", "plot_example"]
28 |
29 |
30 | def get_bert_score(
31 | cands,
32 | refs,
33 | model_type=None,
34 | num_layers=None,
35 | verbose=False,
36 | idf=False,
37 | device=None,
38 | batch_size=64,
39 | nthreads=4,
40 | all_layers=False,
41 | lang=None,
42 | return_hash=False,
43 | rescale_with_baseline=False,
44 | baseline_path=None,
45 | ):
46 | """
47 | BERTScore metric.
48 |
49 | Args:
50 | - :param: `cands` (list of str): candidate sentences
51 | - :param: `refs` (list of str or list of list of str): reference sentences
52 | - :param: `model_type` (str): bert specification, default using the suggested
53 | model for the target langauge; has to specify at least one of
54 | `model_type` or `lang`
55 | - :param: `num_layers` (int): the layer of representation to use.
56 | default using the number of layer tuned on WMT16 correlation data
57 | - :param: `verbose` (bool): turn on intermediate status update
58 | - :param: `idf` (bool or dict): use idf weighting, can also be a precomputed idf_dict
59 | - :param: `device` (str): on which the contextual embedding model will be allocated on.
60 | If this argument is None, the model lives on cuda:0 if cuda is available.
61 | - :param: `nthreads` (int): number of threads
62 | - :param: `batch_size` (int): bert score processing batch size
63 | - :param: `lang` (str): language of the sentences; has to specify
64 | at least one of `model_type` or `lang`. `lang` needs to be
65 | specified when `rescale_with_baseline` is True.
66 | - :param: `return_hash` (bool): return hash code of the setting
67 | - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
68 | - :param: `baseline_path` (str): customized baseline file
69 |
70 | Return:
71 | - :param: `(P, R, F)`: each is of shape (N); N = number of input
72 | candidate reference pairs. if returning hashcode, the
73 | output will be ((P, R, F), hashcode). If a candidate have
74 | multiple references, the returned score of this candidate is
75 | the *best* score among all references.
76 | """
77 | assert len(cands) == len(refs), "Different number of candidates and references"
78 |
79 | assert lang is not None or model_type is not None, "Either lang or model_type should be specified"
80 |
81 | ref_group_boundaries = None
82 | if not isinstance(refs[0], str):
83 | ref_group_boundaries = []
84 | ori_cands, ori_refs = cands, refs
85 | cands, refs = [], []
86 | count = 0
87 | for cand, ref_group in zip(ori_cands, ori_refs):
88 | cands += [cand] * len(ref_group)
89 | refs += ref_group
90 | ref_group_boundaries.append((count, count + len(ref_group)))
91 | count += len(ref_group)
92 |
93 | if rescale_with_baseline:
94 | assert lang is not None, "Need to specify Language when rescaling with baseline"
95 |
96 | if model_type is None:
97 | lang = lang.lower()
98 | model_type = lang2model[lang]
99 | if num_layers is None:
100 | num_layers = model2layers[model_type]
101 |
102 | if model_type.startswith("scibert"):
103 | tokenizer = AutoTokenizer.from_pretrained(cache_scibert(model_type))
104 | else:
105 | tokenizer = AutoTokenizer.from_pretrained(model_type)
106 |
107 | model = get_model(model_type, num_layers, all_layers)
108 | if device is None:
109 | device = "cuda" if torch.cuda.is_available() else "cpu"
110 | model.to(device)
111 |
112 | if not idf:
113 | idf_dict = defaultdict(lambda: 1.0)
114 | # set idf for [SEP] and [CLS] to 0
115 | idf_dict[tokenizer.sep_token_id] = 0
116 | idf_dict[tokenizer.cls_token_id] = 0
117 | elif isinstance(idf, dict):
118 | if verbose:
119 | print("using predefined IDF dict...")
120 | idf_dict = idf
121 | else:
122 | if verbose:
123 | print("preparing IDF dict...")
124 | start = time.perf_counter()
125 | idf_dict = get_idf_dict(refs, tokenizer, nthreads=nthreads)
126 | if verbose:
127 | print("done in {:.2f} seconds".format(time.perf_counter() - start))
128 |
129 | if verbose:
130 | print("calculating scores...")
131 | start = time.perf_counter()
132 | all_preds = bert_cos_score_idf(
133 | model,
134 | refs,
135 | cands,
136 | tokenizer,
137 | idf_dict,
138 | verbose=verbose,
139 | device=device,
140 | batch_size=batch_size,
141 | all_layers=all_layers,
142 | ).cpu()
143 |
144 | if ref_group_boundaries is not None:
145 | max_preds = []
146 | for beg, end in ref_group_boundaries:
147 | max_preds.append(all_preds[beg:end].max(dim=0)[0])
148 | all_preds = torch.stack(max_preds, dim=0)
149 |
150 | use_custom_baseline = baseline_path is not None
151 | if rescale_with_baseline:
152 | if baseline_path is None:
153 | baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv")
154 | if os.path.isfile(baseline_path):
155 | if not all_layers:
156 | baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float()
157 | else:
158 | baselines = torch.from_numpy(pd.read_csv(baseline_path).to_numpy())[:, 1:].unsqueeze(1).float()
159 |
160 | all_preds = (all_preds - baselines) / (1 - baselines)
161 | else:
162 | print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr)
163 |
164 | out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2] # P, R, F
165 |
166 | if verbose:
167 | time_diff = time.perf_counter() - start
168 | print(f"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec")
169 |
170 | if return_hash:
171 | return tuple(
172 | [out, get_hash(model_type, num_layers, idf, rescale_with_baseline, use_custom_baseline=use_custom_baseline)]
173 | )
174 |
175 | return out
176 |
177 |
178 | def plot_example(
179 | candidate,
180 | reference,
181 | model_type=None,
182 | num_layers=None,
183 | lang=None,
184 | rescale_with_baseline=False,
185 | baseline_path=None,
186 | fname="",
187 | ):
188 | """
189 | BERTScore metric.
190 |
191 | Args:
192 | - :param: `candidate` (str): a candidate sentence
193 | - :param: `reference` (str): a reference sentence
194 | - :param: `verbose` (bool): turn on intermediate status update
195 | - :param: `model_type` (str): bert specification, default using the suggested
196 | model for the target langauge; has to specify at least one of
197 | `model_type` or `lang`
198 | - :param: `num_layers` (int): the layer of representation to use
199 | - :param: `lang` (str): language of the sentences; has to specify
200 | at least one of `model_type` or `lang`. `lang` needs to be
201 | specified when `rescale_with_baseline` is True.
202 | - :param: `return_hash` (bool): return hash code of the setting
203 | - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
204 | - :param: `fname` (str): path to save the output plot
205 | """
206 | assert isinstance(candidate, str)
207 | assert isinstance(reference, str)
208 |
209 | assert lang is not None or model_type is not None, "Either lang or model_type should be specified"
210 |
211 | if rescale_with_baseline:
212 | assert lang is not None, "Need to specify Language when rescaling with baseline"
213 |
214 | if model_type is None:
215 | lang = lang.lower()
216 | model_type = lang2model[lang]
217 | if num_layers is None:
218 | num_layers = model2layers[model_type]
219 |
220 | if model_type.startswith("scibert"):
221 | tokenizer = AutoTokenizer.from_pretrained(cache_scibert(model_type))
222 | else:
223 | tokenizer = AutoTokenizer.from_pretrained(model_type)
224 | model = get_model(model_type, num_layers)
225 | device = "cuda" if torch.cuda.is_available() else "cpu"
226 | model.to(device)
227 |
228 | idf_dict = defaultdict(lambda: 1.0)
229 | # set idf for [SEP] and [CLS] to 0
230 | idf_dict[tokenizer.sep_token_id] = 0
231 | idf_dict[tokenizer.cls_token_id] = 0
232 |
233 | hyp_embedding, masks, padded_idf = get_bert_embedding(
234 | [candidate], model, tokenizer, idf_dict, device=device, all_layers=False
235 | )
236 | ref_embedding, masks, padded_idf = get_bert_embedding(
237 | [reference], model, tokenizer, idf_dict, device=device, all_layers=False
238 | )
239 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1))
240 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1))
241 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2))
242 | sim = sim.squeeze(0).cpu()
243 |
244 | # remove [CLS] and [SEP] tokens
245 | r_tokens = [tokenizer.decode([i]) for i in sent_encode(tokenizer, reference)][1:-1]
246 | h_tokens = [tokenizer.decode([i]) for i in sent_encode(tokenizer, candidate)][1:-1]
247 | sim = sim[1:-1, 1:-1]
248 |
249 | if rescale_with_baseline:
250 | if baseline_path is None:
251 | baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv")
252 | if os.path.isfile(baseline_path):
253 | baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float()
254 | sim = (sim - baselines[2].item()) / (1 - baselines[2].item())
255 | else:
256 | print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr)
257 |
258 | fig, ax = plt.subplots(figsize=(len(r_tokens), len(h_tokens)))
259 | im = ax.imshow(sim, cmap="Blues", vmin=0, vmax=1)
260 |
261 | # We want to show all ticks...
262 | ax.set_xticks(np.arange(len(r_tokens)))
263 | ax.set_yticks(np.arange(len(h_tokens)))
264 | # ... and label them with the respective list entries
265 | ax.set_xticklabels(r_tokens, fontsize=10)
266 | ax.set_yticklabels(h_tokens, fontsize=10)
267 | ax.grid(False)
268 | plt.xlabel("Reference (tokenized)", fontsize=14)
269 | plt.ylabel("Candidate (tokenized)", fontsize=14)
270 | title = "Similarity Matrix"
271 | if rescale_with_baseline:
272 | title += " (after Rescaling)"
273 | plt.title(title, fontsize=14)
274 |
275 | divider = make_axes_locatable(ax)
276 | cax = divider.append_axes("right", size="2%", pad=0.2)
277 | fig.colorbar(im, cax=cax)
278 |
279 | # Rotate the tick labels and set their alignment.
280 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
281 |
282 | # Loop over data dimensions and create text annotations.
283 | for i in range(len(h_tokens)):
284 | for j in range(len(r_tokens)):
285 | text = ax.text(
286 | j,
287 | i,
288 | "{:.3f}".format(sim[i, j].item()),
289 | ha="center",
290 | va="center",
291 | color="k" if sim[i, j].item() < 0.5 else "w",
292 | )
293 |
294 | fig.tight_layout()
295 | if fname != "":
296 | plt.savefig(fname, dpi=100)
297 | print("Saved figure to file: ", fname)
298 | plt.show()
299 |
--------------------------------------------------------------------------------
/bert_score/scorer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import pathlib
5 | import torch
6 | import matplotlib.pyplot as plt
7 | from mpl_toolkits.axes_grid1 import make_axes_locatable
8 | import numpy as np
9 | import pandas as pd
10 | import warnings
11 |
12 | from collections import defaultdict
13 | from transformers import AutoTokenizer
14 |
15 | from .utils import (
16 | get_model,
17 | get_idf_dict,
18 | bert_cos_score_idf,
19 | get_bert_embedding,
20 | lang2model,
21 | model2layers,
22 | get_hash,
23 | cache_scibert,
24 | sent_encode,
25 | )
26 |
27 |
28 | class BERTScorer:
29 | """
30 | BERTScore Scorer Object.
31 | """
32 |
33 | def __init__(
34 | self,
35 | model_type=None,
36 | num_layers=None,
37 | batch_size=64,
38 | nthreads=4,
39 | all_layers=False,
40 | idf=False,
41 | idf_sents=None,
42 | device=None,
43 | lang=None,
44 | rescale_with_baseline=False,
45 | baseline_path=None,
46 | ):
47 | """
48 | Args:
49 | - :param: `model_type` (str): contexual embedding model specification, default using the suggested
50 | model for the target langauge; has to specify at least one of
51 | `model_type` or `lang`
52 | - :param: `num_layers` (int): the layer of representation to use.
53 | default using the number of layer tuned on WMT16 correlation data
54 | - :param: `verbose` (bool): turn on intermediate status update
55 | - :param: `idf` (bool): a booling to specify whether to use idf or not (this should be True even if `idf_sents` is given)
56 | - :param: `idf_sents` (List of str): list of sentences used to compute the idf weights
57 | - :param: `device` (str): on which the contextual embedding model will be allocated on.
58 | If this argument is None, the model lives on cuda:0 if cuda is available.
59 | - :param: `batch_size` (int): bert score processing batch size
60 | - :param: `nthreads` (int): number of threads
61 | - :param: `lang` (str): language of the sentences; has to specify
62 | at least one of `model_type` or `lang`. `lang` needs to be
63 | specified when `rescale_with_baseline` is True.
64 | - :param: `return_hash` (bool): return hash code of the setting
65 | - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
66 | - :param: `baseline_path` (str): customized baseline file
67 | """
68 |
69 | assert lang is not None or model_type is not None, "Either lang or model_type should be specified"
70 |
71 | if rescale_with_baseline:
72 | assert lang is not None, "Need to specify Language when rescaling with baseline"
73 |
74 | if device is None:
75 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
76 | else:
77 | self.device = device
78 |
79 | self._lang = lang
80 | self._rescale_with_baseline = rescale_with_baseline
81 | self._idf = idf
82 | self.batch_size = batch_size
83 | self.nthreads = nthreads
84 | self.all_layers = all_layers
85 |
86 | if model_type is None:
87 | lang = lang.lower()
88 | self._model_type = lang2model[lang]
89 | else:
90 | self._model_type = model_type
91 |
92 | if num_layers is None:
93 | self._num_layers = model2layers[self.model_type]
94 | else:
95 | self._num_layers = num_layers
96 |
97 | # Building model and tokenizer
98 |
99 | if self.model_type.startswith("scibert"):
100 | self._tokenizer = AutoTokenizer.from_pretrained(cache_scibert(self.model_type))
101 | else:
102 | self._tokenizer = AutoTokenizer.from_pretrained(self.model_type)
103 |
104 | self._model = get_model(self.model_type, self.num_layers, self.all_layers)
105 | self._model.to(self.device)
106 |
107 | self._idf_dict = None
108 | if idf_sents is not None:
109 | self.compute_idf(idf_sents)
110 |
111 | self._baseline_vals = None
112 | self.baseline_path = baseline_path
113 | self.use_custom_baseline = self.baseline_path is not None
114 | if self.baseline_path is None:
115 | self.baseline_path = os.path.join(
116 | os.path.dirname(__file__), f"rescale_baseline/{self.lang}/{self.model_type}.tsv"
117 | )
118 |
119 | @property
120 | def lang(self):
121 | return self._lang
122 |
123 | @property
124 | def idf(self):
125 | return self._idf
126 |
127 | @property
128 | def model_type(self):
129 | return self._model_type
130 |
131 | @property
132 | def num_layers(self):
133 | return self._num_layers
134 |
135 | @property
136 | def rescale_with_baseline(self):
137 | return self._rescale_with_baseline
138 |
139 | @property
140 | def baseline_vals(self):
141 | if self._baseline_vals is None:
142 | if os.path.isfile(self.baseline_path):
143 | if not self.all_layers:
144 | self._baseline_vals = torch.from_numpy(
145 | pd.read_csv(self.baseline_path).iloc[self.num_layers].to_numpy()
146 | )[1:].float()
147 | else:
148 | self._baseline_vals = (
149 | torch.from_numpy(pd.read_csv(self.baseline_path).to_numpy())[:, 1:].unsqueeze(1).float()
150 | )
151 | else:
152 | raise ValueError(f"Baseline not Found for {self.model_type} on {self.lang} at {self.baseline_path}")
153 |
154 | return self._baseline_vals
155 |
156 | @property
157 | def hash(self):
158 | return get_hash(
159 | self.model_type, self.num_layers, self.idf, self.rescale_with_baseline, self.use_custom_baseline
160 | )
161 |
162 | def compute_idf(self, sents):
163 | """
164 | Args:
165 |
166 | """
167 | if self._idf_dict is not None:
168 | warnings.warn("Overwriting the previous importance weights.")
169 |
170 | self._idf_dict = get_idf_dict(sents, self._tokenizer, nthreads=self.nthreads)
171 |
172 | def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False):
173 | """
174 | Args:
175 | - :param: `cands` (list of str): candidate sentences
176 | - :param: `refs` (list of str or list of list of str): reference sentences
177 |
178 | Return:
179 | - :param: `(P, R, F)`: each is of shape (N); N = number of input
180 | candidate reference pairs. if returning hashcode, the
181 | output will be ((P, R, F), hashcode). If a candidate have
182 | multiple references, the returned score of this candidate is
183 | the *best* score among all references.
184 | """
185 |
186 | ref_group_boundaries = None
187 | if not isinstance(refs[0], str):
188 | ref_group_boundaries = []
189 | ori_cands, ori_refs = cands, refs
190 | cands, refs = [], []
191 | count = 0
192 | for cand, ref_group in zip(ori_cands, ori_refs):
193 | cands += [cand] * len(ref_group)
194 | refs += ref_group
195 | ref_group_boundaries.append((count, count + len(ref_group)))
196 | count += len(ref_group)
197 |
198 | if verbose:
199 | print("calculating scores...")
200 | start = time.perf_counter()
201 |
202 | if self.idf:
203 | assert self._idf_dict, "IDF weights are not computed"
204 | idf_dict = self._idf_dict
205 | else:
206 | idf_dict = defaultdict(lambda: 1.0)
207 | idf_dict[self._tokenizer.sep_token_id] = 0
208 | idf_dict[self._tokenizer.cls_token_id] = 0
209 |
210 | all_preds = bert_cos_score_idf(
211 | self._model,
212 | refs,
213 | cands,
214 | self._tokenizer,
215 | idf_dict,
216 | verbose=verbose,
217 | device=self.device,
218 | batch_size=batch_size,
219 | all_layers=self.all_layers,
220 | ).cpu()
221 |
222 | if ref_group_boundaries is not None:
223 | max_preds = []
224 | for start, end in ref_group_boundaries:
225 | max_preds.append(all_preds[start:end].max(dim=0)[0])
226 | all_preds = torch.stack(max_preds, dim=0)
227 |
228 | if self.rescale_with_baseline:
229 | all_preds = (all_preds - self.baseline_vals) / (1 - self.baseline_vals)
230 |
231 | out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2] # P, R, F
232 |
233 | if verbose:
234 | time_diff = time.perf_counter() - start
235 | print(f"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec")
236 |
237 | if return_hash:
238 | out = tuple([out, self.hash])
239 |
240 | return out
241 |
242 | def plot_example(self, candidate, reference, fname=""):
243 | """
244 | Args:
245 | - :param: `candidate` (str): a candidate sentence
246 | - :param: `reference` (str): a reference sentence
247 | - :param: `fname` (str): path to save the output plot
248 | """
249 |
250 | assert isinstance(candidate, str)
251 | assert isinstance(reference, str)
252 |
253 | idf_dict = defaultdict(lambda: 1.0)
254 | idf_dict[self._tokenizer.sep_token_id] = 0
255 | idf_dict[self._tokenizer.cls_token_id] = 0
256 |
257 | hyp_embedding, masks, padded_idf = get_bert_embedding(
258 | [candidate], self._model, self._tokenizer, idf_dict, device=self.device, all_layers=False
259 | )
260 | ref_embedding, masks, padded_idf = get_bert_embedding(
261 | [reference], self._model, self._tokenizer, idf_dict, device=self.device, all_layers=False
262 | )
263 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1))
264 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1))
265 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2))
266 | sim = sim.squeeze(0).cpu()
267 |
268 | r_tokens = [self._tokenizer.decode([i]) for i in sent_encode(self._tokenizer, reference)][1:-1]
269 | h_tokens = [self._tokenizer.decode([i]) for i in sent_encode(self._tokenizer, candidate)][1:-1]
270 | sim = sim[1:-1, 1:-1]
271 |
272 | if self.rescale_with_baseline:
273 | sim = (sim - self.baseline_vals[2].item()) / (1 - self.baseline_vals[2].item())
274 |
275 | fig, ax = plt.subplots(figsize=(len(r_tokens), len(h_tokens)))
276 | im = ax.imshow(sim, cmap="Blues", vmin=0, vmax=1)
277 |
278 | # We want to show all ticks...
279 | ax.set_xticks(np.arange(len(r_tokens)))
280 | ax.set_yticks(np.arange(len(h_tokens)))
281 | # ... and label them with the respective list entries
282 | ax.set_xticklabels(r_tokens, fontsize=10)
283 | ax.set_yticklabels(h_tokens, fontsize=10)
284 | ax.grid(False)
285 | plt.xlabel("Reference (tokenized)", fontsize=14)
286 | plt.ylabel("Candidate (tokenized)", fontsize=14)
287 | title = "Similarity Matrix"
288 | if self.rescale_with_baseline:
289 | title += " (after Rescaling)"
290 | plt.title(title, fontsize=14)
291 |
292 | divider = make_axes_locatable(ax)
293 | cax = divider.append_axes("right", size="2%", pad=0.2)
294 | fig.colorbar(im, cax=cax)
295 |
296 | # Rotate the tick labels and set their alignment.
297 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
298 |
299 | # Loop over data dimensions and create text annotations.
300 | for i in range(len(h_tokens)):
301 | for j in range(len(r_tokens)):
302 | text = ax.text(
303 | j,
304 | i,
305 | "{:.3f}".format(sim[i, j].item()),
306 | ha="center",
307 | va="center",
308 | color="k" if sim[i, j].item() < 0.5 else "w",
309 | )
310 |
311 | fig.tight_layout()
312 | if fname != "":
313 | plt.savefig(fname, dpi=100)
314 | print("Saved figure to file: ", fname)
315 | plt.show()
316 |
317 | def __repr__(self):
318 | return f"{self.__class__.__name__}(hash={self.hash}, batch_size={self.batch_size}, nthreads={self.nthreads})"
319 |
320 | def __str__(self):
321 | return self.__repr__()
322 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import json
4 | import torch
5 | import torch.nn as nn
6 |
7 | from config import get_trainer_config, InputConfig
8 | from model.dataset import FacebookDataset
9 | from model.trainer import Trainer
10 | from model.gpt2_model import GPT2EncoderDecoderModel, GPT2DoubleHeadsModel
11 | from model.utils import f1_score, open, set_seed, config_logger
12 | from transformers.tokenization_gpt2 import GPT2Tokenizer
13 | from metrics import nlp_metrics
14 | from model.seq2seq import TransformerSeq2Seq
15 | from model.seq2seq_vocab import Seq2seqVocab
16 | from model.entailment_score import EntailmentScorer
17 | from bert_score.score import get_bert_score
18 |
19 | PADDING_IDX = 0
20 |
21 | def modify_tokenizer(tokenizer, data_type):
22 | additional_special_tokens = ['', '', '', '', '',
23 | '']
24 | if data_type == 'emoji':
25 | with open('datasets/emoji_talk/emojis.json', 'r') as f:
26 | emojis = json.load(f)['emojis']
27 | additional_special_tokens.extend(emojis)
28 | tokenizer.add_special_tokens({'pad_token': '', 'bos_token': '', 'eos_token': '',
29 | 'additional_special_tokens': additional_special_tokens})
30 | tokenizer.eos_id, tokenizer.bos_id, tokenizer.pad_id = tokenizer.eos_token_id, tokenizer.bos_token_id, tokenizer.pad_token_id
31 | tokenizer.sent_dialog_id = tokenizer.bos_token_id
32 | tokenizer.info_dialog_id, tokenizer.info_bos_id = tokenizer.added_tokens_encoder[''], \
33 | tokenizer.added_tokens_encoder[
34 | '']
35 | tokenizer.info_eos_id = tokenizer.added_tokens_encoder['']
36 | tokenizer.talker1_dialog_id, tokenizer.talker1_bos_id = tokenizer.added_tokens_encoder[''], \
37 | tokenizer.added_tokens_encoder['']
38 | tokenizer.talker1_eos_id = tokenizer.added_tokens_encoder['']
39 | tokenizer.talker2_dialog_id, tokenizer.talker2_bos_id = tokenizer.added_tokens_encoder[''], \
40 | tokenizer.added_tokens_encoder['']
41 | tokenizer.talker2_eos_id = tokenizer.added_tokens_encoder['']
42 | return tokenizer, len(additional_special_tokens) + 3
43 |
44 | def pad_sequence(sequences, batch_first=False, padding_value=0, left=False):
45 | # assuming trailing dimensions and type of all the Tensors
46 | # in sequences are same and fetching those from sequences[0]
47 | if not len(sequences):
48 | return torch.empty(0)
49 | trailing_dims = sequences[0].size()[1:]
50 | max_len = max([s.size(0) for s in sequences])
51 | if batch_first:
52 | out_dims = (len(sequences), max_len) + trailing_dims
53 | else:
54 | out_dims = (max_len, len(sequences)) + trailing_dims
55 |
56 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
57 | for i, tensor in enumerate(sequences):
58 | length = tensor.size(0)
59 | s_slice = slice(-length, None) if left else slice(None, length)
60 | s_slice = (i, s_slice) if batch_first else (s_slice, i)
61 | out_tensor[s_slice] = tensor
62 |
63 | return out_tensor
64 |
65 | def main():
66 | args = InputConfig().args
67 |
68 | trainer_config = get_trainer_config(args)
69 |
70 | set_seed(trainer_config.seed)
71 | device = torch.device(trainer_config.device)
72 | save_path = trainer_config.load_last[:trainer_config.load_last.rfind('/')]
73 | generate_file_name = args.generate_file_name
74 |
75 | logger = config_logger(os.path.join(save_path, 'inference.log'))
76 |
77 | parsed_valid_data, parsed_test_data = None, None
78 | if args.model_type == 'gpt2':
79 | if args.single_input:
80 | model = GPT2DoubleHeadsModel.from_pretrained('./gpt2-small')
81 | else:
82 | model = GPT2EncoderDecoderModel.from_pretrained('./gpt2-small')
83 | tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-small')
84 | elif args.model_type == 'seq2seq' or args.model_type == 'rnn-seq2seq':
85 | seq2seq_vocab = Seq2seqVocab(trainer_config.train_datasets, trainer_config.valid_datasets,
86 | trainer_config.test_datasets, args.vocab_path, data_type=args.data_type)
87 | tokenizer = seq2seq_vocab.vocab
88 | if args.model_type == 'seq2seq':
89 | model = TransformerSeq2Seq(args.emb_dim, args.hidden_dim, args.num_layers, args.heads, args.depth_size,
90 | args.filter_size, tokenizer, args.pretrained_emb_file, args.pointer_gen, logger,
91 | multi_input=not args.single_input,
92 | attention_pooling_type=args.attention_pooling_type)
93 | else:
94 | model = TransformerSeq2Seq(args.emb_dim, args.hidden_dim, args.num_layers, args.heads, args.depth_size,
95 | args.filter_size, tokenizer, args.pretrained_emb_file, args.pointer_gen, logger,
96 | base_model='gru')
97 | args.dialog_embeddings = False
98 |
99 | model.shared_attention = (args.shared_attention == 1)
100 | model.shared_module = (args.shared_module == 1)
101 | model.attention_pooling_type = args.attention_pooling_type
102 | if args.model_type in ['gpt', 'dialogpt', 'gpt2']:
103 | tokenizer, additional_length = modify_tokenizer(tokenizer, args.data_type)
104 | model.embeddings_size = 768
105 | model.n_embeddings = len(tokenizer)
106 | model.shared_attention = (args.shared_attention == 1)
107 | model.shared_module = (args.shared_module == 1)
108 | model.attention_pooling_type = args.attention_pooling_type
109 | model.single_input = args.single_input
110 | if args.model_type == 'gpt':
111 | model_embedding_weight = model.transformer.tokens_embed.weight
112 | model.transformer.tokens_embed = nn.Embedding(model.n_embeddings, 768)
113 | model.lm_head = nn.Linear(768, model.n_embeddings, bias=False)
114 | model.transformer.tokens_embed.weight.data[:-additional_length, :] = model_embedding_weight.data
115 | model.transformer.tokens_embed.weight.data[-additional_length:, :] = 0
116 | model.lm_head.weight = model.transformer.tokens_embed.weight
117 | else:
118 | model_embedding_weight = model.transformer.wte.weight
119 | model.transformer.wte = nn.Embedding(model.n_embeddings, 768)
120 | model.lm_head = nn.Linear(768, model.n_embeddings, bias=False)
121 | model.transformer.wte.weight.data[:-additional_length, :] = model_embedding_weight.data
122 | model.transformer.wte.weight.data[-additional_length:, :] = 0
123 | model.lm_head.weight = model.transformer.wte.weight
124 |
125 | if not args.single_input:
126 | model.reload_module_dict()
127 | model.sent_dialog_id = tokenizer.sent_dialog_id
128 |
129 | model.padding_idx = tokenizer.pad_id
130 | model.n_pos_embeddings = 512
131 |
132 | model.talker1_id = tokenizer.talker1_bos_id
133 | model.talker2_id = tokenizer.talker2_bos_id
134 | model.bos_id = tokenizer.bos_id
135 | model.eos_id = tokenizer.eos_id
136 | model.beam_size = args.beam_size
137 | model.diversity_groups = 1
138 | model.max_seq_len = 32
139 | model.dialog_embeddings = args.dialog_embeddings
140 | model.bs_temperature = args.bs_temperature
141 | model.bs_nucleus_p = args.bs_nucleus_p
142 | model.annealing_topk = args.annealing_topk
143 | model.length_penalty_coef = args.length_penalty
144 | model.vocab = None
145 | model.annealing = args.annealing
146 | model.diversity_coef = args.diversity_coef
147 | model.sample = False
148 | model.inference_mode = args.inference_mode
149 | model.response_k = args.response_k
150 |
151 | logger.info('loading datasets')
152 | valid_dataset = None
153 | test_dataset = FacebookDataset(trainer_config.test_datasets, tokenizer,
154 | max_lengths=(model.n_pos_embeddings - 1) // (3 if args.single_input else 1), # A bit restrictive here
155 | dialog_embeddings=args.dialog_embeddings,
156 | cache=trainer_config.test_datasets_cache,
157 | use_start_end=args.use_start_end,
158 | negative_samples=0, # Keep all negative samples
159 | augment=False,
160 | aug_syn_proba=0.0,
161 | limit_size=trainer_config.limit_eval_size,
162 | single_input=args.single_input,
163 | data_type=args.data_type,
164 | parsed_data=parsed_test_data)
165 | logger.info(f'test dataset {(len(test_dataset))}')
166 |
167 | state_dict = torch.load(trainer_config.load_last, map_location=device)
168 | if state_dict.__contains__('model'):
169 | model.load_state_dict(state_dict['model'], strict=False)
170 | else:
171 | model.load_state_dict(state_dict)
172 | model.to(device)
173 | logger.info('Weights loaded from {}'.format(trainer_config.load_last))
174 |
175 | trainer = Trainer(model,
176 | test_dataset,
177 | trainer_config,
178 | None,
179 | logger=logger,
180 | test_dataset=test_dataset,
181 | valid_dataset=valid_dataset,
182 | n_jobs=trainer_config.n_jobs,
183 | device=device,
184 | ignore_idxs=tokenizer.all_special_ids,
185 | local_rank=args.local_rank,
186 | apex_level=None,
187 | apex_loss_scale=trainer_config.apex_loss_scale,
188 | evaluate_full_sequences=trainer_config.evaluate_full_sequences,
189 | full_input=trainer_config.full_input,
190 | uncertainty_loss=args.uncertainty_loss)
191 |
192 | def external_metrics_func(full_references, full_predictions, epoch, metric=None, generate_entail=False):
193 | if epoch == -1:
194 | references_file_path = os.path.join(save_path, 'test_references_file.txt')
195 | predictions_file_path = os.path.join(save_path, 'test_predictions_file.txt')
196 | else:
197 | references_file_path = os.path.join(save_path, 'eval_references_file.txt')
198 | predictions_file_path = os.path.join(save_path, 'eval_predictions_file.txt')
199 | with open(references_file_path, 'w', encoding='utf-8') as f:
200 | f.write('\n'.join(full_references))
201 | with open(predictions_file_path, 'w', encoding='utf-8') as f:
202 | f.write('\n'.join(full_predictions))
203 |
204 | bleu, bleu_list, nist, nist_list, nist_bleu, nist_bleu_list, s_dist, c_dist, entropy, meteor, \
205 | rouge_l, f1_score, avg_length = nlp_metrics(references_file_path, predictions_file_path)
206 |
207 | metrics = {'meteor': meteor * 100, 'avg_len': avg_length, 'rouge-l': rouge_l * 100, 'bleu': bleu, 'nist': nist,
208 | 'nist-bleu': nist_bleu, 'f1': f1_score * 100}
209 | for name, metric in (
210 | ('bleu', bleu_list), ('nist', nist_list), ('nist_bleu', nist_bleu_list), ('entropy', entropy),
211 | ('sentence_div', s_dist), ('corpus_div', c_dist)):
212 | for i, m in enumerate(metric, 1):
213 | if name == 'sentence_div' or name == 'corpus_div':
214 | metrics['{}_{}'.format(name, i)] = m * 100
215 | else:
216 | metrics['{}_{}'.format(name, i)] = m
217 | if args.entail_score_refs_file and epoch == -1:
218 | entailment_scorer = EntailmentScorer(predictions_file_path, args.entail_score_refs_file,
219 | args.entail_model_path, device)
220 | metrics['entail_score'] = entailment_scorer.calculate_entailment_score()
221 | if args.bert_score_model_path is not None and epoch == -1:
222 | all_preds = get_bert_score(
223 | full_predictions,
224 | full_references,
225 | model_type=args.bert_score_model_path,
226 | num_layers=16,
227 | batch_size=16,
228 | )
229 | metrics['bert_score_p'] = torch.mean(all_preds[0]).item()
230 | metrics['bert_score_r'] = torch.mean(all_preds[1]).item()
231 | metrics['bert_score_f'] = torch.mean(all_preds[2]).item()
232 | for k, v in metrics.items():
233 | metrics[k] = round(v, 6)
234 | return metrics
235 |
236 | def external_metrics_func_entail_data(full_predictions, raw_entail_data):
237 | for i, prediction in enumerate(full_predictions):
238 | raw_entail_data[i][2] = prediction
239 | with open(generate_file_name, 'w') as f:
240 | json.dump(raw_entail_data, f)
241 |
242 | metric_funcs = {'f1_score': f1_score}
243 | # trainer.test(metric_funcs, external_metrics_func, epoch=0, inference=True)
244 | if args.data_type == 'entailment':
245 | with open(args.test_datasets, 'r') as f:
246 | raw_entail_data = json.load(f)
247 | trainer.test(metric_funcs, external_metrics_func_entail_data, epoch=-1, inference=True,
248 | raw_entail_data=raw_entail_data)
249 | else:
250 | trainer.test(metric_funcs, external_metrics_func, epoch=-1, inference=True)
251 |
252 | if __name__ == '__main__':
253 | main()
254 |
--------------------------------------------------------------------------------
/attention_experiment/attention_exp.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 |
4 | import json
5 | import torch
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import torch.nn as nn
9 |
10 | from config import get_trainer_config, InputConfig
11 | from model.dataset import FacebookDataset
12 | from model.trainer import Trainer
13 | from model.gpt2_model import GPT2DoubleHeadsModel
14 | from transformers.tokenization_gpt2 import GPT2Tokenizer
15 | from model.utils import open, set_seed, config_logger
16 | from model.seq2seq import TransformerSeq2Seq
17 | from model.seq2seq_vocab import Seq2seqVocab
18 |
19 | PADDING_IDX = 0
20 |
21 | def modify_tokenizer(tokenizer, data_type):
22 | additional_special_tokens = ['', '', '', '', '',
23 | '']
24 | if data_type == 'emoji':
25 | with open('datasets/emoji_talk/emojis.json', 'r') as f:
26 | emojis = json.load(f)['emojis']
27 | additional_special_tokens.extend(emojis)
28 | tokenizer.add_special_tokens({'pad_token': '', 'bos_token': '', 'eos_token': '',
29 | 'additional_special_tokens': additional_special_tokens})
30 | tokenizer.eos_id, tokenizer.bos_id, tokenizer.pad_id = tokenizer.eos_token_id, tokenizer.bos_token_id, tokenizer.pad_token_id
31 | tokenizer.sent_dialog_id = tokenizer.bos_token_id
32 | tokenizer.info_dialog_id, tokenizer.info_bos_id = tokenizer.added_tokens_encoder[''], \
33 | tokenizer.added_tokens_encoder[
34 | '']
35 | tokenizer.info_eos_id = tokenizer.added_tokens_encoder['']
36 | tokenizer.talker1_dialog_id, tokenizer.talker1_bos_id = tokenizer.added_tokens_encoder[''], \
37 | tokenizer.added_tokens_encoder['']
38 | tokenizer.talker1_eos_id = tokenizer.added_tokens_encoder['']
39 | tokenizer.talker2_dialog_id, tokenizer.talker2_bos_id = tokenizer.added_tokens_encoder[''], \
40 | tokenizer.added_tokens_encoder['']
41 | tokenizer.talker2_eos_id = tokenizer.added_tokens_encoder['']
42 | return tokenizer, len(additional_special_tokens) + 3
43 |
44 | def pad_sequence(sequences, batch_first=False, padding_value=0, left=False):
45 | # assuming trailing dimensions and type of all the Tensors
46 | # in sequences are same and fetching those from sequences[0]
47 | if not len(sequences):
48 | return torch.empty(0)
49 | trailing_dims = sequences[0].size()[1:]
50 | max_len = max([s.size(0) for s in sequences])
51 | if batch_first:
52 | out_dims = (len(sequences), max_len) + trailing_dims
53 | else:
54 | out_dims = (max_len, len(sequences)) + trailing_dims
55 |
56 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
57 | for i, tensor in enumerate(sequences):
58 | length = tensor.size(0)
59 | s_slice = slice(-length, None) if left else slice(None, length)
60 | s_slice = (i, s_slice) if batch_first else (s_slice, i)
61 | out_tensor[s_slice] = tensor
62 |
63 | return out_tensor
64 |
65 | def collate_func(data):
66 | persona_info, h, y, distractors_batch = zip(*data)
67 |
68 | contexts = []
69 |
70 | if max(map(len, persona_info)) > 0:
71 | persona_info = [torch.tensor(d, dtype=torch.long) for d in persona_info]
72 | contexts.append(persona_info)
73 |
74 | if max(map(len, h)) > 0:
75 | h = [torch.tensor(d, dtype=torch.long) for d in h]
76 | contexts.append(h)
77 |
78 | y_out = [torch.tensor(d, dtype=torch.long) for d in y]
79 |
80 | distractors = [torch.tensor(d, dtype=torch.long) for distractors in distractors_batch for d in distractors]
81 |
82 | # Pad now so we pad correctly when we have only a single input (context concatenated with y)
83 | y_out = pad_sequence(y_out, batch_first=True, padding_value=PADDING_IDX)
84 | distractors = pad_sequence(distractors, batch_first=True, padding_value=PADDING_IDX)
85 | contexts = [pad_sequence(c, batch_first=True, padding_value=PADDING_IDX) for c in contexts]
86 |
87 | return contexts, y_out, distractors
88 |
89 | def _s2s_loss(targets, enc_contexts, model):
90 | hidden_state, padding_mask = None, None
91 |
92 | nexts = targets[:, 1:].contiguous() if targets.dim() == 2 else targets[:, 1:, 0].contiguous()
93 | outputs = model.decode(targets[:, :-1].contiguous(), enc_contexts)
94 |
95 | outputs = outputs.view(-1, outputs.shape[-1]).float()
96 | nexts = nexts.view(-1)
97 |
98 | lm_criterion = torch.nn.CrossEntropyLoss(ignore_index=PADDING_IDX)
99 | loss = lm_criterion(outputs, nexts)
100 | return loss, hidden_state, padding_mask
101 |
102 | def _lm_loss(contexts, enc_contexts, model, ignore_idxs, device):
103 | batch_lm_loss = torch.tensor(0, dtype=torch.float, device=device)
104 |
105 | for context in contexts:
106 | enc_context = model.encode(context.clone())
107 | enc_contexts.append(enc_context)
108 |
109 | context_outputs = model.generate(enc_context[0])
110 | ignore_mask = torch.stack([context == idx for idx in ignore_idxs], dim=-1).any(dim=-1)
111 | context.masked_fill_(ignore_mask, PADDING_IDX)
112 | prevs = context_outputs[:, :-1, :].contiguous()
113 | nexts = context[:, 1:].contiguous() if context.dim() == 2 else context[:, 1:, 0].contiguous()
114 | lm_criterion = torch.nn.CrossEntropyLoss(ignore_index=PADDING_IDX)
115 | batch_lm_loss += lm_criterion(prevs.view(-1, prevs.shape[-1]).float(), nexts.view(-1)) / len(contexts)
116 | return batch_lm_loss
117 |
118 |
119 |
120 | def main():
121 | args = InputConfig().args
122 |
123 | trainer_config = get_trainer_config(args)
124 |
125 | set_seed(trainer_config.seed)
126 | device = torch.device(trainer_config.device)
127 | save_path = trainer_config.load_last[:trainer_config.load_last.rfind('/')]
128 | generate_file_name = args.generate_file_name
129 |
130 | logger = config_logger(os.path.join(save_path, 'inference.log'))
131 |
132 | parsed_valid_data, parsed_test_data = None, None
133 | if args.model_type == 'seq2seq':
134 | seq2seq_vocab = Seq2seqVocab(trainer_config.train_datasets, trainer_config.valid_datasets,
135 | trainer_config.test_datasets, args.vocab_path, data_type=args.data_type)
136 | tokenizer = seq2seq_vocab.vocab
137 | model = TransformerSeq2Seq(args.emb_dim, args.hidden_dim, args.num_layers, args.heads, args.depth_size,
138 | args.filter_size, tokenizer, args.pretrained_emb_file, args.pointer_gen, logger,
139 | multi_input=not args.single_input,
140 | attention_pooling_type=args.attention_pooling_type)
141 | args.dialog_embeddings = False
142 | else:
143 | model = GPT2DoubleHeadsModel.from_pretrained('./gpt2-small')
144 | tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-small')
145 | tokenizer, additional_length = modify_tokenizer(tokenizer, args.data_type)
146 | model.embeddings_size = 768
147 | model.n_embeddings = len(tokenizer)
148 | model.shared_attention = (args.shared_attention == 1)
149 | model.shared_module = (args.shared_module == 1)
150 | model.attention_pooling_type = args.attention_pooling_type
151 | model.single_input = args.single_input
152 | model_embedding_weight = model.transformer.wte.weight
153 | model.transformer.wte = nn.Embedding(model.n_embeddings, 768)
154 | model.lm_head = nn.Linear(768, model.n_embeddings, bias=False)
155 | model.transformer.wte.weight.data[:-additional_length, :] = model_embedding_weight.data
156 | model.transformer.wte.weight.data[-additional_length:, :] = 0
157 | model.lm_head.weight = model.transformer.wte.weight
158 |
159 | model.padding_idx = tokenizer.pad_id
160 | model.n_pos_embeddings = 512
161 |
162 | model.talker1_id = tokenizer.talker1_bos_id
163 | model.talker2_id = tokenizer.talker2_bos_id
164 | model.bos_id = tokenizer.bos_id
165 | model.eos_id = tokenizer.eos_id
166 | model.beam_size = args.beam_size
167 | model.diversity_groups = 1
168 | model.max_seq_len = 32
169 | model.dialog_embeddings = args.dialog_embeddings
170 | model.bs_temperature = args.bs_temperature
171 | model.bs_nucleus_p = args.bs_nucleus_p
172 | model.annealing_topk = args.annealing_topk
173 | model.length_penalty_coef = args.length_penalty
174 | model.vocab = None
175 | model.annealing = args.annealing
176 | model.diversity_coef = args.diversity_coef
177 | model.sample = False
178 | model.inference_mode = args.inference_mode
179 | model.response_k = args.response_k
180 |
181 | logger.info('loading datasets')
182 | valid_dataset = None
183 | test_dataset = FacebookDataset(trainer_config.test_datasets, tokenizer,
184 | max_lengths=(model.n_pos_embeddings - 1) // (3 if args.single_input else 1), # A bit restrictive here
185 | dialog_embeddings=args.dialog_embeddings,
186 | cache=trainer_config.test_datasets_cache,
187 | use_start_end=args.use_start_end,
188 | negative_samples=0, # Keep all negative samples
189 | augment=False,
190 | aug_syn_proba=0.0,
191 | limit_size=trainer_config.limit_eval_size,
192 | max_history_size=args.max_history_size,
193 | single_input=args.single_input,
194 | data_type=args.data_type,
195 | parsed_data=parsed_test_data)
196 | # logger.info(f'valid dataset {len(valid_dataset)} test dataset {(len(test_dataset))}')
197 | logger.info(f'test dataset {(len(test_dataset))}')
198 |
199 | state_dict = torch.load(trainer_config.load_last, map_location=device)
200 | if state_dict.__contains__('model'):
201 | model.load_state_dict(state_dict['model'], strict=False)
202 | else:
203 | model.load_state_dict(state_dict)
204 | model.to(device)
205 | logger.info('Weights loaded from {}'.format(trainer_config.load_last))
206 |
207 | trainer = Trainer(model,
208 | test_dataset,
209 | trainer_config,
210 | None,
211 | logger=logger,
212 | test_dataset=test_dataset,
213 | valid_dataset=valid_dataset,
214 | n_jobs=trainer_config.n_jobs,
215 | device=device,
216 | ignore_idxs=tokenizer.all_special_ids,
217 | local_rank=args.local_rank,
218 | apex_level=None,
219 | apex_loss_scale=trainer_config.apex_loss_scale,
220 | full_input=trainer_config.full_input,
221 | uncertainty_loss=args.uncertainty_loss)
222 |
223 | _, all_attention, all_attention_inference = trainer.test_attention()
224 | torch.save(all_attention, 'augmentation/gpt2_th0.99_mix_all_attention.bin')
225 |
226 | def analysis_attention():
227 | all_attention = torch.load('augmentation/gpt2_th0.99_raw_all_attention.bin')
228 | with open('augmentation/th0.99_gpt2_positions.json', 'r') as f:
229 | consistent_pos = json.load(f)
230 | token_pos, target_persona_pos, whole_persona_pos = consistent_pos['token_positions'], \
231 | consistent_pos['target_persona_positions'], \
232 | consistent_pos['whole_persona_positions']
233 | token_level_attentions, sentence_level_attentions = [], []
234 | persona_sentence_ratio, avg_token_number = [], []
235 | for i in tqdm(range(len(target_persona_pos))):
236 | cur_attention = (all_attention[i] / torch.sum(all_attention[i], dim=-1, keepdim=True)).numpy()
237 | persona_sentence_attention = cur_attention[:, :, target_persona_pos[i][0]: target_persona_pos[i][1]]
238 | sentence_level_attentions.append(np.mean(np.sum(persona_sentence_attention, axis=-1), axis=-1))
239 | cur_token_attention = []
240 | for p in token_pos[i]:
241 | cur_token_attention.append(cur_attention[:, p[1], p[0]])
242 | token_level_attentions.append(cur_token_attention)
243 | persona_sentence_ratio.append((target_persona_pos[i][1] - target_persona_pos[i][0]) / cur_attention.shape[2])
244 | avg_token_number.append(cur_attention.shape[2])
245 | token_by_layer, sentence_by_layer = [[] for _ in range(12)], [[] for _ in range(12)]
246 | for i in range(len(all_attention)):
247 | for j in range(12):
248 | sentence_by_layer[j].append(sentence_level_attentions[i][j])
249 | if len(token_level_attentions[i]) > 0:
250 | token_by_layer[j].append(np.mean([a[j] for a in token_level_attentions[i]]))
251 | for i in range(12):
252 | print(str(i) + ' layer token-level: ' + str(np.mean(token_by_layer[i])))
253 | print(str(i) + ' layer sentence-level: ' + str(np.mean(sentence_by_layer[i])))
254 | print('mean token level value: ' + str(1/ np.mean(avg_token_number)))
255 | print('mean sentence level value: ' + str(np.mean(persona_sentence_ratio)))
256 | # cur_whole_attention, cur_persona_attention = [], []
257 | # for p in token_pos[i]:
258 | # cur_whole_attention.append(F.softmax(all_attention[i][:, p[1], :], dim=-1).numo'tpy())
259 | # cur_persona_attention.append(F.softmax(all_attention[i][:, p[1], :whole_persona_pos[i][1]], dim=-1).numpy())
260 | # whole_matched_response_attentions.append(cur_whole_attention)
261 | # within_persona_matched_response_attentions.append(cur_persona_attention)
262 | # response_persona_attention = F.softmax(all_attention[i], dim=-1).numpy()[:, :, target_persona_pos[i][0]: target_persona_pos[i][1]]
263 | # sentence_level_attention.append(np.mean(np.sum(response_persona_attention, axis=-1), axis=-1))
264 | # all_probs, all_probs_within_persona, all_prob_target_persona = [[] for _ in range(6)], [[] for _ in range(6)], \
265 | # [[] for _ in range(6)]
266 | # all_acc, all_acc_within_persona = [0] * 6, [0] * 6
267 | # lengths, persona_lengths, target_persona_lengths = [], [], []
268 | # for i in tqdm(range(len(target_persona_pos))):
269 | # for j, pos in enumerate(token_pos[i]):
270 | # for m in range(6):
271 | # all_probs[m].append(whole_matched_response_attentions[i][j][m][pos[0]])
272 | # index = np.argsort(-whole_matched_response_attentions[i][j][m])
273 | # if index[0] == pos[0]:
274 | # all_acc[m] += 1
275 | # all_probs_within_persona[m].append(within_persona_matched_response_attentions[i][j][m][pos[0]])
276 | # index = np.argsort(-within_persona_matched_response_attentions[i][j][m])
277 | # if index[0] == pos[0]:
278 | # all_acc_within_persona[m] += 1
279 | # lengths.append(whole_matched_response_attentions[i][j][0].shape[0])
280 | # persona_lengths.append(within_persona_matched_response_attentions[i][j][0].shape[0])
281 | # target_persona_lengths.append(target_persona_pos[i][1] - target_persona_pos[i][0])
282 | # for m in range(6):
283 | # all_prob_target_persona[m].append(sentence_level_attention[i][m])
284 | # for i in range(6):
285 | # print(str(i) + ' layer prob values: ' + str(np.mean(all_probs[i])))
286 | # print(str(i) + ' layer acc: ' + str(all_acc[i] / len(all_probs[i])))
287 | # print(str(i) + ' layer prob values within persona: ' + str(np.mean(all_probs_within_persona[i])))
288 | # print(str(i) + ' layer acc within persona: ' + str(all_acc_within_persona[i] / len(all_probs[i])))
289 | # print(str(i) + ' layer sentence-level prob: ' + str(np.mean(all_prob_target_persona[i])))
290 | # print('mean prob value: ' + str(1 / np.mean(lengths)))
291 | # print('mean prob value within persona: ' + str(1 / np.mean(persona_lengths)))
292 | # print('mean prob value sentence-level: ' + str(1 / np.mean(target_persona_lengths)))
293 |
294 | def analysis_attention_history():
295 | all_attention = torch.load('augmentation/gpt2_th0.99_raw_attention.bin')
296 | with open('augmentation/th0.99_gpt2_positions.json', 'r') as f:
297 | positions = json.load(f)
298 | avgs = []
299 | lengths = []
300 | attentions = [[] for _ in range(6)]
301 | for i in range(len(positions)):
302 | cur_attention = F.softmax(all_attention[i], dim=-1)
303 | cur_positions = positions[i]
304 | # if len(cur_positions[1]) >= 3:
305 | # start = cur_positions[0][-1]
306 | # if len(cur_positions[1]) > 3:
307 | # start = cur_positions[1][-4]
308 | # else:
309 | # start = cur_positions[0][-1]
310 | # else:
311 | # start = cur_positions[0][-1]
312 | start = cur_positions[1][-2] if len(cur_positions[1]) > 1 else cur_positions[0][-1]
313 | end = cur_positions[1][-1]
314 | target_attention = cur_attention[:, :, start: end]
315 | avg_attention = 1 / cur_attention.size()[2] * (end - start)
316 | for j in range(6):
317 | attentions[j].append(torch.mean(torch.sum(target_attention[j], dim=-1)).item())
318 | avgs.append(avg_attention)
319 | lengths.append(cur_attention.size()[2])
320 | print(np.mean(attentions[0]))
321 | print(np.mean(attentions[1]))
322 | print(np.mean(attentions[2]))
323 | print(np.mean(attentions[3]))
324 | print(np.mean(attentions[4]))
325 | print(np.mean(attentions[5]))
326 | print(np.mean(avgs))
327 | print(1 / np.mean(lengths))
328 | print('111')
329 |
330 | if __name__ == '__main__':
331 | # main()
332 | analysis_attention()
333 | # analysis_attention_history()
334 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from attrdict import AttrDict
2 | from copy import deepcopy
3 | import torch
4 | from model.utils import openai_transformer_config
5 | import git
6 | import argparse
7 |
8 |
9 | repo = git.Repo(search_parent_directories=True)
10 |
11 |
12 | def cast2(type_):
13 | return lambda val: val if val is None else type_(val)
14 |
15 |
16 | def get_model_config(args):
17 | default_config = openai_transformer_config()
18 | config = AttrDict({'bpe_vocab_path': './parameters/bpe.vocab',
19 | 'bpe_codes_path': './parameters/bpe.code',
20 | 'checkpoint_path': './checkpoints/last_checkpoint', # Keep the checpoint folder for the checkpoints of the agents
21 | 'n_layers': default_config.n_layers,
22 | 'n_pos_embeddings': 512,
23 | 'embeddings_size': default_config.embeddings_size,
24 | 'n_heads': default_config.n_heads,
25 | 'dropout': default_config.dropout,
26 | 'embed_dropout': default_config.embed_dropout,
27 | 'attn_dropout': default_config.attn_dropout,
28 | 'ff_dropout': default_config.ff_dropout,
29 | 'normalize_embeddings': args.normalize_embeddings,
30 | 'max_seq_len': 128,
31 | 'beam_size': args.beam_size,
32 | 'diversity_coef': args.diversity_coef,
33 | 'diversity_groups': args.diversity_groups,
34 | 'annealing_topk': args.annealing_topk,
35 | 'annealing': args.annealing,
36 | 'length_penalty': args.length_penalty,
37 | 'n_segments': None,
38 | 'constant_embedding': args.constant_embedding,
39 | 'multiple_choice_head': args.multiple_choice_head,
40 | 'share_models': True,
41 | 'successive_attention': args.successive_attention,
42 | 'sparse_embeddings': args.sparse_embeddings,
43 | 'shared_attention': args.shared_attention,
44 | 'dialog_embeddings': args.dialog_embeddings,
45 | 'single_input': args.single_input,
46 | 'use_start_end': args.use_start_end,
47 | 'apex_level': args.apex_level, # 'O0', 'O1', 'O2', 'O3',
48 | 'bs_temperature': args.bs_temperature,
49 | 'bs_nucleus_p': args.bs_nucleus_p,
50 | 'same_embedding_lm': args.same_embedding_lm,
51 | })
52 |
53 | return config
54 |
55 |
56 | def get_trainer_config(args, curriculum_config=False):
57 | config = AttrDict({'n_epochs': args.n_epochs,
58 | 'writer_comment': args.writer_comment,
59 | 'train_batch_size': args.train_batch_size,
60 | 'meta_batch_size': args.meta_batch_size,
61 | 'batch_split': args.batch_split,
62 | 'test_batch_size': args.test_batch_size,
63 | 'lr': args.lr,
64 | 'lr_warmup': args.lr_warmup, # a fraction of total training (epoch * train_set_length) if linear_schedule == True
65 | 'weight_decay': 0.01,
66 | 's2s_weight': args.s2s_weight,
67 | 'lm_weight': args.lm_weight,
68 | 'risk_weight': args.risk_weight,
69 | 'hits_weight': args.hits_weight,
70 | 'negative_samples': args.negative_samples,
71 | 'n_jobs': 4,
72 | 'label_smoothing': args.label_smoothing,
73 | 'clip_grad': args.clip_grad,
74 | 'alpha_clip_grad': args.alpha_clip_grad,
75 | 'test_period': 1,
76 | 'seed': args.seed,
77 | 'device': 'cuda',
78 | 'zero_shot': args.zero_shot,
79 | 'persona_augment': args.persona_augment,
80 | 'persona_aug_syn_proba': args.persona_aug_syn_proba,
81 | 'apex_loss_scale': args.apex_loss_scale, # e.g. '128', 'dynamic'
82 | 'linear_schedule': args.linear_schedule,
83 | 'evaluate_full_sequences': args.evaluate_full_sequences,
84 | 'limit_eval_size': args.limit_eval_size,
85 | 'limit_train_size': args.limit_train_size,
86 | 'risk_metric': args.risk_metric,
87 | 'load_last': args.load_last, #./checkpoints/last_checkpoint', # Now that we save several experiments you can put the path of the checpoint file you want to load here
88 | 'load_alpha_last': args.load_alpha_last,
89 | 'repo_id': str(repo),
90 | 'repo_sha': str(repo.head.object.hexsha),
91 | 'repo_branch': str(repo.active_branch),
92 | 'openai_parameters_dir': './parameters',
93 | 'last_checkpoint_path': 'last_checkpoint', # there are now in the ./runs/XXX/ experiments folders
94 | 'eval_references_file': 'eval_references_file',
95 | 'eval_predictions_file': 'eval_predictions_file',
96 | 'test_references_file': 'test_references_file',
97 | 'test_predictions_file_best': 'test_predictions_file_best',
98 | 'test_predictions_file_last': 'test_predictions_file_last',
99 | 'interrupt_checkpoint_path': 'interrupt_checkpoint', # there are now in the ./runs/XXX/ experiments folders
100 | 'train_datasets': args.train_datasets,
101 | 'train_datasets_cache': args.train_datasets_cache,
102 | 'test_datasets': args.test_datasets,
103 | 'test_datasets_cache': args.test_datasets_cache,
104 | 'valid_datasets': args.valid_datasets,
105 | 'valid_datasets_cache': args.valid_datasets_cache,
106 | 'full_input': args.full_input,
107 | 'single_input': args.single_input,
108 | 'max_history_size': args.max_history_size,
109 | 'model_saving_interval': args.model_saving_interval,
110 | 'patience': args.patience,
111 | 'mixup_cache': args.mixup_cache,
112 | 'data_type': args.data_type,
113 | })
114 | if curriculum_config:
115 | config.train_datasets = args.curriculum_train_datasets
116 | config.train_datasets_cache = args.curriculum_train_datasets_cache
117 | config.valid_datasets = args.curriculum_valid_datasets
118 | config.valid_datasets_cache = args.curriculum_valid_datasets_cache
119 | config.test_datasets = args.curriculum_valid_datasets
120 | config.test_datasets_cache = args.curriculum_valid_datasets_cache
121 | config.lr = args.curriculum_lr
122 | config.n_epochs = args.curriculum_n_epochs
123 | config.patience = args.patience
124 | config.max_history_size = args.curriculum_max_history_size
125 | config.eval_references_file = 'eval_references_file_stage1'
126 | config.eval_predictions_file = 'eval_predictions_file_stage1'
127 | config.test_references_file = 'test_reference_file_stage1'
128 | config.test_predictions_file_best = 'test_predictions_file_stage1_best'
129 | config.test_predictions_file_last = 'test_predictions_file_stage1_last'
130 | config.data_type = args.curriculum_data_type
131 |
132 | local_config = deepcopy(config)
133 | local_config.train_batch_size = 16
134 | local_config.batch_split = 2
135 | local_config.test_batch_size = 4
136 | local_config.n_jobs = 0
137 | local_config.device = 'cpu'
138 | local_config.risk_weight = 0
139 | local_config.zero_shot = False
140 | local_config.fp16 = False
141 | # local_config.train_datasets_cache = './datasets/train_datasets_cache.bin'
142 | # local_config.test_datasets_cache = './datasets/test_datasets_cache.bin'
143 |
144 | return config if torch.cuda.is_available() else local_config
145 |
146 | class InputConfig():
147 | def __init__(self):
148 | parser = argparse.ArgumentParser()
149 | parser.add_argument('--seed', type=int, default=0)
150 | parser.add_argument('--normalize_embeddings', action='store_true')
151 | parser.add_argument('--beam_size', default=3, type=int)
152 | parser.add_argument('--inference_mode', default='beam', type=str)
153 | parser.add_argument('--response_k', default=1, type=int)
154 | parser.add_argument('--diversity_coef', default=0, type=int)
155 | parser.add_argument('--lr', default=6.25e-5, type=float)
156 | parser.add_argument('--meta_lr', default=1e-4, type=float)
157 | parser.add_argument('--finetune_lr', default=2e-4, type=float)
158 | parser.add_argument('--finetune_epochs', default=3, type=int)
159 | parser.add_argument('--lr_warmup', default=0.002, type=float)
160 | parser.add_argument('--clip_grad', default=None, type=float)
161 | parser.add_argument('--diversity_groups', default=1, type=int)
162 | parser.add_argument('--annealing_topk', default=None, type=int)
163 | parser.add_argument('--annealing', default=0, type=float)
164 | parser.add_argument('--length_penalty', default=0.6, type=float)
165 | parser.add_argument('--bs_temperature', default=1, type=float)
166 | parser.add_argument('--bs_nucleus_p', default=0, type=float)
167 | parser.add_argument('--apex_level', default=None, type=str)
168 | parser.add_argument('--constant_embedding', action='store_true')
169 | parser.add_argument('--multiple_choice_head', action='store_true')
170 | parser.add_argument('--successive_attention', action='store_true')
171 | parser.add_argument('--sparse_embeddings', default=False, type=bool)
172 | parser.add_argument('--dialog_embeddings', default=True, type=bool)
173 | parser.add_argument('--single_input', action='store_true')
174 | parser.add_argument('--no_persona', action='store_true')
175 | parser.add_argument('--mixup', action='store_true')
176 | parser.add_argument('--use_start_end', action='store_true')
177 | parser.add_argument('--zero_shot', action='store_true')
178 | parser.add_argument('--persona_augment', action='store_true')
179 | parser.add_argument('--linear_schedule', default=True, type=bool)
180 | parser.add_argument('--evaluate_full_sequences', default=True, type=bool)
181 | parser.add_argument('--n_epochs', default=3, type=int)
182 | parser.add_argument('--patience', default=-1, type=int, help="the training patience if the dev result "
183 | "does not promote then training ends")
184 | parser.add_argument('--train_batch_size', default=128, type=int)
185 | parser.add_argument('--batch_split', default=32, type=int)
186 | parser.add_argument('--test_batch_size', default=8, type=int)
187 | parser.add_argument('--writer_comment', default='', type=str)
188 | parser.add_argument('--s2s_weight', default=2, type=float)
189 | parser.add_argument('--lm_weight', default=1, type=float)
190 | parser.add_argument('--risk_weight', default=0, type=float)
191 | parser.add_argument('--hits_weight', default=0, type=float)
192 | parser.add_argument('--label_smoothing', default=-1, type=float,
193 | help='Config for Seq2Seq model, whether use label smoothing loss, -1 means no smoothing')
194 | parser.add_argument('--negative_samples', default=0, type=int)
195 | parser.add_argument('--persona_aug_syn_proba', default=0, type=float)
196 | parser.add_argument('--apex_loss_scale', default=None, type=str)
197 | parser.add_argument('--limit_eval_size', default=-1, type=int)
198 | parser.add_argument('--limit_train_size', default=-1, type=int)
199 | parser.add_argument('--risk_metric', default='f1', type=str)
200 | parser.add_argument('--load_last', default='', type=str)
201 | parser.add_argument('--load_alpha_last', default='', type=str)
202 | parser.add_argument('--data_type', default='persona', type=str, help='data set types, persona/emoji/daily')
203 | parser.add_argument('--test_data_type', default=None, type=str, help='data set types, persona/emoji/daily')
204 | parser.add_argument('--emb_dim', default=300, type=int, help='Config for Seq2Seq model')
205 | parser.add_argument('--hidden_dim', default=300, type=int, help='Config for Seq2Seq model')
206 | parser.add_argument('--num_layers', default=6, type=int, help='Config for Seq2Seq model')
207 | parser.add_argument('--heads', default=4, type=int, help='Config for Seq2Seq model')
208 | parser.add_argument('--depth_size', default=40, type=int, help='Config for Seq2Seq model')
209 | parser.add_argument('--filter_size', default=50, type=int, help='Config for Seq2Seq model')
210 | parser.add_argument('--pointer_gen', action='store_true', help='Config for Seq2Seq model')
211 | parser.add_argument('--pretrained_emb_file', default='./glove/glove.6B.300d.txt', type=str)
212 | parser.add_argument('--vocab_path', default='./datasets/persona_vocab.bin', type=str)
213 | parser.add_argument('--extend_exist_vocab', default=None, type=str)
214 | parser.add_argument('--train_datasets', default='datasets/ConvAI2/train_self_original.txt', type=str)
215 | parser.add_argument('--valid_datasets', default='datasets/ConvAI2/valid_self_original.txt', type=str)
216 | parser.add_argument('--test_datasets', default='datasets/ConvAI2/test_self_original.txt', type=str)
217 | parser.add_argument('--cache_vocab_path', default='datasets/ConvAI2/cached_vocab.pickle', type=str)
218 | parser.add_argument('--train_datasets_cache', default='datasets/train_cache.bin', type=str)
219 | parser.add_argument('--valid_datasets_cache', default='datasets/valid_cache.bin', type=str)
220 | parser.add_argument('--test_datasets_cache', default='datasets/test_cache.bin', type=str)
221 | parser.add_argument('--extra_train_path', default=None, type=str)
222 | parser.add_argument('--extra_data_type', default='persona', type=str)
223 | parser.add_argument('--extra_cvae_utterances_path', default=None, type=str, help='The path indicates the CVAE augmented utterances')
224 | parser.add_argument('--generate_file_name', default='generation.json', type=str, help='The saved json file name '
225 | 'when inference on entailment data')
226 | parser.add_argument('--mixup_cache', default='datasets/mixup_cache.bin', type=str)
227 | parser.add_argument('--mixup_mode', default='alternate', type=str, help='The mode to execute the mixup operation'
228 | 'alternate=no mixup for one batch and mix up for next batch iteratively'
229 | 'all=mixup for all training samples, random=randomly mixup several samples within a batch'
230 | 'while rest samples remains unmixed')
231 | parser.add_argument('--mixup_model_path', default='./fasttext/persona_50_cbow.bin', type=str)
232 | parser.add_argument('--mixup_candidate_th', default=0.4, type=float)
233 | parser.add_argument('--mixup_ratio', default=0.15, type=float)
234 | parser.add_argument('--mixup_soft_loss_weight', default=0, type=float)
235 | parser.add_argument('--bert_mixup', action='store_true', help='whether use bert model for mixup')
236 | parser.add_argument('--replace', action='store_true', help='whether use directly replace token in samples')
237 | parser.add_argument('--few_shot', action='store_true', help='whether do few-shot learning')
238 | parser.add_argument('--shot_num', type=int, default=5, help='The shot number for training')
239 | parser.add_argument('--train_task_map', type=str, default=None, help='The path of task map json file between '
240 | 'persona id and sample ids')
241 | parser.add_argument('--valid_task_map', type=str, default=None, help='The path of task map json file between '
242 | 'persona id and sample ids')
243 | parser.add_argument('--test_task_map', type=str, default=None, help='The path of task map json file between '
244 | 'persona id and sample ids')
245 | parser.add_argument('--meta_batch_size', type=int, default=16, help='The meta batch size')
246 | parser.add_argument('--meta_batch_ratio', type=float, default=1, help='The training and evaluation sample number'
247 | ' ratio in each batch')
248 | parser.add_argument('--full_input', action='store_true', help='whether use the concatenated persona, history'
249 | ' and reply as the input ids')
250 | parser.add_argument('--max_history_size', type=int, default=-1, help='max history size in input ids')
251 | parser.add_argument('--same_embedding_lm', type=int, default=1, help='the embedding in transformer and the '
252 | 'weight in the lm are the same')
253 | parser.add_argument('--uncertainty_loss', action='store_true', help='whether use uncertainty loss')
254 | # parser.add_argument('--attention_weight', action='store_true', help='Whether use the learnable weight for '
255 | # 'attention obtained from different source in decoder')
256 | parser.add_argument('--model_type', type=str, default='gpt', help='gpt/gpt2/se2seq/rnn-seq2seq')
257 | parser.add_argument('--model_saving_interval', type=int, default=10, help='model saving interval for seq2seq')
258 | parser.add_argument('--entail_score_refs_file', type=str, default=None, help='The persona and idx json file for each '
259 | 'utterance, if None no entailment score will be calculated')
260 | parser.add_argument('--entail_model_path', type=str, default='./analysis/roberta_nli',
261 | help='The persona and idx json file for each utterance, if None no entailment score will be calculated')
262 | parser.add_argument('--bert_score_model_path', type=str, default='../roberta_large', help='The model path '
263 | 'indicate which model will be used for bertscore, if None then no bertscore is calculated')
264 | parser.add_argument('--baseline_path', type=str, default='./bert_score/rescale_baseline/en/roberta-large.tsv',
265 | help='The file path for the bert score rescale baseline')
266 | parser.add_argument('--rescale_with_baseline', action='store_true',
267 | help='Whether rescale the bert score')
268 | parser.add_argument('--ignore_sample_indices', type=str, default=None,
269 | help='The json file indicating which samples will be ignored')
270 | parser.add_argument('--shared_module', type=int, default=1)
271 | parser.add_argument('--shared_attention', type=int, default=1)
272 | parser.add_argument('--attention_pooling_type', type=str, default='mean', help='the method to pool attention '
273 | 'output from different source(mean/min/max/sw/dw/linear/att) '
274 | 'sw=source level weight, dw=dimension level weight, linear=linear transform for concatenating,'
275 | 'att=extra transformer attention layer to fuse attention output,'
276 | 'dys=dynamic determine the scalar weight for each source by a linear layer'
277 | 'dyd=dynamic determine the vector weight for each dimension for each source by a linear layer'
278 | 'mdys=mutual determine the scalar weight for each source by a linear layer'
279 | 'mdyd=mutual dynamic determine the vector weight for each dimension for each source by a linear layer')
280 |
281 | '''Configurations related to curriculum learning'''
282 | parser.add_argument('--curriculum_learning', action='store_true', help='Whether to do curriculum learning')
283 | parser.add_argument('--curriculum_train_datasets', type=str, default=None)
284 | parser.add_argument('--curriculum_valid_datasets', type=str, default=None)
285 | parser.add_argument('--curriculum_train_datasets_cache', type=str, default=None)
286 | parser.add_argument('--curriculum_valid_datasets_cache', type=str, default=None)
287 | parser.add_argument('--curriculum_lr', type=float, default=2e-4)
288 | parser.add_argument('--curriculum_n_epochs', type=int, default=50)
289 | parser.add_argument('--curriculum_patience', type=int, default=-1)
290 | parser.add_argument('--curriculum_max_history_size', type=int, default=3)
291 | parser.add_argument('--curriculum_data_type', default='entailment', type=str)
292 | parser.add_argument('--curriculum_reverse', action='store_true', help='Whether using the reverse order of training')
293 |
294 | parser.add_argument('--local_rank', type=int, default=-1, help="Distributed training.")
295 | parser.add_argument('--server_ip', type=str, default='', help="Used for debugging on GPU machine.")
296 | parser.add_argument('--server_port', type=str, default='', help="Used for debugging on GPU machine.")
297 |
298 | self.args = parser.parse_args()
299 |
--------------------------------------------------------------------------------
/bert_score/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import torch
4 | from math import log
5 | from itertools import chain
6 | from collections import defaultdict, Counter
7 | from multiprocessing import Pool
8 | from functools import partial
9 | from tqdm.auto import tqdm
10 | from torch.nn.utils.rnn import pad_sequence
11 | from distutils.version import LooseVersion
12 |
13 | from transformers import BertConfig, XLNetConfig, XLMConfig, RobertaConfig
14 | from transformers import AutoModel, GPT2Tokenizer
15 |
16 | from . import __version__
17 | from transformers import __version__ as trans_version
18 |
19 | __all__ = []
20 |
21 | SCIBERT_URL_DICT = {
22 | "scibert-scivocab-uncased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_scivocab_uncased.tar", # recommend by the SciBERT authors
23 | "scibert-scivocab-cased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_scivocab_cased.tar",
24 | "scibert-basevocab-uncased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_basevocab_uncased.tar",
25 | "scibert-basevocab-cased": "https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/pytorch_models/scibert_basevocab_cased.tar",
26 | }
27 |
28 |
29 | lang2model = defaultdict(lambda: "bert-base-multilingual-cased")
30 | lang2model.update(
31 | {"en": "roberta-large", "zh": "bert-base-chinese", "en-sci": "scibert-scivocab-uncased",}
32 | )
33 |
34 |
35 | model2layers = {
36 | "bert-base-uncased": 9, # 0.6925188074454226
37 | "bert-large-uncased": 18, # 0.7210358126642836
38 | "bert-base-cased-finetuned-mrpc": 9, # 0.6721947475618048
39 | "bert-base-multilingual-cased": 9, # 0.6680687802637132
40 | "bert-base-chinese": 8,
41 | "roberta-base": 10, # 0.706288719158983
42 | "roberta-large": 17, # 0.7385974720781534
43 | "roberta-large-mnli": 19, # 0.7535618640417984
44 | "roberta-base-openai-detector": 7, # 0.7048158349432633
45 | "roberta-large-openai-detector": 15, # 0.7462770207355116
46 | "xlnet-base-cased": 5, # 0.6630103662114238
47 | "xlnet-large-cased": 7, # 0.6598800720297179
48 | "xlm-mlm-en-2048": 6, # 0.651262570131464
49 | "xlm-mlm-100-1280": 10, # 0.6475166424401905
50 | "scibert-scivocab-uncased": 8, # 0.6590354319927313
51 | "scibert-scivocab-cased": 9, # 0.6536375053937445
52 | "scibert-basevocab-uncased": 9, # 0.6748944832703548
53 | "scibert-basevocab-cased": 9, # 0.6524624150542374
54 | "distilroberta-base": 5, # 0.6797558139322964
55 | "distilbert-base-uncased": 5, # 0.6756659152782033
56 | "distilbert-base-uncased-distilled-squad": 4, # 0.6718318036382493
57 | "distilbert-base-multilingual-cased": 5, # 0.6178131050889238
58 | "albert-base-v1": 10, # 0.654237567249745
59 | "albert-large-v1": 17, # 0.6755890754323239
60 | "albert-xlarge-v1": 16, # 0.7031844211905911
61 | "albert-xxlarge-v1": 8, # 0.7508642218461096
62 | "albert-base-v2": 9, # 0.6682455591837927
63 | "albert-large-v2": 14, # 0.7008537594374035
64 | "albert-xlarge-v2": 13, # 0.7317228357869254
65 | "albert-xxlarge-v2": 8, # 0.7505160257184014
66 | "xlm-roberta-base": 9, # 0.6506799445871697
67 | "xlm-roberta-large": 17, # 0.6941551437476826
68 | "google/electra-small-generator": 9, # 0.6659421842117754
69 | "google/electra-small-discriminator": 11, # 0.6534639151385759
70 | "google/electra-base-generator": 10, # 0.6730033453857188
71 | "google/electra-base-discriminator": 9, # 0.7032089590812965
72 | "google/electra-large-generator": 18, # 0.6813370013104459
73 | "google/electra-large-discriminator": 14, # 0.6896675824733477
74 | "google/bert_uncased_L-2_H-128_A-2": 1, # 0.5887998733228855
75 | "google/bert_uncased_L-2_H-256_A-4": 1, # 0.6114863547661203
76 | "google/bert_uncased_L-2_H-512_A-8": 1, # 0.6177345529192847
77 | "google/bert_uncased_L-2_H-768_A-12": 2, # 0.6191261237956839
78 | "google/bert_uncased_L-4_H-128_A-2": 3, # 0.6076202863798991
79 | "google/bert_uncased_L-4_H-256_A-4": 3, # 0.6205239036810148
80 | "google/bert_uncased_L-4_H-512_A-8": 3, # 0.6375351621856903
81 | "google/bert_uncased_L-4_H-768_A-12": 3, # 0.6561849979644787
82 | "google/bert_uncased_L-6_H-128_A-2": 5, # 0.6200458425360283
83 | "google/bert_uncased_L-6_H-256_A-4": 5, # 0.6277501629539081
84 | "google/bert_uncased_L-6_H-512_A-8": 5, # 0.641952305130849
85 | "google/bert_uncased_L-6_H-768_A-12": 5, # 0.6762186226247106
86 | "google/bert_uncased_L-8_H-128_A-2": 7, # 0.6186876506711779
87 | "google/bert_uncased_L-8_H-256_A-4": 7, # 0.6447993208267708
88 | "google/bert_uncased_L-8_H-512_A-8": 6, # 0.6489729408169956
89 | "google/bert_uncased_L-8_H-768_A-12": 7, # 0.6705203359541737
90 | "google/bert_uncased_L-10_H-128_A-2": 8, # 0.6126762064125278
91 | "google/bert_uncased_L-10_H-256_A-4": 8, # 0.6376350032576573
92 | "google/bert_uncased_L-10_H-512_A-8": 9, # 0.6579006292799915
93 | "google/bert_uncased_L-10_H-768_A-12": 8, # 0.6861146692220176
94 | "google/bert_uncased_L-12_H-128_A-2": 10, # 0.6184105693383591
95 | "google/bert_uncased_L-12_H-256_A-4": 11, # 0.6374004994430261
96 | "google/bert_uncased_L-12_H-512_A-8": 10, # 0.65880012149526
97 | "google/bert_uncased_L-12_H-768_A-12": 9, # 0.675911357700092
98 | }
99 |
100 |
101 | def sent_encode(tokenizer, sent):
102 | "Encoding as sentence based on the tokenizer"
103 | sent = sent.strip()
104 | if sent == "":
105 | return tokenizer.build_inputs_with_special_tokens([])
106 | elif isinstance(tokenizer, GPT2Tokenizer):
107 | # for RoBERTa and GPT-2
108 | import transformers
109 |
110 | if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"):
111 | return tokenizer.encode(
112 | sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len, truncation=True
113 | )
114 | else:
115 | return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len)
116 | else:
117 | import transformers
118 |
119 | if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"):
120 | return tokenizer.encode(sent, add_special_tokens=True, truncation=True)
121 | else:
122 | return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len)
123 |
124 |
125 | def get_model(model_type, num_layers, all_layers=None):
126 | print(model_type)
127 | if model_type.startswith("scibert"):
128 | model = AutoModel.from_pretrained(cache_scibert(model_type))
129 | else:
130 | model = AutoModel.from_pretrained(model_type)
131 | model.eval()
132 |
133 | # drop unused layers
134 | if not all_layers:
135 | if hasattr(model, "n_layers"): # xlm
136 | assert (
137 | 0 <= num_layers <= model.n_layers
138 | ), f"Invalid num_layers: num_layers should be between 0 and {model.n_layers} for {model_type}"
139 | model.n_layers = num_layers
140 | elif hasattr(model, "layer"): # xlnet
141 | assert (
142 | 0 <= num_layers <= len(model.layer)
143 | ), f"Invalid num_layers: num_layers should be between 0 and {len(model.layer)} for {model_type}"
144 | model.layer = torch.nn.ModuleList([layer for layer in model.layer[:num_layers]])
145 | elif hasattr(model, "encoder"): # albert
146 | if hasattr(model.encoder, "albert_layer_groups"):
147 | assert (
148 | 0 <= num_layers <= model.encoder.config.num_hidden_layers
149 | ), f"Invalid num_layers: num_layers should be between 0 and {model.encoder.config.num_hidden_layers} for {model_type}"
150 | model.encoder.config.num_hidden_layers = num_layers
151 | else: # bert, roberta
152 | assert (
153 | 0 <= num_layers <= len(model.encoder.layer)
154 | ), f"Invalid num_layers: num_layers should be between 0 and {len(model.encoder.layer)} for {model_type}"
155 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]])
156 | elif hasattr(model, "transformer"): # bert, roberta
157 | assert (
158 | 0 <= num_layers <= len(model.transformer.layer)
159 | ), f"Invalid num_layers: num_layers should be between 0 and {len(model.transformer.layer)} for {model_type}"
160 | model.transformer.layer = torch.nn.ModuleList([layer for layer in model.transformer.layer[:num_layers]])
161 | else:
162 | raise ValueError("Not supported")
163 | else:
164 | if hasattr(model, "output_hidden_states"):
165 | model.output_hidden_states = True
166 | elif hasattr(model, "encoder"):
167 | model.encoder.output_hidden_states = True
168 | elif hasattr(model, "transformer"):
169 | model.transformer.output_hidden_states = True
170 | else:
171 | raise ValueError(f"Not supported model architecture: {model_type}")
172 |
173 | return model
174 |
175 |
176 | def padding(arr, pad_token, dtype=torch.long):
177 | lens = torch.LongTensor([len(a) for a in arr])
178 | max_len = lens.max().item()
179 | padded = torch.ones(len(arr), max_len, dtype=dtype) * pad_token
180 | mask = torch.zeros(len(arr), max_len, dtype=torch.long)
181 | for i, a in enumerate(arr):
182 | padded[i, : lens[i]] = torch.tensor(a, dtype=dtype)
183 | mask[i, : lens[i]] = 1
184 | return padded, lens, mask
185 |
186 |
187 | def bert_encode(model, x, attention_mask, all_layers=False):
188 | model.eval()
189 | with torch.no_grad():
190 | out = model(x, attention_mask=attention_mask)
191 | if all_layers:
192 | emb = torch.stack(out[-1], dim=2)
193 | else:
194 | emb = out[0]
195 | return emb
196 |
197 |
198 | def process(a, tokenizer=None):
199 | if tokenizer is not None:
200 | a = sent_encode(tokenizer, a)
201 | return set(a)
202 |
203 |
204 | def get_idf_dict(arr, tokenizer, nthreads=4):
205 | """
206 | Returns mapping from word piece index to its inverse document frequency.
207 |
208 |
209 | Args:
210 | - :param: `arr` (list of str) : sentences to process.
211 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`.
212 | - :param: `nthreads` (int) : number of CPU threads to use
213 | """
214 | idf_count = Counter()
215 | num_docs = len(arr)
216 |
217 | process_partial = partial(process, tokenizer=tokenizer)
218 |
219 | with Pool(nthreads) as p:
220 | idf_count.update(chain.from_iterable(p.map(process_partial, arr)))
221 |
222 | idf_dict = defaultdict(lambda: log((num_docs + 1) / (1)))
223 | idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()})
224 | return idf_dict
225 |
226 |
227 | def collate_idf(arr, tokenizer, idf_dict, device="cuda:0"):
228 | """
229 | Helper function that pads a list of sentences to hvae the same length and
230 | loads idf score for words in the sentences.
231 |
232 | Args:
233 | - :param: `arr` (list of str): sentences to process.
234 | - :param: `tokenize` : a function that takes a string and return list
235 | of tokens.
236 | - :param: `numericalize` : a function that takes a list of tokens and
237 | return list of token indexes.
238 | - :param: `idf_dict` (dict): mapping a word piece index to its
239 | inverse document frequency
240 | - :param: `pad` (str): the padding token.
241 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda'
242 | """
243 | arr = [sent_encode(tokenizer, a) for a in arr]
244 |
245 | idf_weights = [[idf_dict[i] for i in a] for a in arr]
246 |
247 | pad_token = tokenizer.pad_token_id
248 |
249 | padded, lens, mask = padding(arr, pad_token, dtype=torch.long)
250 | padded_idf, _, _ = padding(idf_weights, 0, dtype=torch.float)
251 |
252 | padded = padded.to(device=device)
253 | mask = mask.to(device=device)
254 | lens = lens.to(device=device)
255 | return padded, padded_idf, lens, mask
256 |
257 |
258 | def get_bert_embedding(all_sens, model, tokenizer, idf_dict, batch_size=-1, device="cuda:0", all_layers=False):
259 | """
260 | Compute BERT embedding in batches.
261 |
262 | Args:
263 | - :param: `all_sens` (list of str) : sentences to encode.
264 | - :param: `model` : a BERT model from `pytorch_pretrained_bert`.
265 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`.
266 | - :param: `idf_dict` (dict) : mapping a word piece index to its
267 | inverse document frequency
268 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda'
269 | """
270 |
271 | padded_sens, padded_idf, lens, mask = collate_idf(all_sens, tokenizer, idf_dict, device=device)
272 |
273 | if batch_size == -1:
274 | batch_size = len(all_sens)
275 |
276 | embeddings = []
277 | with torch.no_grad():
278 | for i in range(0, len(all_sens), batch_size):
279 | batch_embedding = bert_encode(
280 | model, padded_sens[i : i + batch_size], attention_mask=mask[i : i + batch_size], all_layers=all_layers
281 | )
282 | embeddings.append(batch_embedding)
283 | del batch_embedding
284 |
285 | total_embedding = torch.cat(embeddings, dim=0)
286 |
287 | return total_embedding, mask, padded_idf
288 |
289 |
290 | def greedy_cos_idf(ref_embedding, ref_masks, ref_idf, hyp_embedding, hyp_masks, hyp_idf, all_layers=False):
291 | """
292 | Compute greedy matching based on cosine similarity.
293 |
294 | Args:
295 | - :param: `ref_embedding` (torch.Tensor):
296 | embeddings of reference sentences, BxKxd,
297 | B: batch size, K: longest length, d: bert dimenison
298 | - :param: `ref_lens` (list of int): list of reference sentence length.
299 | - :param: `ref_masks` (torch.LongTensor): BxKxK, BERT attention mask for
300 | reference sentences.
301 | - :param: `ref_idf` (torch.Tensor): BxK, idf score of each word
302 | piece in the reference setence
303 | - :param: `hyp_embedding` (torch.Tensor):
304 | embeddings of candidate sentences, BxKxd,
305 | B: batch size, K: longest length, d: bert dimenison
306 | - :param: `hyp_lens` (list of int): list of candidate sentence length.
307 | - :param: `hyp_masks` (torch.LongTensor): BxKxK, BERT attention mask for
308 | candidate sentences.
309 | - :param: `hyp_idf` (torch.Tensor): BxK, idf score of each word
310 | piece in the candidate setence
311 | """
312 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1))
313 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1))
314 |
315 | if all_layers:
316 | B, _, L, D = hyp_embedding.size()
317 | hyp_embedding = hyp_embedding.transpose(1, 2).transpose(0, 1).contiguous().view(L * B, hyp_embedding.size(1), D)
318 | ref_embedding = ref_embedding.transpose(1, 2).transpose(0, 1).contiguous().view(L * B, ref_embedding.size(1), D)
319 | batch_size = ref_embedding.size(0)
320 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2))
321 | masks = torch.bmm(hyp_masks.unsqueeze(2).float(), ref_masks.unsqueeze(1).float())
322 | if all_layers:
323 | masks = masks.unsqueeze(0).expand(L, -1, -1, -1).contiguous().view_as(sim)
324 | else:
325 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim)
326 |
327 | masks = masks.float().to(sim.device)
328 | sim = sim * masks
329 |
330 | word_precision = sim.max(dim=2)[0]
331 | word_recall = sim.max(dim=1)[0]
332 |
333 | hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True))
334 | ref_idf.div_(ref_idf.sum(dim=1, keepdim=True))
335 | precision_scale = hyp_idf.to(word_precision.device)
336 | recall_scale = ref_idf.to(word_recall.device)
337 | if all_layers:
338 | precision_scale = precision_scale.unsqueeze(0).expand(L, B, -1).contiguous().view_as(word_precision)
339 | recall_scale = recall_scale.unsqueeze(0).expand(L, B, -1).contiguous().view_as(word_recall)
340 | P = (word_precision * precision_scale).sum(dim=1)
341 | R = (word_recall * recall_scale).sum(dim=1)
342 | F = 2 * P * R / (P + R)
343 |
344 | hyp_zero_mask = hyp_masks.sum(dim=1).eq(2)
345 | ref_zero_mask = ref_masks.sum(dim=1).eq(2)
346 |
347 | if all_layers:
348 | P = P.view(L, B)
349 | R = R.view(L, B)
350 | F = F.view(L, B)
351 |
352 | if torch.any(hyp_zero_mask):
353 | print("Warning: Empty candidate sentence detected; setting precision to be 0.", file=sys.stderr)
354 | P = P.masked_fill(hyp_zero_mask, 0.0)
355 |
356 | if torch.any(ref_zero_mask):
357 | print("Warning: Empty reference sentence detected; setting recall to be 0.", file=sys.stderr)
358 | R = R.masked_fill(ref_zero_mask, 0.0)
359 |
360 | F = F.masked_fill(torch.isnan(F), 0.0)
361 |
362 | return P, R, F
363 |
364 |
365 | def bert_cos_score_idf(
366 | model, refs, hyps, tokenizer, idf_dict, verbose=False, batch_size=64, device="cuda:0", all_layers=False
367 | ):
368 | """
369 | Compute BERTScore.
370 |
371 | Args:
372 | - :param: `model` : a BERT model in `pytorch_pretrained_bert`
373 | - :param: `refs` (list of str): reference sentences
374 | - :param: `hyps` (list of str): candidate sentences
375 | - :param: `tokenzier` : a BERT tokenizer corresponds to `model`
376 | - :param: `idf_dict` : a dictionary mapping a word piece index to its
377 | inverse document frequency
378 | - :param: `verbose` (bool): turn on intermediate status update
379 | - :param: `batch_size` (int): bert score processing batch size
380 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda'
381 | """
382 | preds = []
383 |
384 | def dedup_and_sort(l):
385 | return sorted(list(set(l)), key=lambda x: len(x.split(" ")), reverse=True)
386 |
387 | sentences = dedup_and_sort(refs + hyps)
388 | embs = []
389 | iter_range = range(0, len(sentences), batch_size)
390 | print(len(sentences))
391 | if verbose:
392 | print("computing bert embedding.")
393 | iter_range = tqdm(iter_range)
394 | stats_dict = dict()
395 | for batch_start in tqdm(iter_range):
396 | sen_batch = sentences[batch_start : batch_start + batch_size]
397 | embs, masks, padded_idf = get_bert_embedding(
398 | sen_batch, model, tokenizer, idf_dict, device=device, all_layers=all_layers
399 | )
400 | embs = embs.cpu()
401 | masks = masks.cpu()
402 | padded_idf = padded_idf.cpu()
403 | for i, sen in enumerate(sen_batch):
404 | sequence_len = masks[i].sum().item()
405 | emb = embs[i, :sequence_len]
406 | idf = padded_idf[i, :sequence_len]
407 | stats_dict[sen] = (emb, idf)
408 |
409 | def pad_batch_stats(sen_batch, stats_dict, device):
410 | stats = [stats_dict[s] for s in sen_batch]
411 | emb, idf = zip(*stats)
412 | emb = [e.to(device) for e in emb]
413 | idf = [i.to(device) for i in idf]
414 | lens = [e.size(0) for e in emb]
415 | emb_pad = pad_sequence(emb, batch_first=True, padding_value=2.0)
416 | idf_pad = pad_sequence(idf, batch_first=True)
417 |
418 | def length_to_mask(lens):
419 | lens = torch.tensor(lens, dtype=torch.long)
420 | max_len = max(lens)
421 | base = torch.arange(max_len, dtype=torch.long).expand(len(lens), max_len)
422 | return base < lens.unsqueeze(1)
423 |
424 | pad_mask = length_to_mask(lens).to(device)
425 | return emb_pad, pad_mask, idf_pad
426 |
427 | device = next(model.parameters()).device
428 | iter_range = range(0, len(refs), batch_size)
429 | if verbose:
430 | print("computing greedy matching.")
431 | iter_range = tqdm(iter_range)
432 |
433 | with torch.no_grad():
434 | for batch_start in iter_range:
435 | batch_refs = refs[batch_start : batch_start + batch_size]
436 | batch_hyps = hyps[batch_start : batch_start + batch_size]
437 | ref_stats = pad_batch_stats(batch_refs, stats_dict, device)
438 | hyp_stats = pad_batch_stats(batch_hyps, stats_dict, device)
439 |
440 | P, R, F1 = greedy_cos_idf(*ref_stats, *hyp_stats, all_layers)
441 | preds.append(torch.stack((P, R, F1), dim=-1).cpu())
442 | preds = torch.cat(preds, dim=1 if all_layers else 0)
443 | return preds
444 |
445 |
446 | def get_hash(model, num_layers, idf, rescale_with_baseline, use_custom_baseline):
447 | msg = "{}_L{}{}_version={}(hug_trans={})".format(
448 | model, num_layers, "_idf" if idf else "_no-idf", __version__, trans_version
449 | )
450 | if rescale_with_baseline:
451 | if use_custom_baseline:
452 | msg += "-custom-rescaled"
453 | else:
454 | msg += "-rescaled"
455 | return msg
456 |
457 |
458 | def cache_scibert(model_type, cache_folder="~/.cache/torch/transformers"):
459 | if not model_type.startswith("scibert"):
460 | return model_type
461 |
462 | underscore_model_type = model_type.replace("-", "_")
463 | cache_folder = os.path.abspath(cache_folder)
464 | filename = os.path.join(cache_folder, underscore_model_type)
465 |
466 | # download SciBERT models
467 | if not os.path.exists(filename):
468 | cmd = f"mkdir -p {cache_folder}; cd {cache_folder};"
469 | cmd += f"wget {SCIBERT_URL_DICT[model_type]}; tar -xvf {underscore_model_type}.tar;"
470 | cmd += (
471 | f"rm -f {underscore_model_type}.tar ; cd {underscore_model_type}; tar -zxvf weights.tar.gz; mv weights/* .;"
472 | )
473 | cmd += f"rm -f weights.tar.gz; rmdir weights; mv bert_config.json config.json;"
474 | print(cmd)
475 | print(f"downloading {model_type} model")
476 | os.system(cmd)
477 |
478 | # fix the missing files in scibert
479 | json_file = os.path.join(filename, "special_tokens_map.json")
480 | if not os.path.exists(json_file):
481 | with open(json_file, "w") as f:
482 | print(
483 | '{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}',
484 | file=f,
485 | )
486 |
487 | json_file = os.path.join(filename, "added_tokens.json")
488 | if not os.path.exists(json_file):
489 | with open(json_file, "w") as f:
490 | print("{}", file=f)
491 |
492 | if "uncased" in model_type:
493 | json_file = os.path.join(filename, "tokenizer_config.json")
494 | if not os.path.exists(json_file):
495 | with open(json_file, "w") as f:
496 | print('{"do_lower_case": true, "max_len": 512, "init_inputs": []}', file=f)
497 |
498 | return filename
499 |
--------------------------------------------------------------------------------