├── misc
├── __init__.py
├── scripts
│ ├── __init__.py
│ ├── bleu.py
│ └── rouge.py
├── use.py
├── utils.py
├── evaluate.py
├── evaluate_attacks.py
├── input_data.py
└── acc_transformer.py
├── models
├── __init__.py
├── bert
│ ├── __init__.py
│ ├── config.py
│ ├── optimization.py
│ ├── measures.py
│ ├── tokenization.py
│ └── input_data.py
├── utils.py
├── myBertClassifier.py
├── myCopyDecoder.py
├── myCNNClassifier.py
├── myClassifier.py
└── mySeq2Seq.py
├── .idea
└── vcs.xml
├── requirements.txt
├── download.sh
├── train_ae.sh
├── README.md
├── test_cls.sh
├── train_cls.sh
├── test_adv.sh
├── train_adv.sh
├── config.py
└── yelp_preprocessing.py
/misc/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/bert/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/misc/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | requests
3 | six
4 | nltk
5 | spacy
6 | sacremoses
7 | Sphinx
8 | sphinx_rtd_theme
9 | pillow >= 4.1.1
10 | torch==1.3.0
11 | torchvision==0.4.1
12 | torchtext
13 | git+git://github.com/jekbradbury/revtok.git
14 | transformers
15 | gensim
16 | boto
17 | annoy
18 | urllib3==1.25.6
19 | scipy==1.3.1
20 | tensorflow_hub
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | mkdir embeddings
4 | cd embeddings
5 |
6 | # download glove embeddings
7 | wget http://nlp.stanford.edu/data/glove.840B.300d.zip
8 | unzip glove.840B.300d.zip
9 |
10 | # download counter-fitted embeddings
11 | wget https://github.com/nmrksic/counter-fitting/raw/master/word_vectors/counter-fitted-vectors.txt.zip
12 | unzip counter-fitted-vectors.txt.zip
13 |
14 | cd ..
15 |
16 |
--------------------------------------------------------------------------------
/train_ae.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source /Users/yxu132/pyflow3.6/bin/activate
4 | DATA_DIR=data
5 | MODEL_DIR=saved_models
6 |
7 | python train.py --do_train \
8 | --vocab_file=$DATA_DIR/vocab.in \
9 | --emb_file=$DATA_DIR/emb.json \
10 | --input_file=$DATA_DIR/train.in \
11 | --output_file=$DATA_DIR/train.out \
12 | --dev_file=$DATA_DIR/dev.in \
13 | --dev_output=$DATA_DIR/dev.out \
14 | --enc_type=bi --attention --enc_num_units=512 --dec_num_units=512 \
15 | --learning_rate=0.001 --batch_size=32 --max_len=50 \
16 | --num_epochs=10 --print_every_steps=100 --stop_steps=20000 \
17 | --output_dir=$MODEL_DIR/ae_output \
18 | --save_checkpoints \
19 | --num_gpus=0
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # adv-def-text
2 |
3 | ## Requirements:
4 |
5 | - Python 3.6+
6 | - TensorFlow 1.3+
7 | - spacy
8 | - transformers
9 | - tensorflow_hub
10 |
11 | Please refer to requirement.txt for detailed packages
12 |
13 | ## Introduction
14 |
15 | This is an implementation for the paper: "Grey-box Adversarial Attack and Defence for Text".
16 | The paper is currently submitted to EMNLP'20.
17 |
18 | ## Run training and test
19 |
20 | 1, Please download the Yelp review dataset from the official website [link](https://www.yelp.com/dataset).
21 |
22 | 2, Download the GloVe embeddings and the counter-fitted embeddings using
23 |
24 | ```
25 | ./download.sh.
26 | ```
27 |
28 | 3, Run dataset preprecessing using
29 |
30 | ```
31 | python yelp_preprocessing.py --data_dir YELP_DATASET_PATH --embed_file GLOVE_EMB_PATH
32 | ```
33 |
34 | 4, Train target models using the scripts
35 |
36 | ```
37 | ./train_cls.sh
38 | ```
39 |
40 | 5, Pre-train Auto-encoder for reconstruction using the scripts
41 |
42 | ```
43 | ./train_ae.sh.
44 | ```
45 |
46 | 6, Train adversarial attack/defence models using the scripts (multiple variants of our model are available and commented out in the script)
47 |
48 | ```
49 | ./train_adv.sh
50 | ```
51 |
52 | 7, Perform independent test for adversarial attack/defence using
53 |
54 | ```
55 | ./test_adv.sh
56 | ```
57 |
--------------------------------------------------------------------------------
/test_cls.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source /Users/yxu132/pyflow3.6/bin/activate
4 | DATA_DIR=data
5 | MODEL_DIR=saved_models
6 |
7 | python train.py --do_test \
8 | --vocab_file=$DATA_DIR/vocab.in \
9 | --emb_file=$DATA_DIR/emb.json \
10 | --test_file=$DATA_DIR/test.in \
11 | --test_output=$DATA_DIR/test.out \
12 | --load_model=$MODEL_DIR/cls_output/bi_att \
13 | --classification --classification_model=RNN --output_classes=2 \
14 | --enc_type=bi --enc_num_units=256 --cls_attention --cls_attention_size=50 \
15 | --learning_rate=0.001 --batch_size=32 --max_len=50 \
16 | --num_epochs=10 --print_every_steps=100 --stop_steps=5000 \
17 | --output_dir=$MODEL_DIR/cls_output_test \
18 | --save_checkpoints \
19 | --num_gpus=0
20 |
21 |
22 | ## Test against augmented classifier from the AE+LS+CF
23 | #python train.py --do_test \
24 | # --vocab_file=$DATA_DIR/vocab.in \
25 | # --emb_file=$DATA_DIR/emb.json \
26 | # --test_file=$DATA_DIR/test.in \
27 | # --test_output=$DATA_DIR/test.out \
28 | # --load_model=$MODEL_DIR/adv_output_lscf/nmt-T2.ckpt \
29 | # --classification --classification_model=RNN --output_classes=2 \
30 | # --enc_type=bi --enc_num_units=256 --cls_attention --cls_attention_size=50 \
31 | # --learning_rate=0.001 --batch_size=32 --max_len=50 \
32 | # --num_epochs=10 --print_every_steps=100 --stop_steps=5000 \
33 | # --output_dir=$MODEL_DIR/cls_output_test \
34 | # --save_checkpoints \
35 | # --num_gpus=0 \
36 | # --use_defending_as_target
37 |
38 |
--------------------------------------------------------------------------------
/misc/use.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow_hub as hub
3 | import os
4 |
5 |
6 | class USE(object):
7 | def __init__(self, cache_path):
8 | super(USE, self).__init__()
9 | os.environ['TFHUB_CACHE_DIR'] = cache_path
10 | module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/3"
11 | self.embed = hub.Module(module_url)
12 | # config = tf.ConfigProto()
13 | # config.gpu_options.allow_growth = True
14 | # self.sess = None
15 | self.build_graph()
16 | # self.sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
17 |
18 | def set_sess(self, sess):
19 | self.sess = sess
20 |
21 | def build_graph(self):
22 | self.sts_input1 = tf.placeholder(tf.string, shape=(None))
23 | self.sts_input2 = tf.placeholder(tf.string, shape=(None))
24 |
25 | sts_encode1 = tf.nn.l2_normalize(self.embed(self.sts_input1), axis=1)
26 | sts_encode2 = tf.nn.l2_normalize(self.embed(self.sts_input2), axis=1)
27 | self.cosine_similarities = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2), axis=1)
28 | clip_cosine_similarities = tf.clip_by_value(self.cosine_similarities, -1.0, 1.0)
29 | self.sim_scores = 1.0 - tf.acos(clip_cosine_similarities)
30 |
31 | def semantic_sim(self, sents1, sents2):
32 | scores = self.sess.run(
33 | self.sim_scores,
34 | feed_dict={
35 | self.sts_input1: sents1,
36 | self.sts_input2: sents2,
37 | })
38 | return scores
--------------------------------------------------------------------------------
/misc/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys, six
3 |
4 | def prepare_batch(inputs):
5 | sequence_lengths = [len(seq) for seq in inputs]
6 | batch_size = len(inputs)
7 | max_sequence_length = max(sequence_lengths)
8 |
9 | inputs_batch_major = np.zeros(shape=[batch_size, max_sequence_length],
10 | dtype=np.int32)
11 |
12 | for i, seq in enumerate(inputs):
13 | for j, element in enumerate(seq):
14 | inputs_batch_major[i, j] = element
15 |
16 | return inputs_batch_major, sequence_lengths
17 |
18 |
19 | def batch_generator(x, y):
20 | while True:
21 | i = np.random.randint(0, len(x))
22 | yield [x[i], y[i]]
23 |
24 | def input_generator(x, y, batch_size):
25 | gen_batch = batch_generator(x, y)
26 |
27 | x_batch = []
28 | y_batch = []
29 | for i in range(batch_size):
30 | a, b= next(gen_batch)
31 | x_batch += [a]
32 | y_batch += [b]
33 | return x_batch, y_batch
34 |
35 | def print_out(s, f=None, new_line=True):
36 | """Similar to print but with support to flush and output to a file."""
37 | if isinstance(s, bytes):
38 | s = s.decode("utf-8")
39 |
40 | if f:
41 | f.write(s.encode("utf-8"))
42 | if new_line:
43 | f.write(b"\n")
44 |
45 | # stdout
46 | if six.PY2:
47 | sys.stdout.write(s.encode("utf-8"))
48 | else:
49 | sys.stdout.buffer.write(s.encode("utf-8"))
50 |
51 | if new_line:
52 | sys.stdout.write("\n")
53 | sys.stdout.flush()
54 |
55 | def readlines(input_file):
56 | ret = []
57 | for line in open(input_file, 'r'):
58 | ret.append(line.strip())
59 | return ret
60 |
61 | def write_lines(arr_list, output_path):
62 | with open(output_path, 'w') as output_file:
63 | output_file.write('\n'.join(arr_list))
64 | return
65 |
66 | import json
67 | def write_numpy_array(emb_mat, output_path):
68 | with open(output_path, 'w') as outfile:
69 | json.dump(emb_mat.tolist(), outfile)
--------------------------------------------------------------------------------
/train_cls.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source /Users/yxu132/pyflow3.6/bin/activate
4 | DATA_DIR=data
5 | MODEL_DIR=saved_models
6 |
7 | # RNN train
8 | python train.py --do_train \
9 | --vocab_file=$DATA_DIR/vocab.in \
10 | --emb_file=$DATA_DIR/emb.json \
11 | --input_file=$DATA_DIR/train.in \
12 | --output_file=$DATA_DIR/train.out \
13 | --dev_file=$DATA_DIR/dev.in \
14 | --dev_output=$DATA_DIR/dev.out \
15 | --classification --classification_model=RNN --output_classes=2 \
16 | --enc_type=bi --enc_num_units=256 --cls_attention --cls_attention_size=50 \
17 | --learning_rate=0.001 --batch_size=32 --max_len=50 \
18 | --num_epochs=10 --print_every_steps=100 --stop_steps=5000 \
19 | --output_dir=$MODEL_DIR/cls_output_rnn \
20 | --save_checkpoints \
21 | --num_gpus=0
22 |
23 | ## CNN train
24 | #python train.py --do_train \
25 | # --vocab_file=$DATA_DIR/vocab.in \
26 | # --emb_file=$DATA_DIR/emb.json \
27 | # --input_file=$DATA_DIR/train.in \
28 | # --output_file=$DATA_DIR/train.out \
29 | # --dev_file=$DATA_DIR/dev.in \
30 | # --dev_output=$DATA_DIR/dev.out \
31 | # --classification --classification_model=CNN --output_classes=2 \
32 | # --enc_type=bi --enc_num_units=256 --cls_attention_size=50 \
33 | # --learning_rate=0.001 --batch_size=32 --max_len=50 --dropout_keep_prob=0.8 \
34 | # --num_epochs=10 --print_every_steps=100 --stop_steps=5000 \
35 | # --output_dir=$MODEL_DIR/cls_output_cnn \
36 | # --save_checkpoints \
37 | # --num_gpus=0
38 |
39 | ## BERT train
40 | #python train.py --do_train \
41 | # --vocab_file=/Users/yxu132/data/bert/uncased_L-12_H-768_A-12/vocab.txt \
42 | # --input_file=$DATA_DIR/train.in \
43 | # --output_file=$DATA_DIR/train.out \
44 | # --dev_file=$DATA_DIR/dev.in \
45 | # --dev_output=$DATA_DIR/dev.out \
46 | # --classification --classification_model=BERT --output_classes=2 \
47 | # --bert_config_file=/Users/yxu132/data/bert/uncased_L-12_H-768_A-12/bert_config.json \
48 | # --bert_init_chk=/Users/yxu132/data/bert/uncased_L-12_H-768_A-12/bert_model.ckpt \
49 | # --learning_rate=1e-5 --batch_size=32 --max_len=50 \
50 | # --num_epochs=10 --print_every_steps=100 --stop_steps=5000 \
51 | # --output_dir=$MODEL_DIR/cls_output_bert \
52 | # --save_checkpoints \
53 | # --num_gpus=0
--------------------------------------------------------------------------------
/test_adv.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source /Users/yxu132/pyflow3.6/bin/activate
4 | DATA_DIR=data
5 | MODEL_DIR=saved_models
6 |
7 | # AE+LS+CF
8 | python train.py --do_test \
9 | --vocab_file=$DATA_DIR/vocab.in \
10 | --emb_file=$DATA_DIR/emb.json \
11 | --test_file=$DATA_DIR/test.in \
12 | --test_output=$DATA_DIR/test.out \
13 | --load_model=$MODEL_DIR/adv_output_lscf/nmt-T2.ckpt \
14 | --adv --classification_model=RNN --output_classes=2 \
15 | --gumbel_softmax_temporature=0.1 \
16 | --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
17 | --cls_attention --cls_attention_size=50 --attention \
18 | --batch_size=16 --max_len=50 \
19 | --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
20 | --output_dir=$MODEL_DIR/adv_test_output \
21 | --save_checkpoints \
22 | --num_gpus=0 \
23 | --ae_vocab_file=$DATA_DIR/cf_vocab.in \
24 | --ae_emb_file=$DATA_DIR/cf_emb.json \
25 | --use_cache_dir=/dccstor/ddig/ying/use_cache \
26 | --accept_name=xlnet
27 |
28 |
29 | ## AE+LS+CF+CPY
30 | #python train.py --do_test \
31 | # --vocab_file=$DATA_DIR/vocab.in \
32 | # --emb_file=$DATA_DIR/emb.json \
33 | # --test_file=$DATA_DIR/test.in \
34 | # --test_output=$DATA_DIR/test.out \
35 | # --load_model=$MODEL_DIR/adv_output_lscf/nmt-T2.ckpt \
36 | # --adv --classification_model=RNN --output_classes=2 \
37 | # --gumbel_softmax_temporature=0.1 \
38 | # --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
39 | # --cls_attention --cls_attention_size=50 --attention \
40 | # --batch_size=16 --max_len=50 \
41 | # --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
42 | # --output_dir=adv_test \
43 | # --save_checkpoints \
44 | # --num_gpus=0 \
45 | # --ae_vocab_file=$DATA_DIR/cf_vocab.in \
46 | # --ae_emb_file=$DATA_DIR/cf_emb.json \
47 | # --use_cache_dir=/dccstor/ddig/ying/use_cache \
48 | # --accept_name=xlnet \
49 | # --copy --attention_copy_mask --use_stop_words --top_k_attack=9
50 | #
51 | # Test Conditional Generation: AE+LS+CF
52 | #python train.py --do_test \
53 | # --vocab_file=$DATA_DIR/vocab.in \
54 | # --emb_file=$DATA_DIR/emb.json \
55 | # --test_file=$DATA_DIR/test.in \
56 | # --test_output=$DATA_DIR/test.out \
57 | # --load_model_pos=$MODEL_DIR/adv_output_lscfcp_ptn \
58 | # --load_model_neg=$MODEL_DIR/adv_output_lscfcp_ntp \
59 | # --adv --classification_model=RNN --output_classes=2 \
60 | # --gumbel_softmax_temporature=0.1 \
61 | # --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
62 | # --cls_attention --cls_attention_size=50 --attention \
63 | # --batch_size=16 --max_len=50 \
64 | # --output_dir=$MODEL_DIR/adv_test_lscfcp_ptn \
65 | # --save_checkpoints \
66 | # --num_gpus=0 \
67 | # --ae_vocab_file=$DATA_DIR/cf_vocab.in \
68 | # --ae_emb_file=$DATA_DIR/cf_emb.json
69 | #
70 | #
71 | ### AE+LS+CF+DEFENCE
72 | #python train.py --do_test \
73 | # --vocab_file=$DATA_DIR/vocab.in \
74 | # --emb_file=$DATA_DIR/emb.json \
75 | # --test_file=$DATA_DIR/test.in \
76 | # --test_output=$DATA_DIR/test.out \
77 | # --load_model=$MODEL_DIR/adv_output_lscf/nmt-T2.ckpt \
78 | # --adv --classification_model=RNN --output_classes=2 --defending \
79 | # --gumbel_softmax_temporature=0.1 \
80 | # --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
81 | # --cls_attention --cls_attention_size=50 --attention \
82 | # --batch_size=16 --max_len=50 \
83 | # --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
84 | # --output_dir=$MODEL_DIR/adv_def_test \
85 | # --save_checkpoints \
86 | # --num_gpus=0 \
87 | # --ae_vocab_file=$DATA_DIR/cf_vocab.in \
88 | # --ae_emb_file=$DATA_DIR/cf_emb.json \
89 | # --use_cache_dir=/dccstor/ddig/ying/use_cache \
90 | # --accept_name=xlnet
91 |
92 |
93 |
--------------------------------------------------------------------------------
/misc/scripts/bleu.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Python implementation of BLEU and smooth-BLEU.
17 |
18 | This module provides a Python implementation of BLEU and smooth-BLEU.
19 | Smooth BLEU is computed following the method outlined in the paper:
20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic
21 | evaluation metrics for machine translation. COLING 2004.
22 | """
23 |
24 | import collections
25 | import math
26 | import sys
27 |
28 | def _get_ngrams(segment, max_order):
29 | """Extracts all n-grams upto a given maximum order from an input segment.
30 |
31 | Args:
32 | segment: text segment from which n-grams will be extracted.
33 | max_order: maximum length in tokens of the n-grams returned by this
34 | methods.
35 |
36 | Returns:
37 | The Counter containing all n-grams upto max_order in segment
38 | with a count of how many times each n-gram occurred.
39 | """
40 | ngram_counts = collections.Counter()
41 | for order in range(1, max_order + 1):
42 | for i in range(0, len(segment) - order + 1):
43 | ngram = tuple(segment[i:i+order])
44 | ngram_counts[ngram] += 1
45 | return ngram_counts
46 |
47 |
48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4,
49 | smooth=False):
50 | """Computes BLEU score of translated segments against one or more references.
51 |
52 | Args:
53 | reference_corpus: list of lists of references for each translation. Each
54 | reference should be tokenized into a list of tokens.
55 | translation_corpus: list of translations to score. Each translation
56 | should be tokenized into a list of tokens.
57 | max_order: Maximum n-gram order to use when computing BLEU score.
58 | smooth: Whether or not to apply Lin et al. 2004 smoothing.
59 |
60 | Returns:
61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
62 | precisions and brevity penalty.
63 | """
64 | matches_by_order = [0] * max_order
65 | possible_matches_by_order = [0] * max_order
66 | reference_length = 0
67 | translation_length = 0
68 | for (references, translation) in zip(reference_corpus,
69 | translation_corpus):
70 | reference_length += min(len(r) for r in references)
71 | translation_length += len(translation)
72 |
73 | merged_ref_ngram_counts = collections.Counter()
74 | for reference in references:
75 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
76 | translation_ngram_counts = _get_ngrams(translation, max_order)
77 | overlap = translation_ngram_counts & merged_ref_ngram_counts
78 | for ngram in overlap:
79 | matches_by_order[len(ngram)-1] += overlap[ngram]
80 | for order in range(1, max_order+1):
81 | possible_matches = len(translation) - order + 1
82 | if possible_matches > 0:
83 | possible_matches_by_order[order-1] += possible_matches
84 |
85 | precisions = [0] * max_order
86 | for i in range(0, max_order):
87 | if smooth:
88 | precisions[i] = ((matches_by_order[i] + 1.) /
89 | (possible_matches_by_order[i] + 1.))
90 | else:
91 | if possible_matches_by_order[i] > 0:
92 | precisions[i] = (float(matches_by_order[i]) /
93 | possible_matches_by_order[i])
94 | else:
95 | precisions[i] = 0.0
96 |
97 | if min(precisions) > 0:
98 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
99 | geo_mean = math.exp(p_log_sum)
100 | else:
101 | geo_mean = 0
102 |
103 | ratio = float(translation_length) / reference_length
104 |
105 | if ratio > 1.0:
106 | bp = 1.
107 | else:
108 | if ratio == 0:
109 | ratio = sys.float_info.epsilon
110 | bp = math.exp(1 - 1. / ratio)
111 |
112 | bleu = geo_mean * bp
113 |
114 | return (bleu, precisions, bp, ratio, translation_length, reference_length)
115 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.framework import ops
3 | from tensorflow.python.ops import array_ops
4 | from tensorflow.python.ops import nn_ops
5 | from tensorflow.python.ops import math_ops
6 |
7 | def sequence_loss(logits, targets, weights,
8 | average_across_timesteps=True, average_across_batch=True,
9 | softmax_loss_function=None, name=None,
10 | max_across_timesteps=False):
11 | """Weighted cross-entropy loss for a sequence of logits (per example).
12 |
13 | Args:
14 | logits: A 3D Tensor of shape
15 | [batch_size x sequence_length x num_decoder_symbols] and dtype float.
16 | The logits correspond to the prediction across all classes at each
17 | timestep.
18 | targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype
19 | int. The target represents the true class at each timestep.
20 | weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype
21 | float. Weights constitutes the weighting of each prediction in the
22 | sequence. When using weights as masking set all valid timesteps to 1 and
23 | all padded timesteps to 0.
24 | average_across_timesteps: If set, sum the cost across the sequence
25 | dimension and divide by the cost by the total label weight across
26 | timesteps.
27 | average_across_batch: If set, sum the cost across the batch dimension and
28 | divide the returned cost by the batch size.
29 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
30 | to be used instead of the standard softmax (the default if this is None).
31 | name: Optional name for this operation, defaults to "sequence_loss".
32 | Returns:
33 | A scalar float Tensor: The average log-perplexity per symbol (weighted).
34 |
35 | Raises:
36 | ValueError: logits does not have 3 dimensions or targets does not have 2
37 | dimensions or weights does not have 2 dimensions.
38 | """
39 | if len(logits.get_shape()) != 3:
40 | raise ValueError("Logits must be a "
41 | "[batch_size x sequence_length x logits] tensor")
42 | if len(targets.get_shape()) != 2:
43 | raise ValueError("Targets must be a [batch_size x sequence_length] "
44 | "tensor")
45 | if len(weights.get_shape()) != 2:
46 | raise ValueError("Weights must be a [batch_size x sequence_length] "
47 | "tensor")
48 | with ops.name_scope(name, "sequence_loss", [logits, targets, weights]):
49 | num_classes = array_ops.shape(logits)[2]
50 | probs_flat = array_ops.reshape(logits, [-1, num_classes])
51 | targets = array_ops.reshape(targets, [-1])
52 | if softmax_loss_function is None:
53 | crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(labels=targets, logits=probs_flat)
54 | else:
55 | crossent = softmax_loss_function(probs_flat, targets)
56 | crossent = crossent * array_ops.reshape(weights, [-1])
57 | if average_across_timesteps and average_across_batch:
58 | crossent = math_ops.reduce_sum(crossent)
59 | total_size = math_ops.reduce_sum(weights)
60 | total_size += 1e-12 # to avoid division by 0 for all-0 weights
61 | crossent /= total_size
62 | else:
63 | batch_size = array_ops.shape(logits)[0]
64 | sequence_length = array_ops.shape(logits)[1]
65 | crossent = array_ops.reshape(crossent, [batch_size, sequence_length])
66 | if average_across_timesteps and not average_across_batch:
67 | crossent = math_ops.reduce_sum(crossent, axis=[1])
68 | total_size = math_ops.reduce_sum(weights, axis=[1])
69 | total_size += 1e-12 # to avoid division by 0 for all-0 weights
70 | crossent /= total_size
71 | if not average_across_timesteps and average_across_batch:
72 | crossent = math_ops.reduce_sum(crossent, axis=[0])
73 | total_size = math_ops.reduce_sum(weights, axis=[0])
74 | total_size += 1e-12 # to avoid division by 0 for all-0 weights
75 | crossent /= total_size
76 | if max_across_timesteps:
77 | crossent = math_ops.reduce_max(crossent, axis=[1])
78 | crossent = math_ops.reduce_mean(crossent, axis=[0])
79 | return crossent
80 |
81 | def hinge_loss(logits, targets, delta):
82 | logits = math_ops.to_float(logits)
83 | targets = math_ops.to_float(targets)
84 | correct_label_scores = math_ops.reduce_sum(math_ops.multiply(logits, 1-targets), axis=-1)
85 | incorrect_label_scores = math_ops.reduce_sum(math_ops.multiply(logits, targets), axis=-1)
86 | incrrect_correct_different = (incorrect_label_scores - correct_label_scores)
87 | target_output = tf.cast(targets[:, -1], dtype=tf.float32)
88 | loss = math_ops.maximum(delta - tf.reduce_sum(math_ops.multiply(incrrect_correct_different, target_output)) / tf.reduce_sum(target_output),
89 | delta - tf.reduce_sum(math_ops.multiply(incrrect_correct_different, (1-target_output))) / tf.reduce_sum(1-target_output)
90 | )
91 | # loss = tf.reduce_mean(math_ops.maximum(0.0, delta - incrrect_correct_different))
92 | return loss
93 |
94 | def cos_dist_loss(emb1, emb2):
95 | normalize_a = tf.nn.l2_normalize(emb1, 1)
96 | normalize_b = tf.nn.l2_normalize(emb2, 1)
97 | cos_distance = 1 - tf.reduce_sum(tf.multiply(normalize_a, normalize_b), axis=-1)
98 | return cos_distance
99 |
100 | def get_device_str(num_gpus, gpu_rellocate=False):
101 | """Return a device string for multi-GPU setup."""
102 | if num_gpus == 0:
103 | return "/cpu:0"
104 | device_str_output = "/gpu:0"
105 | if num_gpus > 1 and gpu_rellocate:
106 | device_str_output = "/gpu:1"
107 | return device_str_output
108 |
109 | def make_cell(rnn_size, device_str, trainable=True):
110 | enc_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size, trainable=trainable)
111 | enc_cell = tf.contrib.rnn.DeviceWrapper(enc_cell, device_str)
112 | print(" %s, device=%s" % (type(enc_cell).__name__, device_str))
113 | return enc_cell
--------------------------------------------------------------------------------
/models/myBertClassifier.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Adv-Def-Text Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Bert-based sentiment classification model.
18 | Author: Ying Xu
19 | Date: Jul 8, 2020
20 | """
21 |
22 | import tensorflow as tf
23 | from tensorflow.python.layers import core as layers_core
24 | from misc import input_data
25 | from models.bert import modeling
26 |
27 |
28 | def get_device_str(num_gpus):
29 | """Return a device string for multi-GPU setup."""
30 | if num_gpus == 0:
31 | return "/cpu:0"
32 | device_str_output = "/gpu:0"
33 | return device_str_output
34 |
35 |
36 | class BertClassificationModel():
37 | def __init__(self, args, bert_config, mode=None):
38 |
39 | self.mode = mode
40 | self.bidirectional = True if args.enc_type == 'bi' else False
41 | self.args = args
42 | self.batch_size = args.batch_size
43 |
44 | self._make_graph(bert_config)
45 |
46 | def _make_graph(self, bert_config):
47 |
48 | self._init_placeholders()
49 |
50 | self.input_mask = tf.sequence_mask(
51 | tf.to_int32(self.encoder_inputs_length),
52 | tf.reduce_max(self.encoder_inputs_length),
53 | dtype=tf.int32)
54 |
55 | self.segment_ids = tf.sequence_mask(
56 | tf.to_int32(self.encoder_inputs_length),
57 | tf.reduce_max(self.encoder_inputs_length),
58 | dtype=tf.int32)
59 |
60 | self.segment_ids = 0 * self.segment_ids
61 |
62 | old_ = True if ((self.args.test_file is not None and 'yelp' in self.args.test_file) or
63 | (self.args.input_file is not None and 'yelp' in self.args.input_file)) else False
64 |
65 | self.model = modeling.BertModel(
66 | config=bert_config,
67 | is_training=(self.mode == 'Train'),
68 | input_ids=self.encoder_inputs,
69 | input_mask=self.input_mask,
70 | token_type_ids=self.segment_ids,
71 | use_one_hot_embeddings=False,
72 | word_embedding_trainable=(self.mode == 'Train'),
73 | )
74 |
75 | encoder_outputs = self.model.get_pooled_output()
76 |
77 | with tf.variable_scope("classification") as scope:
78 | fc_output = tf.layers.dense(encoder_outputs, 1024, activation=tf.nn.relu)
79 | projection_layer = layers_core.Dense(units=self.args.output_classes, name="projection_layer")
80 | with tf.device(get_device_str(self.args.num_gpus)):
81 | self.logits = tf.nn.tanh(projection_layer(fc_output)) # [batch size, output_classes]
82 |
83 | # if self.mode == "Train":
84 | self._init_optimizer()
85 |
86 | def _init_placeholders(self):
87 | self.encoder_inputs = tf.placeholder(
88 | shape=(None, None),
89 | dtype=tf.int32,
90 | name='encoder_inputs'
91 | )
92 |
93 | self.segment_ids = tf.placeholder(
94 | shape=(None, None),
95 | dtype=tf.int32,
96 | name='segment_ids'
97 | )
98 |
99 | self.encoder_inputs_length = tf.placeholder(
100 | shape=(None,),
101 | dtype=tf.int32,
102 | name='encoder_inputs_length',
103 | )
104 |
105 | self.classification_outputs = tf.placeholder(
106 | shape=(None, None),
107 | dtype=tf.int32,
108 | name='classification_outputs',
109 | )
110 |
111 | def _init_embedding(self):
112 | self.embedding_encoder = input_data._create_pretrained_emb_from_txt(
113 | vocab_file=self.args.vocab_file, embed_file=self.args.emb_file)
114 | self.encoder_embedding_inputs = tf.nn.embedding_lookup(
115 | self.embedding_encoder,
116 | self.encoder_inputs) # [batch size, sequence len, h_dim]
117 |
118 |
119 | def _init_optimizer(self):
120 | if self.args.output_classes > 2:
121 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
122 | logits=self.logits, labels=self.classification_outputs))
123 | else:
124 | self.target_output = tf.cast(self.classification_outputs[:, -1], dtype=tf.int32)
125 | self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
126 | logits=self.logits, labels=self.target_output))
127 | tf.summary.scalar('loss', self.loss)
128 | self.summary_op = tf.summary.merge_all()
129 |
130 | learning_rate = self.args.learning_rate
131 | optimizer = tf.train.AdamOptimizer(learning_rate)
132 | gradients = optimizer.compute_gradients(self.loss)
133 | capped_gradients = [
134 | (tf.clip_by_value(grad, -1.0 * self.args.max_gradient_norm, self.args.max_gradient_norm), var) for grad, var
135 | in gradients if grad is not None]
136 | self.train_op = optimizer.apply_gradients(capped_gradients)
137 |
138 | # inputs and outputs for train/infer
139 |
140 | def make_train_inputs(self, x):
141 | return {
142 | self.encoder_inputs: x[0],
143 | self.classification_outputs: x[1],
144 | self.encoder_inputs_length: x[2]
145 | }
146 |
147 | def embedding_encoder_fn(self):
148 | return self.embedding_encoder
149 |
150 | def get_bert_embedding(self):
151 | return self.model.embedding_table
152 |
153 | def make_train_outputs(self, full_loss_step=True, defence=False):
154 | return [self.train_op, self.loss, self.logits, self.summary_op]
155 |
156 | def make_eval_outputs(self):
157 | return self.loss
158 |
159 | def make_test_outputs(self):
160 | return [self.loss, self.logits, self.encoder_inputs, self.classification_outputs]
--------------------------------------------------------------------------------
/models/myCopyDecoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A class of Decoders that may sample to generate the next input.
16 | """
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | from tensorflow.contrib.seq2seq.python.ops import decoder
25 | from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
26 | from tensorflow.python.framework import ops
27 | from tensorflow.python.framework import tensor_shape
28 | from tensorflow.python.layers import base as layers_base
29 | from tensorflow.python.ops import rnn_cell_impl
30 | from tensorflow.python.util import nest
31 | import tensorflow as tf
32 |
33 | __all__ = [
34 | "BasicDecoderOutput",
35 | "CopyDecoder",
36 | ]
37 |
38 |
39 | class BasicDecoderOutput(
40 | collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))):
41 | pass
42 |
43 |
44 | class CopyDecoder(decoder.Decoder):
45 | """Basic sampling decoder."""
46 |
47 | def __init__(self, cell, helper, initial_state, copy_mask, encoder_input_ids, vocab_size, output_layer=None):
48 | """Initialize BasicDecoder.
49 |
50 | Args:
51 | cell: An `RNNCell` instance.
52 | helper: A `Helper` instance.
53 | initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
54 | The initial state of the RNNCell.
55 | output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
56 | `tf.layers.Dense`. Optional layer to apply to the RNN output prior
57 | to storing the result or sampling.
58 |
59 | Raises:
60 | TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
61 | """
62 | rnn_cell_impl.assert_like_rnncell("cell", cell)
63 | if not isinstance(helper, helper_py.Helper):
64 | raise TypeError("helper must be a Helper, received: %s" % type(helper))
65 | if (output_layer is not None
66 | and not isinstance(output_layer, layers_base.Layer)):
67 | raise TypeError(
68 | "output_layer must be a Layer, received: %s" % type(output_layer))
69 | self._cell = cell
70 | self._helper = helper
71 | self._initial_state = initial_state
72 | self._output_layer = output_layer
73 | self._copy_mask = copy_mask
74 | self._encoder_input_ids = encoder_input_ids
75 | self._vocab_size = vocab_size
76 |
77 | @property
78 | def batch_size(self):
79 | return self._helper.batch_size
80 |
81 | def _rnn_output_size(self):
82 | size = self._cell.output_size
83 | if self._output_layer is None:
84 | return size
85 | else:
86 | # To use layer's compute_output_shape, we need to convert the
87 | # RNNCell's output_size entries into shapes with an unknown
88 | # batch size. We then pass this through the layer's
89 | # compute_output_shape and read off all but the first (batch)
90 | # dimensions to get the output size of the rnn with the layer
91 | # applied to the top.
92 | output_shape_with_unknown_batch = nest.map_structure(
93 | lambda s: tensor_shape.TensorShape([None]).concatenate(s),
94 | size)
95 | layer_output_shape = self._output_layer.compute_output_shape(
96 | output_shape_with_unknown_batch)
97 | return nest.map_structure(lambda s: s[1:], layer_output_shape)
98 |
99 | @property
100 | def output_size(self):
101 | # Return the cell output and the id
102 | return BasicDecoderOutput(
103 | rnn_output=self._rnn_output_size(),
104 | sample_id=self._helper.sample_ids_shape)
105 |
106 | @property
107 | def output_dtype(self):
108 | # Assume the dtype of the cell is the output_size structure
109 | # containing the input_state's first component's dtype.
110 | # Return that structure and the sample_ids_dtype from the helper.
111 | dtype = nest.flatten(self._initial_state)[0].dtype
112 | return BasicDecoderOutput(
113 | nest.map_structure(lambda _: dtype, self._rnn_output_size()),
114 | self._helper.sample_ids_dtype)
115 |
116 | def initialize(self, name=None):
117 | """Initialize the decoder.
118 |
119 | Args:
120 | name: Name scope for any created operations.
121 |
122 | Returns:
123 | `(finished, first_inputs, initial_state)`.
124 | """
125 | return self._helper.initialize() + (self._initial_state,)
126 |
127 | def step(self, time, inputs, state, name=None):
128 | """Perform a decoding step.
129 |
130 | Args:
131 | time: scalar `int32` tensor.
132 | inputs: A (structure of) input tensors.
133 | state: A (structure of) state tensors and TensorArrays.
134 | name: Name scope for any created operations.
135 |
136 | Returns:
137 | `(outputs, next_state, next_inputs, finished)`.
138 | """
139 | with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
140 | cell_outputs, cell_state = self._cell(inputs, state)
141 |
142 | generate_score = self._output_layer(cell_outputs)
143 |
144 | copy_mask_cur = self._copy_mask[:, time] # [batch_size, ?]
145 | ids_cur = self._encoder_input_ids[:, time]
146 |
147 | prob_one_hot = tf.one_hot(ids_cur, self._vocab_size) * 1e7 # [batch_size, vocab_size]
148 | prob_c_one_hot = prob_one_hot * tf.expand_dims(copy_mask_cur, -1)
149 |
150 | cell_outputs = prob_c_one_hot + generate_score
151 |
152 |
153 | sample_ids = self._helper.sample(
154 | time=time, outputs=cell_outputs, state=cell_state)
155 | (finished, next_inputs, next_state) = self._helper.next_inputs(
156 | time=time,
157 | outputs=cell_outputs,
158 | state=cell_state,
159 | sample_ids=sample_ids)
160 | outputs = BasicDecoderOutput(cell_outputs, sample_ids)
161 | return (outputs, next_state, next_inputs, finished)
--------------------------------------------------------------------------------
/models/bert/config.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import models.bert.tokenization as tokenization
3 |
4 | flags = tf.flags
5 |
6 | FLAGS = flags.FLAGS
7 |
8 | ## Required parameters
9 | flags.DEFINE_string(
10 | "bert_config_file", None,
11 | "The config json file corresponding to the pre-trained BERT model. "
12 | "This specifies the model architecture.")
13 |
14 | flags.DEFINE_string("vocab_file", None,
15 | "The vocabulary file that the BERT model was trained on.")
16 |
17 | flags.DEFINE_string(
18 | "output_dir", None,
19 | "The output directory where the model checkpoints will be written.")
20 |
21 | ## Other parameters
22 | flags.DEFINE_string("train_file", None,
23 | "Json for training. E.g., train-v1.1.json")
24 |
25 | flags.DEFINE_string("eval_file", None, "Json for validation, e.g. squad.biomedical.dev.json")
26 |
27 | flags.DEFINE_string(
28 | "predict_file", None,
29 | "Json for predictions. E.g., test-v1.1.json")
30 |
31 | flags.DEFINE_string("squad_dev_file", None, "SQuAD json for validation. E.g., test-v1.1.json")
32 |
33 | flags.DEFINE_string(
34 | "init_checkpoint", None,
35 | "Initial checkpoint (usually from a pre-trained BERT model).")
36 |
37 | flags.DEFINE_bool(
38 | "do_lower_case", True,
39 | "Whether to lower case the input text. Should be True for uncased "
40 | "models and False for cased models.")
41 |
42 | flags.DEFINE_integer(
43 | "max_seq_length", 384,
44 | "The maximum total input sequence length after WordPiece tokenization. "
45 | "Sequences longer than this will be truncated, and sequences shorter "
46 | "than this will be padded.")
47 |
48 | flags.DEFINE_integer(
49 | "doc_stride", 128,
50 | "When splitting up a long document into chunks, how much stride to "
51 | "take between chunks.")
52 |
53 | flags.DEFINE_integer(
54 | "max_query_length", 64,
55 | "The maximum number of tokens for the question. Questions longer than "
56 | "this will be truncated to this length.")
57 |
58 | flags.DEFINE_bool("do_train", True, "Whether to run training.")
59 |
60 | flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
61 |
62 | flags.DEFINE_bool("do_fisher", False, "Whether to run eval on the dev set.")
63 |
64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
65 |
66 | flags.DEFINE_integer("predict_batch_size", 8,
67 | "Total batch size for predictions.")
68 |
69 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
70 |
71 | flags.DEFINE_float("num_train_epochs", 3.0,
72 | "Total number of training epochs to perform.")
73 |
74 | flags.DEFINE_float(
75 | "warmup_proportion", 0.1,
76 | "Proportion of training to perform linear learning rate warmup for. "
77 | "E.g., 0.1 = 10% of training.")
78 |
79 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
80 | "How often to save the model checkpoint.")
81 |
82 | flags.DEFINE_integer("iterations_per_loop", 1000,
83 | "How many steps to make in each estimator call.")
84 |
85 | flags.DEFINE_integer(
86 | "n_best_size", 20,
87 | "The total number of n-best predictions to generate in the "
88 | "nbest_predictions.json output file.")
89 |
90 | flags.DEFINE_integer(
91 | "max_answer_length", 30,
92 | "The maximum length of an answer that can be generated. This is needed "
93 | "because the start and end predictions are not conditioned on one another.")
94 |
95 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
96 |
97 | tf.flags.DEFINE_string(
98 | "tpu_name", None,
99 | "The Cloud TPU to use for training. This should be either the name "
100 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
101 | "url.")
102 |
103 | tf.flags.DEFINE_string(
104 | "tpu_zone", None,
105 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
106 | "specified, we will attempt to automatically detect the GCE project from "
107 | "metadata.")
108 |
109 | tf.flags.DEFINE_string(
110 | "gcp_project", None,
111 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
112 | "specified, we will attempt to automatically detect the GCE project from "
113 | "metadata.")
114 |
115 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
116 |
117 | flags.DEFINE_integer(
118 | "num_tpu_cores", 8,
119 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
120 |
121 | flags.DEFINE_bool(
122 | "verbose_logging", False,
123 | "If true, all of the warnings related to data processing will be printed. "
124 | "A number of warnings are expected for a normal SQuAD evaluation.")
125 |
126 | flags.DEFINE_bool(
127 | "version_2_with_negative", False,
128 | "If true, the SQuAD examples contain some that do not have an answer.")
129 |
130 | flags.DEFINE_float(
131 | "null_score_diff_threshold", 0.0,
132 | "If null_score - best_non_null is greater than the threshold predict null.")
133 |
134 | flags.DEFINE_integer("eval_per_step", 1000, "Steps per evaluation")
135 |
136 | flags.DEFINE_bool("word_embedding_trainable", False, "Whether make bert word embedding trainable")
137 |
138 |
139 | ## added flags for EWC
140 | tf.flags.DEFINE_bool("adapt", False, " domain adaptation or not")
141 | tf.flags.DEFINE_string("ewc_filename", None, " file for save ewc fisher matrix")
142 | tf.flags.DEFINE_bool("l2_norm", False, "l2 norm or not")
143 | tf.flags.DEFINE_float("ewc_lambda", "0.0", "")
144 | tf.flags.DEFINE_float("pcd_gamma", "0.0", "")
145 | tf.flags.DEFINE_float("pl2_sigma", "0.0", "")
146 |
147 | ## added flags for GEM
148 | tf.flags.DEFINE_bool("gem", False, " set to use Gradient Episodic Memory")
149 | tf.flags.DEFINE_string("squad_train_file", None, " SQuAD training file")
150 |
151 |
152 |
153 | def validate_flags_or_throw(bert_config):
154 | """Validate the input FLAGS or throw an exception."""
155 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
156 | FLAGS.init_checkpoint)
157 | if FLAGS.do_train:
158 | if not FLAGS.train_file:
159 | raise ValueError(
160 | "If `do_train` is True, then `train_file` must be specified.")
161 | if FLAGS.do_predict:
162 | if not FLAGS.predict_file:
163 | raise ValueError(
164 | "If `do_predict` is True, then `predict_file` must be specified.")
165 |
166 | if FLAGS.max_seq_length > bert_config.max_position_embeddings:
167 | raise ValueError(
168 | "Cannot use sequence length %d because the BERT model "
169 | "was only trained up to sequence length %d" %
170 | (FLAGS.max_seq_length, bert_config.max_position_embeddings))
171 |
172 | if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
173 | raise ValueError(
174 | "The max_seq_length (%d) must be greater than max_query_length "
175 | "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length))
--------------------------------------------------------------------------------
/misc/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Adv-Def-Text Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Evaluation metrics.
18 | Author: Ying Xu
19 | Date: Jul 8, 2020
20 | """
21 |
22 | import re
23 | from misc.scripts import bleu, rouge
24 | from misc import acc_transformer
25 |
26 |
27 | ##################### Sequence Reconstruction evaluation scores ####################
28 |
29 | # from acceptability import test
30 | def _clean(sentence, subword_option):
31 | """Clean and handle BPE or SPM outputs."""
32 | sentence = sentence.strip()
33 |
34 | # BPE
35 | if subword_option == "bpe":
36 | sentence = re.sub("@@ ", "", sentence)
37 |
38 | # SPM
39 | elif subword_option == "spm":
40 | sentence = u"".join(sentence.split()).replace(u"\u2581", u" ").lstrip()
41 |
42 | return sentence
43 |
44 |
45 | # Follow //transconsole/localization/machine_translation/metrics/bleu_calc.py
46 | def _bleu(references, translations, subword_option=None):
47 | """Compute BLEU scores and handling BPE."""
48 | max_order = 4
49 | smooth = False
50 | references_bleu = [[reference] for reference in references]
51 | # bleu_score, precisions, bp, ratio, translation_length, reference_length
52 | bleu_score, _, _, _, _, _ = bleu.compute_bleu(
53 | references_bleu, translations, max_order, smooth)
54 | return 100 * bleu_score
55 |
56 |
57 | def _rouge(references, translations, subword_option=None):
58 | """Compute ROUGE scores and handling BPE."""
59 | translations_sent = [' '.join(translation) for translation in translations]
60 | references_sent = [' '.join(reference) for reference in references]
61 | rouge_score_map = rouge.rouge(translations_sent, references_sent)
62 | return 100 * rouge_score_map["rouge_l/f_score"]
63 |
64 |
65 | def _accuracy(references, translations):
66 | """Compute accuracy, each line contains a label."""
67 |
68 | count = 0.0
69 | match = 0.0
70 | for ind, label in enumerate(references):
71 | label_sentence = ' '.join(label)
72 | pred_sentence = ' '.join(translations[ind])
73 | if label_sentence == pred_sentence:
74 | match += 1
75 | count += 1
76 | return 100 * match / count
77 |
78 |
79 | def _word_accuracy(references, translations):
80 | """Compute accuracy on per word basis."""
81 |
82 | total_acc, total_count = 0., 0.
83 | for ind, reference in enumerate(references):
84 | translation = translations[ind]
85 | match = 0.0
86 | for pos in range(min(len(reference), len(translation))):
87 | label = reference[pos]
88 | pred = translation[pos]
89 | if label == pred:
90 | match += 1
91 | total_acc += 100 * match / max(len(reference), len(translation))
92 | total_count += 1
93 | return total_acc / total_count
94 |
95 |
96 | ##################### Classification evaluation scores ####################
97 | def max_index(arr):
98 | max_v, max_p = -1, -1
99 | for ind, a in enumerate(arr):
100 | a = float(a)
101 | if a > max_v:
102 | max_v = a
103 | max_p = ind
104 | return max_p
105 |
106 | def _clss_accuracy(labels, predicts):
107 | """Compute accuracy for classification"""
108 | total_count = 0.
109 | match = 0.0
110 | for ind, label in enumerate(labels):
111 | max_lab_index = max_index(label)
112 | max_pred_index = max_index(predicts[ind])
113 | if max_pred_index == max_lab_index:
114 | match += 1
115 | total_count += 1
116 | return 100.0 * match / total_count
117 |
118 | def _clss_accuracy_micro(labels, predicts, orig_label=1):
119 | """Compute accuracy for classification"""
120 | total_count = 0.
121 | match = 0.0
122 | for ind, label in enumerate(labels):
123 | max_lab_index = max_index(label)
124 | if max_lab_index == orig_label:
125 | max_pred_index = max_index(predicts[ind])
126 | if max_pred_index == max_lab_index:
127 | match += 1
128 | total_count += 1
129 | return 100.0 * match / total_count
130 |
131 | import numpy as np
132 | from sklearn import metrics
133 | def _clss_auc(labels, predicts):
134 | """c Compute auc for classification"""
135 | fpr, tpr, thresholds = metrics.roc_curve(np.array(labels)[:, 1], np.array(predicts)[:, 1], pos_label=1.0)
136 | auc = metrics.auc(fpr, tpr)
137 | return 100.0 * auc
138 |
139 |
140 |
141 |
142 | ##################### EMB similarity score ##################
143 | def gen_mask(len_lists, max_len):
144 | ret = []
145 | for a in len_lists:
146 | mask_array = np.zeros(max_len, dtype=int)
147 | mask_array[np.arange(a)] = 1
148 | ret.append(mask_array)
149 | return np.array(ret)
150 |
151 | from sklearn.metrics.pairwise import cosine_similarity
152 | def emb_cosine_dist(emb1, emb2, emb_len1, emb_len2):
153 | mask1 = gen_mask(emb_len1, len(emb1[0]))
154 | mask1 = np.expand_dims(mask1, axis=-1)
155 | mask2 = gen_mask(emb_len2, len(emb2[0]))
156 | mask2 = np.expand_dims(mask2, axis=-1)
157 | emb1 = np.multiply(emb1, mask1)
158 | emb2 = np.multiply(emb2, mask2)
159 | emb1 = np.sum(emb1, axis=1) / np.expand_dims(emb_len1, axis=-1)
160 | emb2 = np.sum(emb2, axis=1) / np.expand_dims(emb_len2, axis=-1)
161 | scores = []
162 | for ind, emb in enumerate(emb1):
163 | scores.append(cosine_similarity([emb], [emb2[ind]])[0][0])
164 | return scores
165 |
166 | ##################### ACPT score ##################
167 | def _accept_score(references, translations, args, lim=100):
168 | if args.accept_name is None:
169 | return 0.0
170 | references_sent = [' '.join(reference) for reference in references[:lim]]
171 | translations_sent = [' '.join(translation) for translation in translations[:lim]]
172 | scores = acc_transformer.evaluate_accept(references_sent, translations_sent, args)
173 | return sum(scores)/len(scores)
174 |
175 |
176 | ##################### USE score ##################
177 | def _use_scores(ref_changed, trans_changed, use_model, eval_num=200):
178 | if use_model is None:
179 | return 0.0
180 | cnt = 0
181 | sim_scores = []
182 | while cnt < min(len(ref_changed), eval_num):
183 | batch_src = ref_changed[cnt: cnt + 32]
184 | batch_tgt = trans_changed[cnt: cnt + 32]
185 | src_sent = [' '.join(a) for a in batch_src]
186 | tgt_sent = [' '.join(a) for a in batch_tgt]
187 | scores = use_model.semantic_sim(src_sent, tgt_sent)
188 | sim_scores.extend(scores)
189 | cnt += 32
190 | return sum(sim_scores) / len(sim_scores)
191 |
192 |
193 |
194 |
--------------------------------------------------------------------------------
/models/bert/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def get_optimizer(init_lr, num_train_steps, num_warmup_steps, use_tpu=False):
26 | global_step = tf.train.get_or_create_global_step()
27 |
28 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
29 |
30 | # Implements linear decay of the learning rate.
31 | learning_rate = tf.train.polynomial_decay(
32 | learning_rate,
33 | global_step,
34 | num_train_steps,
35 | end_learning_rate=0.0,
36 | power=1.0,
37 | cycle=False)
38 |
39 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
40 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
41 | if num_warmup_steps:
42 | global_steps_int = tf.cast(global_step, tf.int32)
43 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
44 |
45 | global_steps_float = tf.cast(global_steps_int, tf.float32)
46 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
47 |
48 | warmup_percent_done = global_steps_float / warmup_steps_float
49 | warmup_learning_rate = init_lr * warmup_percent_done
50 |
51 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
52 | learning_rate = (
53 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
54 |
55 | # It is recommended that you use this optimizer for fine tuning, since this
56 | # is how the model was trained (note that the Adam m/v variables are NOT
57 | # loaded from init_checkpoint.)
58 | optimizer = AdamWeightDecayOptimizer(
59 | learning_rate=learning_rate,
60 | weight_decay_rate=0.01,
61 | beta_1=0.9,
62 | beta_2=0.999,
63 | epsilon=1e-6,
64 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
65 |
66 | if use_tpu:
67 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
68 | return optimizer
69 |
70 | def create_optimizer(tvars, grads, optimizer):
71 | """Creates an optimizer training op."""
72 |
73 | global_step = tf.train.get_or_create_global_step()
74 |
75 | # This is how the model was pre-trained.
76 | (capped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
77 |
78 | train_op = optimizer.apply_gradients(
79 | zip(capped_grads, tvars), global_step=global_step)
80 |
81 | # Normally the global step update is done inside of `apply_gradients`.
82 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
83 | # a different optimizer, you should probably take this line out.
84 | new_global_step = global_step + 1
85 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
86 | return train_op
87 |
88 |
89 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
90 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
91 |
92 | def __init__(self,
93 | learning_rate,
94 | weight_decay_rate=0.0,
95 | beta_1=0.9,
96 | beta_2=0.999,
97 | epsilon=1e-6,
98 | exclude_from_weight_decay=None,
99 | name="AdamWeightDecayOptimizer"):
100 | """Constructs a AdamWeightDecayOptimizer."""
101 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
102 |
103 | self.learning_rate = learning_rate
104 | self.weight_decay_rate = weight_decay_rate
105 | self.beta_1 = beta_1
106 | self.beta_2 = beta_2
107 | self.epsilon = epsilon
108 | self.exclude_from_weight_decay = exclude_from_weight_decay
109 |
110 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
111 | """See base class."""
112 | assignments = []
113 | for (grad, param) in grads_and_vars:
114 | if grad is None or param is None:
115 | continue
116 |
117 | param_name = self._get_variable_name(param.name)
118 |
119 | with tf.variable_scope("AdamWeightDecayOptimizer_adam", reuse=tf.AUTO_REUSE):
120 | m = tf.get_variable(
121 | name=param_name + "/adam_m",
122 | shape=param.shape.as_list(),
123 | dtype=tf.float32,
124 | trainable=False,
125 | initializer=tf.zeros_initializer())
126 | v = tf.get_variable(
127 | name=param_name + "/adam_v",
128 | shape=param.shape.as_list(),
129 | dtype=tf.float32,
130 | trainable=False,
131 | initializer=tf.zeros_initializer())
132 |
133 | # Standard Adam update.
134 | next_m = (
135 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
136 | next_v = (
137 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
138 | tf.square(grad)))
139 |
140 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
141 |
142 | # Just adding the square of the weights to the loss function is *not*
143 | # the correct way of using L2 regularization/weight decay with Adam,
144 | # since that will interact with the m and v parameters in strange ways.
145 | #
146 | # Instead we want ot decay the weights in a manner that doesn't interact
147 | # with the m/v parameters. This is equivalent to adding the square
148 | # of the weights to the loss with plain (non-momentum) SGD.
149 | if self._do_use_weight_decay(param_name):
150 | update += self.weight_decay_rate * param
151 |
152 | update_with_lr = self.learning_rate * update
153 |
154 | next_param = param - update_with_lr
155 |
156 | assignments.extend(
157 | [param.assign(next_param),
158 | m.assign(next_m),
159 | v.assign(next_v)])
160 | return tf.group(*assignments, name=name)
161 |
162 | def _do_use_weight_decay(self, param_name):
163 | """Whether to use L2 weight decay for `param_name`."""
164 | if not self.weight_decay_rate:
165 | return False
166 | if self.exclude_from_weight_decay:
167 | for r in self.exclude_from_weight_decay:
168 | if re.search(r, param_name) is not None:
169 | return False
170 | return True
171 |
172 | def _get_variable_name(self, param_name):
173 | """Get the variable name from the tensor name."""
174 | m = re.match("^(.*):\\d+$", param_name)
175 | if m is not None:
176 | param_name = m.group(1)
177 | return param_name
178 |
--------------------------------------------------------------------------------
/models/myCNNClassifier.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Adv-Def-Text Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | CNN-based sentiment classification model.
18 | Author: Ying Xu
19 | Date: Jul 8, 2020
20 | """
21 |
22 | import tensorflow as tf
23 | from tensorflow.python.layers import core as layers_core
24 | from misc import input_data
25 | import numpy as np
26 | import models.utils as utils
27 |
28 |
29 | class CNNClassificationModel():
30 | def __init__(self, args, mode=None):
31 |
32 | self.mode = mode
33 | self.bidirectional = True if args.enc_type == 'bi' else False
34 | self.args = args
35 | self.batch_size = args.batch_size
36 |
37 | self._init_placeholders()
38 |
39 | self.encoder_outputs, self.cls_logits, self.acc, _ = self._make_graph()
40 |
41 | # if self.mode == "Train":
42 | self._init_optimizer()
43 |
44 | def _make_graph(self, encoder_embedding_inputs=None):
45 |
46 | with tf.variable_scope("my_classifier", reuse=tf.AUTO_REUSE) as scope:
47 | if encoder_embedding_inputs is None:
48 | self._init_embedding()
49 | encoder_outputs = self._init_encoder(encoder_embedding_inputs=(self.encoder_embedding_inputs
50 | if encoder_embedding_inputs is None else
51 | encoder_embedding_inputs))
52 |
53 | with tf.variable_scope("classification", reuse=tf.AUTO_REUSE) as scope:
54 | output_flatten = tf.reduce_mean(encoder_outputs, axis=1)
55 | fc_output = tf.layers.dense(output_flatten, 1024, activation=tf.nn.relu)
56 | projection_layer = layers_core.Dense(units=self.args.output_classes, name="projection_layer")
57 | with tf.device(utils.get_device_str(self.args.num_gpus)):
58 | logits = tf.nn.tanh(projection_layer(fc_output)) # [batch size, output_classes]
59 | ybar = tf.argmax(logits, axis=1, output_type=tf.int32)
60 | ylabel = tf.cast(self.classification_outputs[:, -1], dtype=tf.int32)
61 | count = tf.equal(ylabel, ybar)
62 | acc = tf.reduce_mean(tf.cast(count, tf.float32), name='acc')
63 | return encoder_outputs, logits, acc, None
64 |
65 | def _init_placeholders(self):
66 | self.encoder_inputs = tf.placeholder(
67 | shape=(None, None),
68 | dtype=tf.int32,
69 | name='encoder_inputs'
70 | )
71 |
72 | self.encoder_inputs_length = tf.placeholder(
73 | shape=(None,),
74 | dtype=tf.int32,
75 | name='encoder_inputs_length',
76 | )
77 |
78 | self.classification_outputs = tf.placeholder(
79 | shape=(None, None),
80 | dtype=tf.int32,
81 | name='classification_outputs',
82 | )
83 |
84 | def _init_embedding(self):
85 | self.embedding_encoder = input_data._create_pretrained_emb_from_txt(
86 | vocab_file=self.args.vocab_file, embed_file=self.args.emb_file)
87 | self.encoder_embedding_inputs = tf.nn.embedding_lookup(
88 | self.embedding_encoder,
89 | self.encoder_inputs) # [batch size, sequence len, h_dim]
90 |
91 | def _init_encoder(self, encoder_embedding_inputs=None, initializer=tf.random_normal_initializer(stddev=0.1)):
92 | with tf.variable_scope("Encoder", reuse=tf.AUTO_REUSE) as scope:
93 | with tf.device(utils.get_device_str(self.args.num_gpus)):
94 | for i, filter_size in enumerate(self.args.filter_sizes):
95 | filter = tf.get_variable("filter_"+str(filter_size), [filter_size, 300 if i==0 else self.args.enc_num_units,
96 | self.args.enc_num_units],
97 | initializer=initializer)
98 | conv = tf.nn.conv1d(encoder_embedding_inputs if i==0 else h, filter, stride=2, padding="VALID")
99 | conv = tf.contrib.layers.batch_norm(conv, is_training=True if (self.mode=='Train') else False, scope='cnn_bn_'+str(filter_size))
100 | b = tf.get_variable("b_"+str(filter_size), [self.args.enc_num_units])
101 | h = tf.nn.relu(tf.nn.bias_add(conv, b), "relu")
102 |
103 | with tf.name_scope("dropout"):
104 | h_drop = tf.nn.dropout(h, keep_prob=self.args.dropout_keep_prob)
105 | return h_drop
106 |
107 | def _init_optimizer(self):
108 | if self.args.output_classes > 2:
109 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
110 | logits=self.cls_logits, labels=self.classification_outputs))
111 | else:
112 | self.target_output = tf.cast(self.classification_outputs[:, -1], dtype=tf.int32)
113 | self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
114 | logits=self.cls_logits, labels=self.target_output))
115 | tf.summary.scalar('loss', self.loss)
116 | self.summary_op = tf.summary.merge_all()
117 |
118 | learning_rate = self.args.learning_rate
119 | optimizer = tf.train.AdamOptimizer(learning_rate)
120 | gradients = optimizer.compute_gradients(self.loss)
121 | capped_gradients = [
122 | (tf.clip_by_value(grad, -1.0 * self.args.max_gradient_norm, self.args.max_gradient_norm), var) for grad, var
123 | in gradients if grad is not None]
124 | self.train_op = optimizer.apply_gradients(capped_gradients)
125 |
126 | # inputs and outputs for train/infer
127 |
128 | def make_train_inputs(self, x, X_data=None):
129 | x_input = x[0]
130 | if X_data is not None:
131 | x_input = X_data
132 | if len(self.args.def_train_set) > 0:
133 | x_input_def = x[-2]
134 | x_input = np.concatenate([x_input, x_input_def], axis=0)
135 | y_input = np.concatenate([x[1], x[1]], axis=0)
136 | x_lenghts = np.concatenate([x[2], x[-1]], axis=0)
137 | else:
138 | y_input = x[1]
139 | x_lenghts = x[2]
140 | return {
141 | self.encoder_inputs: x_input,
142 | self.classification_outputs: y_input,
143 | self.encoder_inputs_length: x_lenghts
144 | }
145 |
146 | def embedding_encoder_fn(self):
147 | return self.embedding_encoder
148 |
149 | def make_train_outputs(self, full_loss_step=True, defence=False):
150 | return [self.train_op, self.loss, self.cls_logits, self.summary_op]
151 |
152 | def make_eval_outputs(self):
153 | return self.loss
154 |
155 | def make_test_outputs(self):
156 | return [self.loss, self.cls_logits, self.acc, self.encoder_inputs, self.classification_outputs]
157 |
158 | def make_encoder_output(self):
159 | return self.encoder_outputs
--------------------------------------------------------------------------------
/train_adv.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source /Users/yxu132/pyflow3.6/bin/activate
4 | DATA_DIR=data
5 | MODEL_DIR=saved_models
6 |
7 | # AE+bal
8 | python train.py --do_train \
9 | --vocab_file=$DATA_DIR/vocab.in \
10 | --emb_file=$DATA_DIR/emb.json \
11 | --input_file=$DATA_DIR/train.in \
12 | --output_file=$DATA_DIR/train.out \
13 | --dev_file=$DATA_DIR/dev.in \
14 | --dev_output=$DATA_DIR/dev.out \
15 | --load_model_cls=$MODEL_DIR/yelp50_x3-cls/bi_att \
16 | --load_model_ae=$MODEL_DIR/yelp50_x3-ae/bi_att \
17 | --adv --classification_model=RNN --output_classes=2 --balance \
18 | --gumbel_softmax_temporature=0.1 \
19 | --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
20 | --cls_attention --cls_attention_size=50 --attention \
21 | --batch_size=16 --max_len=50 \
22 | --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
23 | --learning_rate=0.0001 --ae_lambda=0.2 --seq_lambda=0.7 \
24 | --output_dir=$MODEL_DIR/adv_train_bal \
25 | --save_checkpoints \
26 | --num_gpus=0
27 |
28 |
29 | ## AE+LS
30 | #python train.py --do_train \
31 | # --vocab_file=$DATA_DIR/vocab.in \
32 | # --emb_file=$DATA_DIR/emb.json \
33 | # --input_file=$DATA_DIR/train.in \
34 | # --output_file=$DATA_DIR/train.out \
35 | # --dev_file=$DATA_DIR/dev.in \
36 | # --dev_output=$DATA_DIR/dev.out \
37 | # --load_model_cls=$MODEL_DIR/yelp50_x3-cls/bi_att \
38 | # --load_model_ae=$MODEL_DIR/yelp50_x3-ae/bi_att \
39 | # --adv --classification_model=RNN --output_classes=2 \
40 | # --gumbel_softmax_temporature=0.1 \
41 | # --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
42 | # --cls_attention --cls_attention_size=50 --attention \
43 | # --batch_size=16 --max_len=50 \
44 | # --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
45 | # --learning_rate=0.00001 --ae_lambda=0.8 --seq_lambda=1.0 \
46 | # --output_dir=$MODEL_DIR/adv_train_ls \
47 | # --save_checkpoints \
48 | # --num_gpus=0 \
49 | # --label_beta=0.95
50 | #
51 | #
52 | #
53 | ## AE+LS+GAN
54 | #python train.py --do_train \
55 | # --vocab_file=$DATA_DIR/vocab.in \
56 | # --emb_file=$DATA_DIR/emb.json \
57 | # --input_file=$DATA_DIR/train.in \
58 | # --output_file=$DATA_DIR/train.out \
59 | # --dev_file=$DATA_DIR/dev.in \
60 | # --dev_output=$DATA_DIR/dev.out \
61 | # --load_model_cls=$MODEL_DIR/yelp50_x3-cls/bi_att \
62 | # --load_model_ae=$MODEL_DIR/yelp50_x3-ae/bi_att \
63 | # --adv --classification_model=RNN --output_classes=2 \
64 | # --gumbel_softmax_temporature=0.1 \
65 | # --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
66 | # --cls_attention --cls_attention_size=50 --attention \
67 | # --batch_size=16 --max_len=50 \
68 | # --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
69 | # --learning_rate=0.00001 --ae_lambda=0.8 --seq_lambda=1.0 \
70 | # --output_dir=$MODEL_DIR/adv_train_lsgan \
71 | # --save_checkpoints \
72 | # --num_gpus=0 \
73 | # --label_beta=0.95 \
74 | # --gan --at_steps=2
75 | #
76 | #
77 | ## AE+LS+CF
78 | #python train.py --do_train \
79 | # --vocab_file=$DATA_DIR/vocab.in \
80 | # --emb_file=$DATA_DIR/emb.json \
81 | # --input_file=$DATA_DIR/train.in \
82 | # --output_file=$DATA_DIR/train.out \
83 | # --dev_file=$DATA_DIR/dev.in \
84 | # --dev_output=$DATA_DIR/dev.out \
85 | # --load_model_cls=$MODEL_DIR/yelp50_x3-cls/bi_att \
86 | # --load_model_ae=$MODEL_DIR/yelp50_x3-ae/bi_att_cf_fixed \
87 | # --adv --classification_model=RNN --output_classes=2 \
88 | # --gumbel_softmax_temporature=0.1 \
89 | # --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
90 | # --cls_attention --cls_attention_size=50 --attention \
91 | # --batch_size=16 --max_len=50 \
92 | # --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
93 | # --learning_rate=0.00001 --ae_lambda=0.8 --seq_lambda=1.0 \
94 | # --output_dir=$MODEL_DIR/adv_train_lscf \
95 | # --save_checkpoints \
96 | # --num_gpus=0 \
97 | # --label_beta=0.95 \
98 | # --ae_vocab_file=$DATA_DIR/cf_vocab.in \
99 | # --ae_emb_file=$DATA_DIR/cf_emb.json
100 | #
101 | ## AE+LS+CF+CPY
102 | #python train.py --do_train \
103 | # --vocab_file=$DATA_DIR/vocab.in \
104 | # --emb_file=$DATA_DIR/emb.json \
105 | # --input_file=$DATA_DIR/train.in \
106 | # --output_file=$DATA_DIR/train.out \
107 | # --dev_file=$DATA_DIR/dev.in \
108 | # --dev_output=$DATA_DIR/dev.out \
109 | # --load_model_cls=$MODEL_DIR/yelp50_x3-cls/bi_att \
110 | # --load_model_ae=$MODEL_DIR/yelp50_x3-ae/bi_att_cf_fixed \
111 | # --adv --classification_model=RNN --output_classes=2 \
112 | # --gumbel_softmax_temporature=0.1 \
113 | # --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
114 | # --cls_attention --cls_attention_size=50 --attention \
115 | # --batch_size=16 --max_len=50 \
116 | # --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
117 | # --learning_rate=0.00001 --ae_lambda=0.8 --seq_lambda=1.0 \
118 | # --output_dir=$MODEL_DIR/adv_train_lscfcp \
119 | # --save_checkpoints \
120 | # --num_gpus=0 \
121 | # --label_beta=0.95 \
122 | # --ae_vocab_file=$DATA_DIR/cf_vocab.in \
123 | # --ae_emb_file=$DATA_DIR/cf_emb.json \
124 | # --copy --attention_copy_mask --use_stop_words --top_k_attack=9
125 | #
126 | ## Conditional PTN: AE+LS+CF
127 | #python train.py --do_train \
128 | # --vocab_file=$DATA_DIR/vocab.in \
129 | # --emb_file=$DATA_DIR/emb.json \
130 | # --input_file=$DATA_DIR/train.pos.in \
131 | # --output_file=$DATA_DIR/train.pos.out \
132 | # --dev_file=$DATA_DIR/dev.in \
133 | # --dev_output=$DATA_DIR/dev.out \
134 | # --load_model_cls=$MODEL_DIR/yelp50_x3-cls/bi_att \
135 | # --load_model_ae=$MODEL_DIR/yelp50_x3-ae/bi_att_cf_fixed \
136 | # --adv --classification_model=RNN --output_classes=2 \
137 | # --gumbel_softmax_temporature=0.1 \
138 | # --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
139 | # --cls_attention --cls_attention_size=50 --attention \
140 | # --batch_size=16 --max_len=50 \
141 | # --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
142 | # --learning_rate=0.00001 --ae_lambda=0.8 --seq_lambda=1.0 \
143 | # --output_dir=$MODEL_DIR/adv_train_lscfcp_ptn \
144 | # --save_checkpoints \
145 | # --num_gpus=0 \
146 | # --label_beta=0.95 \
147 | # --ae_vocab_file=$DATA_DIR/cf_vocab.in \
148 | # --ae_emb_file=$DATA_DIR/cf_emb.json \
149 | # --target_label=0
150 | #
151 | ## Conditional NTP: AE+LS+CF
152 | #python train.py --do_train \
153 | # --vocab_file=$DATA_DIR/vocab.in \
154 | # --emb_file=$DATA_DIR/emb.json \
155 | # --input_file=$DATA_DIR/train.neg.in \
156 | # --output_file=$DATA_DIR/train.neg.out \
157 | # --dev_file=$DATA_DIR/dev.in \
158 | # --dev_output=$DATA_DIR/dev.out \
159 | # --load_model_cls=$MODEL_DIR/yelp50_x3-cls/bi_att \
160 | # --load_model_ae=$MODEL_DIR/yelp50_x3-ae/bi_att_cf_fixed \
161 | # --adv --classification_model=RNN --output_classes=2 \
162 | # --gumbel_softmax_temporature=0.1 \
163 | # --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
164 | # --cls_attention --cls_attention_size=50 --attention \
165 | # --batch_size=16 --max_len=50 \
166 | # --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
167 | # --learning_rate=0.00001 --ae_lambda=0.8 --seq_lambda=1.0 \
168 | # --output_dir=$MODEL_DIR/adv_train_lscfcp_ntp \
169 | # --save_checkpoints \
170 | # --num_gpus=0 \
171 | # --label_beta=0.95 \
172 | # --ae_vocab_file=$DATA_DIR/cf_vocab.in \
173 | # --ae_emb_file=$DATA_DIR/cf_emb.json \
174 | # --target_label=1
175 | #
176 | # AE+LS+CF+DEFENCE
177 | #python train.py --do_train \
178 | # --vocab_file=$DATA_DIR/vocab.in \
179 | # --emb_file=$DATA_DIR/emb.json \
180 | # --input_file=$DATA_DIR/train.in \
181 | # --output_file=$DATA_DIR/train.out \
182 | # --dev_file=$DATA_DIR/dev.in \
183 | # --dev_output=$DATA_DIR/dev.out \
184 | # --load_model_cls=$MODEL_DIR/yelp50_x3-cls/bi_att \
185 | # --load_model_ae=$MODEL_DIR/yelp50_x3-ae/bi_att_cf_fixed \
186 | # --adv --classification_model=RNN --output_classes=2 \
187 | # --gumbel_softmax_temporature=0.1 \
188 | # --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
189 | # --cls_attention --cls_attention_size=50 --attention \
190 | # --batch_size=16 --max_len=50 \
191 | # --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
192 | # --learning_rate=0.00001 --ae_lambda=0.8 --seq_lambda=1.0 \
193 | # --output_dir=$MODEL_DIR/def_train \
194 | # --save_checkpoints \
195 | # --num_gpus=0 \
196 | # --label_beta=0.95 \
197 | # --ae_vocab_file=$DATA_DIR/cf_vocab.in \
198 | # --ae_emb_file=$DATA_DIR/cf_emb.json \
199 | # --defending --at_steps=2
200 | #
201 | #
202 | ## Attacking an augmented AE+LS+CF model: AE+LS+CF
203 | #python train.py --do_train \
204 | # --vocab_file=$DATA_DIR/vocab.in \
205 | # --emb_file=$DATA_DIR/emb.json \
206 | # --input_file=$DATA_DIR/train.in \
207 | # --output_file=$DATA_DIR/train.out \
208 | # --dev_file=$DATA_DIR/dev.in \
209 | # --dev_output=$DATA_DIR/dev.out \
210 | # --load_model_ae=$MODEL_DIR/yelp50_x3-ae/bi_att_cf_fixed \
211 | # --adv --classification_model=RNN --output_classes=2 \
212 | # --gumbel_softmax_temporature=0.1 \
213 | # --enc_type=bi --cls_enc_num_units=256 --cls_enc_type=bi \
214 | # --cls_attention --cls_attention_size=50 --attention \
215 | # --batch_size=16 --max_len=50 \
216 | # --num_epochs=20 --print_every_steps=100 --total_steps=200000 \
217 | # --learning_rate=0.00001 --ae_lambda=0.8 --seq_lambda=1.0 \
218 | # --output_dir=$MODEL_DIR/adv_aeaug_lscf \
219 | # --save_checkpoints \
220 | # --num_gpus=0 \
221 | # --label_beta=0.95 \
222 | # --ae_vocab_file=$DATA_DIR/cf_vocab.in \
223 | # --ae_emb_file=$DATA_DIR/cf_emb.json \
224 | # --load_model_cls=$MODEL_DIR/def_output/nmt-T2.ckpt \
225 | # --use_defending_as_target
226 |
--------------------------------------------------------------------------------
/misc/evaluate_attacks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Adv-Def-Text Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Adversarial attack and defence evaluation.
18 | Author: Ying Xu
19 | Date: Jul 8, 2020
20 | """
21 |
22 | import numpy as np
23 | from misc import evaluate as general_evaluate
24 | from misc import utils
25 | from misc import input_data
26 |
27 | def printSentence(tokenized_sentences, vocab):
28 | train_sentence = ''
29 | for word in tokenized_sentences:
30 | train_sentence += vocab[word] + ' '
31 | utils.print_out(train_sentence)
32 |
33 |
34 | def getSentencesFromIDs(tokenized_sentences_list, vocab, eos_id=input_data.EOS_ID):
35 | train_sentences = []
36 | for tokenized_sentences in tokenized_sentences_list:
37 | train_sentence = []
38 | for word in tokenized_sentences:
39 | if word == eos_id:
40 | break
41 | train_sentence.append(vocab[word])
42 | train_sentences.append(train_sentence)
43 | return train_sentences
44 |
45 |
46 | def read_false_record(file_name):
47 | ret = []
48 | for line in open(file_name, 'r'):
49 | if line.startswith('Example '):
50 | comps = line.strip().split(': ')
51 | example_id = comps[0].split(' ')[1]
52 | ret.append(int(example_id))
53 | return ret
54 |
55 | from sklearn.metrics.pairwise import cosine_similarity
56 | def avgcos(emb1, emb2):
57 | avg_emb1 = [np.mean(a, axis=0) for a in emb1]
58 | avg_emb2 = [np.mean(a, axis=0) for a in emb2]
59 | cos_sim = [cosine_similarity([a], [b])[0][0] for a, b in zip(avg_emb1, avg_emb2)]
60 | avg_cos = sum(cos_sim) / len(cos_sim)
61 | return avg_cos
62 |
63 | def evaluate_attack(args, step, decoder_reference_list, decoder_prediction_list,
64 | cls_logits, cls_orig_logits, cls_labels, vocab,
65 | sent_embs, adv_sent_embs,
66 | is_test=False, X_adv_flip_num=None,
67 | orig_alphas=None, trans_alphas=None,
68 | cls_logits_def=None, cls_origs_def=None,
69 | copy_masks=None):
70 |
71 | cls_orig_acc = general_evaluate._clss_accuracy(cls_labels, cls_orig_logits)
72 | cls_orig_auc = general_evaluate._clss_auc(cls_labels, cls_orig_logits)
73 |
74 | cls_acc = general_evaluate._clss_accuracy(cls_labels, cls_logits)
75 | cls_auc = general_evaluate._clss_auc(cls_labels, cls_logits)
76 | cls_acc_pos = general_evaluate._clss_accuracy_micro(cls_labels, cls_logits, orig_label=1)
77 | cls_acc_neg = general_evaluate._clss_accuracy_micro(cls_labels, cls_logits, orig_label=0)
78 |
79 | if cls_logits_def is not None and len(cls_logits_def) > 0:
80 | cls_def_acc = general_evaluate._clss_accuracy(cls_labels, cls_logits_def)
81 | cls_def_auc = general_evaluate._clss_auc(cls_labels, cls_logits_def)
82 | org_def_acc = general_evaluate._clss_accuracy(cls_labels, cls_origs_def)
83 | org_def_auc = general_evaluate._clss_auc(cls_labels, cls_origs_def)
84 |
85 | reference_list = getSentencesFromIDs(decoder_reference_list, vocab)
86 | translation_list = getSentencesFromIDs(decoder_prediction_list, vocab)
87 |
88 | ref_pos, ref_neg, trans_pos, trans_neg, ref_changed, trans_changed = [], [], [], [], [], []
89 | label_changed, logits_changed, flip_num_changed, ids_changed = [], [], [], []
90 | ref_emb_pos, trans_emb_pos, ref_emb_neg, trans_emb_neg, ref_emb_cha, trans_emb_cha = [], [], [], [], [], []
91 |
92 | for ind, references in enumerate(reference_list):
93 | ref_pos.append(references) if cls_labels[ind][1] > 0 else ref_neg.append(references)
94 | trans_pos.append(translation_list[ind]) if cls_labels[ind][1] > 0 else trans_neg.append(translation_list[ind])
95 | ref_emb_pos.append(sent_embs[ind]) if cls_labels[ind][1] > 0 else ref_emb_neg.append(sent_embs[ind])
96 | trans_emb_pos.append(adv_sent_embs[ind]) if cls_labels[ind][1] > 0 else trans_emb_neg.append(adv_sent_embs[ind])
97 | if np.argmax(cls_logits[ind]) != np.argmax(cls_orig_logits[ind]):
98 | ids_changed.append(ind)
99 | ref_changed.append(references)
100 | trans_changed.append(translation_list[ind])
101 | label_changed.append(cls_labels[ind])
102 | logits_changed.append(cls_logits[ind])
103 | ref_emb_cha.append(sent_embs[ind])
104 | trans_emb_cha.append(adv_sent_embs[ind])
105 | if X_adv_flip_num is not None:
106 | flip_num_changed.append(X_adv_flip_num[ind])
107 |
108 | ae_acc = general_evaluate._accuracy(reference_list, translation_list)
109 | word_acc = general_evaluate._word_accuracy(reference_list, translation_list)
110 | rouge = general_evaluate._rouge(reference_list, translation_list)
111 | bleu = general_evaluate._bleu(reference_list, translation_list)
112 | use = general_evaluate._use_scores(reference_list, translation_list, args.use_model)
113 | accept = general_evaluate._accept_score(reference_list, translation_list, args)
114 |
115 | # positive examples
116 | pos_rouge = general_evaluate._rouge(ref_pos, trans_pos)
117 | pos_bleu = general_evaluate._bleu(ref_pos, trans_pos)
118 | pos_accept = general_evaluate._accept_score(ref_pos, trans_pos, args)
119 | pos_semsim = avgcos(ref_emb_pos, trans_emb_pos)
120 | pos_use = general_evaluate._use_scores(ref_pos, trans_pos, args.use_model)
121 |
122 | # negative examples
123 | neg_rouge = general_evaluate._rouge(ref_neg, trans_neg)
124 | neg_bleu = general_evaluate._bleu(ref_neg, trans_neg)
125 | neg_accept = general_evaluate._accept_score(ref_neg, trans_neg, args)
126 | neg_semsim = avgcos(ref_emb_neg, trans_emb_neg)
127 | neg_use = general_evaluate._use_scores(ref_neg, trans_neg, args.use_model)
128 |
129 |
130 | # changed examples
131 | if len(ref_changed) == 0:
132 | changed_rouge = -1.0
133 | changed_bleu = -1.0
134 | changed_accept = -1.0
135 | changed_semsim = -1.0
136 | changed_use = -1.0
137 | else:
138 | changed_rouge = general_evaluate._rouge(ref_changed, trans_changed)
139 | changed_bleu = general_evaluate._bleu(ref_changed, trans_changed)
140 | changed_accept = general_evaluate._accept_score(ref_changed, trans_changed, args)
141 | changed_semsim = avgcos(ref_emb_cha, trans_emb_cha)
142 | changed_use = general_evaluate._use_scores(ref_changed, trans_changed, args.use_model)
143 | # changed_use = 0.0
144 |
145 | # print out src, spl, and nmt
146 | for i in range(len(ref_changed)):
147 | reference_changed = ref_changed[i]
148 | translation_changed = trans_changed[i]
149 | if orig_alphas is not None and len(orig_alphas) > 0:
150 | orig_alpha = orig_alphas[ids_changed[i]]
151 | reference_changed = [s + '('+'{:.3f}'.format(orig_alpha[ind][0])+')' for ind, s in enumerate(ref_changed[i])]
152 | trans_alpha = trans_alphas[ids_changed[i]]
153 | translation_changed = [s + '('+'{:.3f}'.format(trans_alpha[ind][0])+')' for ind, s in enumerate(trans_changed[i])]
154 | utils.print_out('Example ' + str(ids_changed[i]) + ': src:\t' + ' '.join(reference_changed) + '\t' + str(label_changed[i]))
155 | utils.print_out('Example ' + str(ids_changed[i]) + ': nmt:\t' + ' '.join(translation_changed) + '\t' + str(logits_changed[i]))
156 | if copy_masks is not None and len(copy_masks)>0:
157 | copy_mask = copy_masks[ids_changed[i]]
158 | copy_mask_str = [str(mask) for mask in copy_mask]
159 | utils.print_out('Example ' + str(ids_changed[i]) + ': msk:\t' + ' '.join(copy_mask_str))
160 | if X_adv_flip_num is not None:
161 | utils.print_out('Example ' + str(ids_changed[i]) + ' flipped tokens: ' + str(flip_num_changed[i]))
162 | utils.print_out(' ')
163 |
164 | if X_adv_flip_num is not None:
165 | lenght = 0
166 | for num in X_adv_flip_num:
167 | if num > 0:
168 | lenght += 1
169 | utils.print_out('Average flipped tokens: ' + str(sum(X_adv_flip_num) / lenght))
170 |
171 | utils.print_out('Step: ' + str(step) + ', cls_acc_pos=' + str(cls_acc_pos) + ', cls_acc_neg=' + str(cls_acc_neg))
172 | utils.print_out('Step: ' + str(step) + ', rouge_pos=' + str(pos_rouge) + ', rouge_neg=' + str(neg_rouge) + ', rouge_changed=' + str(changed_rouge))
173 | utils.print_out('Step: ' + str(step) + ', bleu_pos=' + str(pos_bleu) + ', bleu_neg=' + str(neg_bleu) + ', bleu_changed=' + str(changed_bleu))
174 | utils.print_out('Step: ' + str(step) + ', accept_pos=' + str(pos_accept) + ', accept_neg=' + str(neg_accept) + ', accept_changed=' + str(changed_accept))
175 | utils.print_out('Step: ' + str(step) + ', semsim_pos=' + str(pos_semsim) + ', semsim_neg=' + str(neg_semsim) + ', semsim_changed=' + str(changed_semsim))
176 | utils.print_out('Step: ' + str(step) + ', use_pos=' + str(pos_use) + ', use_neg=' + str(neg_use) + ', use_changed=' + str(changed_use))
177 | utils.print_out('Step: ' + str(step) + ', ae_acc=' + str(ae_acc) + ', word_acc=' + str(word_acc) + ', rouge=' + str(rouge) + ', bleu=' + str(bleu) +
178 | ', accept=' + str(accept) + ', use=' + str(use) + ', semsim=' + str(avgcos(sent_embs, adv_sent_embs)))
179 | utils.print_out('Step: ' + str(step) + ', cls_orig_acc=' + str(cls_orig_acc) + ', cls_orig_auc=' + str(cls_orig_auc))
180 | utils.print_out('Step: ' + str(step) + ', cls_acc=' + str(cls_acc) + ', cls_auc=' + str(cls_auc))
181 | if cls_logits_def is not None and len(cls_logits_def) > 0:
182 | utils.print_out('Step: ' + str(step) + ', org_def_acc=' + str(org_def_acc) + ', org_def_auc=' + str(org_def_auc))
183 | utils.print_out('Step: ' + str(step) + ', cls_def_acc=' + str(cls_def_acc) + ', cls_def_auc=' + str(cls_def_auc))
184 |
185 | if is_test:
186 | with open(args.output_dir+'/src_changed.txt', 'w') as output_file:
187 | output_file.write('\n'.join([' '.join(a) for a in ref_changed]))
188 | with open(args.output_dir+'/adv_changed.txt', 'w') as output_file:
189 | output_file.write('\n'.join([' '.join(a) for a in trans_changed]))
190 |
191 | with open(args.output_dir+'/adv.txt', 'w') as output_file:
192 | output_file.write('\n'.join([' '.join(a) for a in translation_list]))
193 | with open(args.output_dir+'/adv_score.txt', 'w') as output_file:
194 | for score in cls_logits:
195 | output_file.write(' '.join([str(a) for a in score])+'\n')
196 |
197 | return cls_acc, cls_acc_pos, cls_acc_neg, changed_bleu
198 |
199 |
200 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Adv-Def-Text Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Input parameter definitions.
18 | Author: Ying Xu
19 | Date: Jul 8, 2020
20 | """
21 |
22 | from misc import input_data
23 | from argparse import ArgumentParser
24 | import os
25 | from misc.use import USE
26 | from transformers import XLNetTokenizer, XLNetLMHeadModel, BertTokenizer, BertForMaskedLM
27 | import torch
28 |
29 |
30 | def add_arguments():
31 | parser = ArgumentParser()
32 |
33 | # basic
34 | parser.add_argument('--do_train', action='store_true', help="do training")
35 | parser.add_argument('--do_test', action='store_true', help="do independent test")
36 | parser.add_argument('--do_cond_test', action='store_true', help="do test for conditional generation")
37 |
38 | parser.add_argument('--input_file', type=str, default=None, help="")
39 | parser.add_argument('--dev_file', type=str, default=None, help="")
40 | parser.add_argument('--test_file', type=str, default=None, help="")
41 | parser.add_argument('--vocab_file', type=str, default=None, help="")
42 | parser.add_argument('--emb_file', type=str, default=None, help="")
43 | parser.add_argument('--output_dir', type=str, default=None, help="")
44 | parser.add_argument('--attention', action='store_true', help='whether use attention in seq2seq')
45 | parser.add_argument('--cls_attention', action='store_true', help="")
46 | parser.add_argument('--cls_attention_size', type=int, default=300, help="")
47 |
48 | # hyper-parameters
49 | parser.add_argument('--batch_size', type=int, default=32, help="")
50 | parser.add_argument('--num_epochs', type=int, default=5, help="")
51 | parser.add_argument('--learning_rate', type=float, default=0.001, help="")
52 | parser.add_argument('--enc_type', type=str, default='bi', help="")
53 | parser.add_argument('--enc_num_units', type=int, default=512, help="")
54 | parser.add_argument('--enc_layers', type=int, default=2, help="")
55 | parser.add_argument('--dec_num_units', type=int, default=512, help="")
56 | parser.add_argument('--dec_layers', type=int, default=2, help="")
57 | parser.add_argument('--epochs', type=int, default=2, help="")
58 | parser.add_argument("--max_gradient_norm", type=float, default=5.0, help="Clip gradients to this norm.")
59 | parser.add_argument('--max_to_keep', type=int, default=5, help="")
60 | parser.add_argument('--lowest_bound_score', type=float, default=10.0, help="Stop the training once achieving the lowest_bound_score")
61 |
62 | parser.add_argument('--beam_width', type=int, default=0, help="")
63 | parser.add_argument("--num_buckets", type=int, default=5, help="Put data into similar-length buckets.")
64 | parser.add_argument("--max_len", type=int, default=50, help="Lenth max of input sentences")
65 | parser.add_argument('--tgt_min_len', type=int, default=0, help='Length min of target sentences')
66 |
67 | # training control
68 | parser.add_argument('--print_every_steps', type=int, default=1, help="")
69 | parser.add_argument('--save_every_epoch', type=int, default=1, help="")
70 | parser.add_argument('--stop_steps', type=int, default=20000, help="number of steps of non-improve to terminate training")
71 | parser.add_argument('--total_steps', type=int, default=None, help="total number of steps for training")
72 | parser.add_argument('--random_seed', type=int, default=1, help="")
73 | parser.add_argument('--num_gpus', type=int, default=0, help="")
74 | parser.add_argument('--save_checkpoints', action='store_true', help='Whether save models while training')
75 |
76 | # classification
77 | parser.add_argument('--classification', action='store_true', help="Perform classification")
78 | parser.add_argument('--classification_model', type=str, default='RNN', help='')
79 | parser.add_argument('--output_classes', type=int, default=2, help="number of classes for classification")
80 | parser.add_argument('--output_file', type=str, default=None, help="Classification output for train set")
81 | parser.add_argument('--dev_output', type=str, default=None, help="Classification output for dev set")
82 | parser.add_argument('--test_output', type=str, default=None, help="Classification output for test set")
83 | parser.add_argument('--filter_sizes', nargs='+', default=[5, 3], type=int, help='filter sizes, only for CNN')
84 | parser.add_argument('--dropout_keep_prob', type=float, default=0.8, help='dropout, only for CNN')
85 | parser.add_argument('--bert_config_file', type=str, default=None, help='pretrained bert config file')
86 | parser.add_argument('--bert_init_chk', type=str, default=None, help='checkpoint for pretrained Bert')
87 |
88 | # adversarial attack and defence
89 | parser.add_argument('--adv', action='store_true', help="Perform adversarial attack training/testing")
90 | parser.add_argument('--cls_enc_type', type=str, default='bi', help="")
91 | parser.add_argument('--cls_enc_num_units', type=int, default=256, help="")
92 | parser.add_argument('--cls_enc_layers', type=int, default=2, help="")
93 | parser.add_argument('--gumbel_softmax_temporature', type=float, default=0.1, help="")
94 | parser.add_argument('--load_model_cls', type=str, default=None, help="Path to target classification model")
95 | parser.add_argument('--load_model_ae', type=str, default=None, help="Path to pretrained AE")
96 | parser.add_argument('--load_model', type=str, default=None, help="Trained model for testing")
97 | parser.add_argument('--load_model_pos', type=str, default=None, help="PTN attack model for testing")
98 | parser.add_argument('--load_model_neg', type=str, default=None, help="NTP attack model for testing")
99 |
100 |
101 | # balanced attack
102 | parser.add_argument('--balance', action='store_true', help="Whether balance between pos/neg attack")
103 | # label smoothing
104 | parser.add_argument('--label_beta', type=float, default=None, help='label smoother param, must be > 0.5')
105 | # use counter-fitted embedding for AE (AE embedding different from CLS embeddings)
106 | parser.add_argument('--ae_vocab_file', type=str, default=None, help='Path to counter-fitted vocabulary')
107 | parser.add_argument('--ae_emb_file', type=str, default=None, help='Path to counter-fitted embeddings')
108 | # gan auxiliary loss
109 | parser.add_argument('--gan', action='store_true', help='Whether use GAN as regularization')
110 | # conditional generation (1 or 0)
111 | parser.add_argument('--target_label', type=int, default=None, help="Target label for conditional generation, 0 (PTN) or 1 (NTP)")
112 | # include defending
113 | parser.add_argument('--defending', action='store_true', help="whether train C* for more robust classification models")
114 | # train defending classifier with augmented dataset
115 | parser.add_argument('--def_train_set', nargs='+', default=[], type=str, help='Set of adversarial examples to include in adv training')
116 | # attack an AE model using the augmented classifier as the target classifier
117 | parser.add_argument('--use_defending_as_target', action='store_true', help='Use the defending component as the target classifier')
118 |
119 | # loss control
120 | parser.add_argument('--at_steps', type=int, default=1, help='Alternative steps for GAN/Defending')
121 | parser.add_argument('--ae_lambda', type=float, default=0.8, help='weighting ae_loss+sent_loss v.s. adv_loss')
122 | parser.add_argument('--seq_lambda', type=float, default=1.0, help='weighting ae_loss v.s. sent_loss')
123 | parser.add_argument('--aux_lambda', type=float, default=1.0, help='weighting ae_loss v.s. auxiliary losses')
124 | parser.add_argument('--sentiment_emb_dist', type=str, default='avgcos', help="whether involve embedding distance as aux loss")
125 | parser.add_argument('--loss_attention', action='store_true', help="whether weight emb dist")
126 | parser.add_argument('--loss_attention_norm', action='store_true', help="whether apply minimax norm to ae_loss_attention")
127 |
128 | # copy mechanism
129 | parser.add_argument('--copy', action='store_true', help="Whether use copy mechanism")
130 | parser.add_argument('--attention_copy_mask', action='store_true', help="Whether use attention to calculate copy mask")
131 | parser.add_argument('--use_stop_words', action='store_true', help="whether mask stop words")
132 | parser.add_argument('--top_k_attack', type=int, default=None, help="number of words to attack in copy mechanism, only set when args.copy is set to true.")
133 | parser.add_argument('--load_copy_model', type=str, default=None, help="Pretrained attention layer from the bi_att model")
134 |
135 | # evaluation options
136 | parser.add_argument('--use_cache_dir', type=str, default=None, help='cache dir for use (sem) eval')
137 | parser.add_argument('--accept_name', type=str, default=None, help="model name for acceptibility scores (xlnet), only used when set")
138 |
139 |
140 | args=parser.parse_args()
141 | if args.save_checkpoints and not os.path.exists(args.output_dir):
142 | os.mkdir(args.output_dir)
143 | vocab_size, vocab_file = input_data.check_vocab(args.vocab_file, args.output_dir, check_special_token=False if (args.classification_model == 'BERT') else True,
144 | vocab_base_name='vocab.txt')
145 | args.vocab_file = vocab_file
146 | args.vocab_size = vocab_size
147 |
148 | if args.ae_vocab_file is not None:
149 | ae_vocab_size, ae_vocab_file = input_data.check_vocab(args.ae_vocab_file, args.output_dir, check_special_token=False if (args.classification_model == 'BERT') else True,
150 | vocab_base_name='ae_vocab.txt')
151 | args.ae_vocab_size = ae_vocab_size
152 | args.ae_vocab_file = ae_vocab_file
153 |
154 | args.use_model = None
155 | if args.use_cache_dir is not None:
156 | args.use_model = USE(args.use_cache_dir)
157 |
158 | if args.accept_name is not None:
159 | if args.accept_name == 'bert':
160 | args.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
161 | args.acpt_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
162 | elif args.accept_name == 'xlnet':
163 | args.tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
164 | args.acpt_model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
165 |
166 | args.device = torch.device('cpu') if args.num_gpus == 0 else torch.device('cuda:0')
167 | args.acpt_model.to(args.device)
168 | args.acpt_model.eval()
169 |
170 | return args
171 |
172 |
173 |
--------------------------------------------------------------------------------
/models/myClassifier.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Adv-Def-Text Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | RNN-based sentiment classification model.
18 | Author: Ying Xu
19 | Date: Jul 8, 2020
20 | """
21 |
22 | import tensorflow as tf
23 | from tensorflow.python.layers import core as layers_core
24 | from misc import input_data
25 | import numpy as np
26 | import models.utils as utils
27 |
28 |
29 | class ClassificationModel():
30 | def __init__(self, args, mode=None):
31 |
32 | self.mode = mode
33 | self.bidirectional = True if args.enc_type == 'bi' else False
34 | self.args = args
35 | self.batch_size = args.batch_size
36 |
37 | self._init_placeholders()
38 |
39 | self.encoder_outputs, self.cls_logits, self.acc, self.alphas = self._make_graph()
40 |
41 | if self.mode == "Train":
42 | self._init_optimizer()
43 |
44 | def _make_graph(self, encoder_embedding_inputs=None):
45 |
46 | with tf.variable_scope("my_classifier", reuse=tf.AUTO_REUSE) as scope:
47 | if encoder_embedding_inputs is None:
48 | self._init_embedding()
49 | encoder_outputs, _ = self._init_encoder(encoder_embedding_inputs=(self.encoder_embedding_inputs
50 | if encoder_embedding_inputs is None else
51 | encoder_embedding_inputs))
52 |
53 | with tf.variable_scope("classification") as scope:
54 | alphas = None
55 | if self.args.cls_attention:
56 | with tf.variable_scope("attention") as scope:
57 | x = tf.reshape(encoder_outputs, [-1, self.args.enc_num_units*2
58 | if self.bidirectional else self.args.enc_num_units])
59 | self.cls_attention_layer = layers_core.Dense(self.args.cls_attention_size, name="cls_attention_layer")
60 | self.cls_attention_fc_layer = layers_core.Dense(1, name="cls_attention_fc_layer")
61 | with tf.device(utils.get_device_str(self.args.num_gpus)):
62 | x = tf.nn.relu(self.cls_attention_layer(x))
63 | x = self.cls_attention_fc_layer(x)
64 | logits = tf.reshape(x, [-1, tf.shape(encoder_outputs)[1], 1])
65 | alphas = tf.nn.softmax(logits, dim=1)
66 | encoder_outputs = encoder_outputs * alphas
67 | output_rnn_last = tf.reduce_sum(encoder_outputs, axis=1) #[batch size, h_dim]
68 | else:
69 | output_rnn_last = tf.reduce_mean(encoder_outputs, axis=1) # [batch size, h_dim]
70 | projection_layer = layers_core.Dense(units=self.args.output_classes, name="projection_layer")
71 | with tf.device(utils.get_device_str(self.args.num_gpus)):
72 | cls_logits = tf.nn.tanh(projection_layer(output_rnn_last)) #[batch size, output_classes]
73 | ybar = tf.argmax(cls_logits, axis=1, output_type=tf.int32)
74 | self.categorical_logits = tf.one_hot(ybar, depth=2, on_value=1.0, off_value=0.0)
75 | ylabel = tf.cast(self.classification_outputs[:, -1], dtype=tf.int32)
76 | count = tf.equal(ylabel, ybar)
77 | acc = tf.reduce_mean(tf.cast(count, tf.float32), name='acc')
78 | return encoder_outputs, cls_logits, acc, alphas
79 |
80 | def _init_placeholders(self):
81 | self.encoder_inputs = tf.placeholder(
82 | shape=(None, None),
83 | dtype=tf.int32,
84 | name='encoder_inputs'
85 | )
86 |
87 | self.encoder_inputs_length = tf.placeholder(
88 | shape=(None,),
89 | dtype=tf.int32,
90 | name='encoder_inputs_length',
91 | )
92 |
93 | self.classification_outputs = tf.placeholder(
94 | shape=(None, None),
95 | dtype=tf.int32,
96 | name='classification_outputs',
97 | )
98 |
99 | def _init_embedding(self, trainable=False):
100 | if trainable:
101 | with tf.variable_scope("my_word_embeddings", reuse=tf.AUTO_REUSE) as scope:
102 | emb_mat, emb_size = input_data.load_embed_json(self.args.emb_file, vocab_size=self.args.vocab_size)
103 | self.embedding_encoder = tf.get_variable(name="embedding_matrix", shape=emb_mat.shape,
104 | initializer=tf.constant_initializer(emb_mat),
105 | trainable=True)
106 | else:
107 | self.embedding_encoder = input_data._create_pretrained_emb_from_txt(
108 | vocab_file=self.args.vocab_file, embed_file=self.args.emb_file,
109 | trainable_tokens=3)
110 |
111 | self.encoder_embedding_inputs = tf.nn.embedding_lookup(
112 | self.embedding_encoder,
113 | self.encoder_inputs) #[batch size, sequence len, h_dim]
114 |
115 | def _init_encoder(self, encoder_embedding_inputs=None, trainable=True):
116 |
117 | with tf.variable_scope("Encoder") as scope:
118 | if self.bidirectional:
119 | fw_cell = tf.contrib.rnn.MultiRNNCell(
120 | [utils.make_cell(self.args.enc_num_units, utils.get_device_str(self.args.num_gpus), trainable=trainable) for _ in range(self.args.enc_layers)])
121 | bw_cell = tf.contrib.rnn.MultiRNNCell(
122 | [utils.make_cell(self.args.enc_num_units, utils.get_device_str(self.args.num_gpus), trainable=trainable) for _ in range(self.args.enc_layers)])
123 | bi_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn(
124 | fw_cell,
125 | bw_cell,
126 | encoder_embedding_inputs,
127 | dtype=tf.float32)
128 | #sequence_length=self.encoder_inputs_length)
129 |
130 | encoder_outputs = tf.concat(bi_outputs, -1) #[batch size, sequence len, h_dim*2]
131 | bi_encoder_state = bi_state
132 |
133 | if self.args.enc_layers == 1:
134 | encoder_state = bi_encoder_state
135 | else:
136 | # alternatively concat forward and backward states
137 | encoder_state = []
138 | for layer_id in range(self.args.enc_layers):
139 | encoder_state.append(bi_encoder_state[0][layer_id]) # forward
140 | encoder_state.append(bi_encoder_state[1][layer_id]) # backward
141 | encoder_state = tuple(encoder_state)
142 | else:
143 | encoder_cell = tf.contrib.rnn.MultiRNNCell([utils.make_cell(self.args.enc_num_units,
144 | utils.get_device_str(self.args.num_gpus),
145 | trainable=trainable) for _ in range(self.args.enc_layers)])
146 | encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
147 | cell=encoder_cell,
148 | inputs=encoder_embedding_inputs,
149 | #sequence_length=self.encoder_inputs_length,
150 | dtype=tf.float32
151 | ) #[batch size, sequence len, h_dim]
152 | return encoder_outputs, encoder_state
153 |
154 | def _init_optimizer(self):
155 |
156 | if self.args.output_classes > 2:
157 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
158 | logits=self.cls_logits, labels=self.classification_outputs))
159 | else:
160 | self.target_output = tf.cast(self.classification_outputs[:, -1], dtype=tf.int32)
161 | self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
162 | logits=self.cls_logits, labels=self.target_output))
163 |
164 | if self.args.defending:
165 | orig, augmented = tf.split(self.encoder_embedding_inputs, num_or_size_splits=2, axis=0)
166 | self.aux_loss = tf.reduce_sum(tf.keras.losses.MSE(orig, augmented))
167 | self.loss += self.args.aux_lambda * self.aux_loss
168 |
169 | tf.summary.scalar('loss', self.loss)
170 | self.summary_op = tf.summary.merge_all()
171 |
172 | learning_rate = self.args.learning_rate
173 | optimizer = tf.train.AdamOptimizer(learning_rate)
174 | self.gradients = optimizer.compute_gradients(self.loss)
175 | capped_gradients = [(tf.clip_by_value(grad, -1.0*self.args.max_gradient_norm, self.args.max_gradient_norm), var) for grad, var in self.gradients if grad is not None]
176 | self.train_op = optimizer.apply_gradients(capped_gradients)
177 |
178 | # inputs and outputs for train/infer
179 |
180 | def make_train_inputs(self, x, X_data=None):
181 | x_input = x[0]
182 | if X_data is not None:
183 | x_input = X_data
184 | if len(self.args.def_train_set) == 1:
185 | x_input_def = x[-2]
186 | if (len(x_input[0]) - len(x_input_def[0])) > 0:
187 | x_input_def = np.concatenate([x_input_def, np.ones([len(x_input_def), len(x_input[0])-len(x_input_def[0])], dtype=np.int32)*input_data.EOS_ID],
188 | axis=1)
189 | if (len(x_input[0]) - len(x_input_def[0])) < 0:
190 | x_input = np.concatenate([x_input, np.ones([len(x_input), len(x_input_def[0])-len(x_input[0])], dtype=np.int32)*input_data.EOS_ID],
191 | axis=1)
192 | x_input = np.concatenate([x_input, x_input_def], axis=0)
193 | y_input = np.concatenate([x[1], x[1]], axis=0)
194 | x_lenghts = np.concatenate([x[2], x[-1]], axis=0)
195 | else:
196 | y_input = x[1]
197 | x_lenghts = x[2]
198 | return {
199 | self.encoder_inputs: x_input,
200 | self.classification_outputs: y_input,
201 | self.encoder_inputs_length: x_lenghts
202 | }
203 |
204 | def embedding_encoder_fn(self):
205 | return self.embedding_encoder
206 |
207 | def make_train_outputs(self, full_loss_step=True, defence=False):
208 | if self.args.cls_attention:
209 | return [self.train_op, self.loss, self.cls_logits, self.summary_op, self.alphas]
210 | else:
211 | return [self.train_op, self.loss, self.cls_logits, self.summary_op]
212 |
213 | def make_eval_outputs(self):
214 | return self.loss
215 |
216 | def make_test_outputs(self):
217 | if self.args.cls_attention:
218 | return [self.loss, self.cls_logits, self.acc, self.alphas, self.encoder_inputs, self.classification_outputs]
219 | else:
220 | return [self.loss, self.cls_logits, self.acc, self.encoder_inputs, self.classification_outputs]
221 |
222 | def make_encoder_output(self):
223 | return self.encoder_outputs
224 |
225 |
--------------------------------------------------------------------------------
/misc/scripts/rouge.py:
--------------------------------------------------------------------------------
1 | """ROUGE metric implementation.
2 |
3 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py.
4 | This is a modified and slightly extended verison of
5 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py.
6 | """
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 | from __future__ import unicode_literals
12 |
13 | import itertools
14 | import numpy as np
15 |
16 | #pylint: disable=C0103
17 |
18 |
19 | def _get_ngrams(n, text):
20 | """Calcualtes n-grams.
21 |
22 | Args:
23 | n: which n-grams to calculate
24 | text: An array of tokens
25 |
26 | Returns:
27 | A set of n-grams
28 | """
29 | ngram_set = set()
30 | text_length = len(text)
31 | max_index_ngram_start = text_length - n
32 | for i in range(max_index_ngram_start + 1):
33 | ngram_set.add(tuple(text[i:i + n]))
34 | return ngram_set
35 |
36 |
37 | def _split_into_words(sentences):
38 | """Splits multiple sentences into words and flattens the result"""
39 | return list(itertools.chain(*[_.split(" ") for _ in sentences]))
40 |
41 |
42 | def _get_word_ngrams(n, sentences):
43 | """Calculates word n-grams for multiple sentences.
44 | """
45 | assert len(sentences) > 0
46 | assert n > 0
47 |
48 | words = _split_into_words(sentences)
49 | return _get_ngrams(n, words)
50 |
51 |
52 | def _len_lcs(x, y):
53 | """
54 | Returns the length of the Longest Common Subsequence between sequences x
55 | and y.
56 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
57 |
58 | Args:
59 | x: sequence of words
60 | y: sequence of words
61 |
62 | Returns
63 | integer: Length of LCS between x and y
64 | """
65 | table = _lcs(x, y)
66 | n, m = len(x), len(y)
67 | return table[n, m]
68 |
69 |
70 | def _lcs(x, y):
71 | """
72 | Computes the length of the longest common subsequence (lcs) between two
73 | strings. The implementation below uses a DP programming algorithm and runs
74 | in O(nm) time where n = len(x) and m = len(y).
75 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
76 |
77 | Args:
78 | x: collection of words
79 | y: collection of words
80 |
81 | Returns:
82 | Table of dictionary of coord and len lcs
83 | """
84 | n, m = len(x), len(y)
85 | table = dict()
86 | for i in range(n + 1):
87 | for j in range(m + 1):
88 | if i == 0 or j == 0:
89 | table[i, j] = 0
90 | elif x[i - 1] == y[j - 1]:
91 | table[i, j] = table[i - 1, j - 1] + 1
92 | else:
93 | table[i, j] = max(table[i - 1, j], table[i, j - 1])
94 | return table
95 |
96 |
97 | def _recon_lcs(x, y):
98 | """
99 | Returns the Longest Subsequence between x and y.
100 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
101 |
102 | Args:
103 | x: sequence of words
104 | y: sequence of words
105 |
106 | Returns:
107 | sequence: LCS of x and y
108 | """
109 | i, j = len(x), len(y)
110 | table = _lcs(x, y)
111 |
112 | def _recon(i, j):
113 | """private recon calculation"""
114 | if i == 0 or j == 0:
115 | return []
116 | elif x[i - 1] == y[j - 1]:
117 | return _recon(i - 1, j - 1) + [(x[i - 1], i)]
118 | elif table[i - 1, j] > table[i, j - 1]:
119 | return _recon(i - 1, j)
120 | else:
121 | return _recon(i, j - 1)
122 |
123 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j)))
124 | return recon_tuple
125 |
126 |
127 | def rouge_n(evaluated_sentences, reference_sentences, n=2):
128 | """
129 | Computes ROUGE-N of two text collections of sentences.
130 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/
131 | papers/rouge-working-note-v1.3.1.pdf
132 |
133 | Args:
134 | evaluated_sentences: The sentences that have been picked by the summarizer
135 | reference_sentences: The sentences from the referene set
136 | n: Size of ngram. Defaults to 2.
137 |
138 | Returns:
139 | A tuple (f1, precision, recall) for ROUGE-N
140 |
141 | Raises:
142 | ValueError: raises exception if a param has len <= 0
143 | """
144 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
145 | raise ValueError("Collections must contain at least 1 sentence.")
146 |
147 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences)
148 | reference_ngrams = _get_word_ngrams(n, reference_sentences)
149 | reference_count = len(reference_ngrams)
150 | evaluated_count = len(evaluated_ngrams)
151 |
152 | # Gets the overlapping ngrams between evaluated and reference
153 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
154 | overlapping_count = len(overlapping_ngrams)
155 |
156 | # Handle edge case. This isn't mathematically correct, but it's good enough
157 | if evaluated_count == 0:
158 | precision = 0.0
159 | else:
160 | precision = overlapping_count / evaluated_count
161 |
162 | if reference_count == 0:
163 | recall = 0.0
164 | else:
165 | recall = overlapping_count / reference_count
166 |
167 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
168 |
169 | # return overlapping_count / reference_count
170 | return f1_score, precision, recall
171 |
172 |
173 | def _f_p_r_lcs(llcs, m, n):
174 | """
175 | Computes the LCS-based F-measure score
176 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/
177 | rouge-working-note-v1.3.1.pdf
178 |
179 | Args:
180 | llcs: Length of LCS
181 | m: number of words in reference summary
182 | n: number of words in candidate summary
183 |
184 | Returns:
185 | Float. LCS-based F-measure score
186 | """
187 | r_lcs = llcs / m
188 | p_lcs = llcs / n
189 | beta = p_lcs / (r_lcs + 1e-12)
190 | num = (1 + (beta**2)) * r_lcs * p_lcs
191 | denom = r_lcs + ((beta**2) * p_lcs)
192 | f_lcs = num / (denom + 1e-12)
193 | return f_lcs, p_lcs, r_lcs
194 |
195 |
196 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences):
197 | """
198 | Computes ROUGE-L (sentence level) of two text collections of sentences.
199 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/
200 | rouge-working-note-v1.3.1.pdf
201 |
202 | Calculated according to:
203 | R_lcs = LCS(X,Y)/m
204 | P_lcs = LCS(X,Y)/n
205 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
206 |
207 | where:
208 | X = reference summary
209 | Y = Candidate summary
210 | m = length of reference summary
211 | n = length of candidate summary
212 |
213 | Args:
214 | evaluated_sentences: The sentences that have been picked by the summarizer
215 | reference_sentences: The sentences from the referene set
216 |
217 | Returns:
218 | A float: F_lcs
219 |
220 | Raises:
221 | ValueError: raises exception if a param has len <= 0
222 | """
223 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
224 | raise ValueError("Collections must contain at least 1 sentence.")
225 | reference_words = _split_into_words(reference_sentences)
226 | evaluated_words = _split_into_words(evaluated_sentences)
227 | m = len(reference_words)
228 | n = len(evaluated_words)
229 | lcs = _len_lcs(evaluated_words, reference_words)
230 | return _f_p_r_lcs(lcs, m, n)
231 |
232 |
233 | def _union_lcs(evaluated_sentences, reference_sentence):
234 | """
235 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common
236 | subsequence between reference sentence ri and candidate summary C. For example
237 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and
238 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is
239 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The
240 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and
241 | LCS_u(r_i, C) = 4/5.
242 |
243 | Args:
244 | evaluated_sentences: The sentences that have been picked by the summarizer
245 | reference_sentence: One of the sentences in the reference summaries
246 |
247 | Returns:
248 | float: LCS_u(r_i, C)
249 |
250 | ValueError:
251 | Raises exception if a param has len <= 0
252 | """
253 | if len(evaluated_sentences) <= 0:
254 | raise ValueError("Collections must contain at least 1 sentence.")
255 |
256 | lcs_union = set()
257 | reference_words = _split_into_words([reference_sentence])
258 | combined_lcs_length = 0
259 | for eval_s in evaluated_sentences:
260 | evaluated_words = _split_into_words([eval_s])
261 | lcs = set(_recon_lcs(reference_words, evaluated_words))
262 | combined_lcs_length += len(lcs)
263 | lcs_union = lcs_union.union(lcs)
264 |
265 | union_lcs_count = len(lcs_union)
266 | union_lcs_value = union_lcs_count / combined_lcs_length
267 | return union_lcs_value
268 |
269 |
270 | def rouge_l_summary_level(evaluated_sentences, reference_sentences):
271 | """
272 | Computes ROUGE-L (summary level) of two text collections of sentences.
273 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/
274 | rouge-working-note-v1.3.1.pdf
275 |
276 | Calculated according to:
277 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m
278 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n
279 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
280 |
281 | where:
282 | SUM(i,u) = SUM from i through u
283 | u = number of sentences in reference summary
284 | C = Candidate summary made up of v sentences
285 | m = number of words in reference summary
286 | n = number of words in candidate summary
287 |
288 | Args:
289 | evaluated_sentences: The sentences that have been picked by the summarizer
290 | reference_sentence: One of the sentences in the reference summaries
291 |
292 | Returns:
293 | A float: F_lcs
294 |
295 | Raises:
296 | ValueError: raises exception if a param has len <= 0
297 | """
298 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
299 | raise ValueError("Collections must contain at least 1 sentence.")
300 |
301 | # total number of words in reference sentences
302 | m = len(_split_into_words(reference_sentences))
303 |
304 | # total number of words in evaluated sentences
305 | n = len(_split_into_words(evaluated_sentences))
306 |
307 | union_lcs_sum_across_all_references = 0
308 | for ref_s in reference_sentences:
309 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences,
310 | ref_s)
311 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n)
312 |
313 |
314 | def rouge(hypotheses, references):
315 | """Calculates average rouge scores for a list of hypotheses and
316 | references"""
317 |
318 | # Filter out hyps that are of 0 length
319 | # hyps_and_refs = zip(hypotheses, references)
320 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0]
321 | # hypotheses, references = zip(*hyps_and_refs)
322 |
323 | # Calculate ROUGE-1 F1, precision, recall scores
324 | rouge_1 = [
325 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references)
326 | ]
327 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1))
328 |
329 | # Calculate ROUGE-2 F1, precision, recall scores
330 | rouge_2 = [
331 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references)
332 | ]
333 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2))
334 |
335 | # Calculate ROUGE-L F1, precision, recall scores
336 | rouge_l = [
337 | rouge_l_sentence_level([hyp], [ref])
338 | for hyp, ref in zip(hypotheses, references)
339 | ]
340 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l))
341 |
342 | return {
343 | "rouge_1/f_score": rouge_1_f,
344 | "rouge_1/r_score": rouge_1_r,
345 | "rouge_1/p_score": rouge_1_p,
346 | "rouge_2/f_score": rouge_2_f,
347 | "rouge_2/r_score": rouge_2_r,
348 | "rouge_2/p_score": rouge_2_p,
349 | "rouge_l/f_score": rouge_l_f,
350 | "rouge_l/r_score": rouge_l_r,
351 | "rouge_l/p_score": rouge_l_p,
352 | }
353 |
--------------------------------------------------------------------------------
/misc/input_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Adv-Def-Text Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Data preparation for training and testing.
18 | Author: Ying Xu
19 | Date: Jul 8, 2020
20 | """
21 |
22 | import tensorflow as tf
23 | import json
24 | import codecs
25 | import numpy as np
26 | import os
27 | from misc import iterator
28 | import misc.utils as utils
29 |
30 | VOCAB_SIZE_THRESHOLD_CPU = 50000
31 | UNK = ""
32 | SOS = ""
33 | EOS = ""
34 | SEP = ""
35 | UNK_ID = 0
36 | SOS_ID = 1
37 | EOS_ID = 2
38 | SEP_ID = 3
39 |
40 | def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None,
41 | eos=None, unk=None, vocab_base_name=None):
42 | """Check if vocab_file doesn't exist, create from corpus_file."""
43 | if tf.gfile.Exists(vocab_file):
44 | print("# Vocab file "+vocab_file+" exists")
45 | vocab, vocab_size = load_vocab(vocab_file)
46 | if check_special_token:
47 | # Verify if the vocab starts with unk, sos, eos
48 | # If not, prepend those tokens & generate a new vocab file
49 | if not unk: unk = UNK
50 | if not sos: sos = SOS
51 | if not eos: eos = EOS
52 | assert len(vocab) >= 3
53 | if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos:
54 | vocab = [unk, sos, eos] + vocab
55 | vocab_size += 3
56 | new_vocab_file = os.path.join(out_dir, vocab_base_name)
57 | with codecs.getwriter("utf-8")(
58 | tf.gfile.GFile(new_vocab_file, "wb")) as f:
59 | for word in vocab:
60 | f.write("%s\n" % word)
61 | vocab_file = new_vocab_file
62 | else:
63 | raise ValueError("vocab_file '%s' does not exist." % vocab_file)
64 |
65 | vocab_size = len(vocab)
66 | return vocab_size, vocab_file
67 |
68 | def load_vocab(vocab_file):
69 | vocab = []
70 | with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f:
71 | vocab_size = 0
72 | for word in f:
73 | vocab_size += 1
74 | vocab.append(word.strip())
75 | return vocab, vocab_size
76 |
77 | def load_embed_json(embed_file, vocab_size=None, dtype=tf.float32):
78 | with codecs.open(embed_file, "r", "utf-8") as fh:
79 | emb_dict = json.load(fh)
80 | emb_size = len(emb_dict[0])
81 | emb_mat = np.array(emb_dict, dtype=dtype.as_numpy_dtype())
82 | if vocab_size > len(emb_mat):
83 | np.random.seed(0)
84 | emb_mat_var = np.random.rand(vocab_size-len(emb_mat), emb_size)
85 | emb_mat_var = np.array(emb_mat_var, dtype=dtype.as_numpy_dtype())
86 | emb_mat = np.concatenate([emb_mat_var, emb_mat], axis=0)
87 | return emb_mat, emb_size
88 |
89 | def _load_simlex(lex_file, vocab_size):
90 | lexis = utils.readlines(lex_file)
91 | sim_num = int(lex_file.split('-')[-2])+1
92 | base = 0
93 | if vocab_size > len(lexis):
94 | base = 3
95 | lexis = np.concatenate([['']*base, lexis], axis=0)
96 | lex = []
97 | for ind, line in enumerate(lexis):
98 | if line == '':
99 | lex.append(np.array([ind]+[-1]*(sim_num-1)))
100 | else:
101 | comps = line.split(' ')
102 | lex.append(np.array([int(a)+base for a in comps]+[ind]+[-1]*(sim_num-(len(comps)+1))))
103 | return np.array(lex)
104 |
105 | def _get_embed_device(vocab_size, num_gpus):
106 | """Decide on which device to place an embed matrix given its vocab size."""
107 | if num_gpus == 0 or vocab_size > VOCAB_SIZE_THRESHOLD_CPU:
108 | return "/cpu:0"
109 | else:
110 | return "/gpu:0"
111 |
112 | def _create_pretrained_embeddings_from_jsons(
113 | vocab_file, embed_file, cls_vocab_file, cls_embed_file,
114 | dtype=tf.float32, cls_model='RNN'):
115 |
116 | vocab, _ = load_vocab(vocab_file)
117 | emb_mat, emb_size = load_embed_json(embed_file, vocab_size=len(vocab), dtype=dtype)
118 |
119 | cls_vocab, _ = load_vocab(cls_vocab_file)
120 | cls_emb_mat, cls_emb_size = load_embed_json(cls_embed_file, vocab_size=len(cls_vocab), dtype=dtype)
121 |
122 | transfer_emb_mat = []
123 | unk_id = 100 if cls_model == 'BERT' else UNK_ID
124 | if cls_model == 'BERT':
125 | for word in vocab:
126 | if word == '':
127 | transfer_emb_mat.append(cls_emb_mat[cls_vocab.index('[CLS]')])
128 | elif word == '':
129 | transfer_emb_mat.append(cls_emb_mat[cls_vocab.index('[SEP]')])
130 | elif word == '':
131 | transfer_emb_mat.append(cls_emb_mat[cls_vocab.index('[UNK]')])
132 | elif word in cls_vocab:
133 | transfer_emb_mat.append(cls_emb_mat[cls_vocab.index(word)])
134 | else:
135 | transfer_emb_mat.append(cls_emb_mat[unk_id])
136 | else:
137 | for word in vocab:
138 | if word in cls_vocab:
139 | transfer_emb_mat.append(cls_emb_mat[cls_vocab.index(word)])
140 | else:
141 | transfer_emb_mat.append(cls_emb_mat[unk_id])
142 |
143 | # with tf.device("/cpu:0"):
144 | cls_emb_mat = tf.constant(cls_emb_mat)
145 | emb_mat = tf.constant(emb_mat)
146 | transfer_emb_mat = tf.constant(np.array(transfer_emb_mat))
147 | return cls_emb_mat, emb_mat, transfer_emb_mat
148 |
149 | def _create_pretrained_emb_from_txt(
150 | vocab_file, embed_file, trainable_tokens=3, dtype=tf.float32):
151 |
152 | vocab, _ = load_vocab(vocab_file)
153 | emb_mat, emb_size = load_embed_json(embed_file, vocab_size=len(vocab), dtype=dtype)
154 |
155 | emb_mat = tf.constant(emb_mat)
156 | return emb_mat
157 |
158 |
159 | def get_labels(args):
160 | output_labels = []
161 | for line in open(args.output_file, 'r'):
162 | output_labels.append(int(line.strip()))
163 | return output_labels
164 |
165 | def get_dataset_iter(args, input_file, output_file, task, is_training=True, is_test=False, is_bert=False):
166 | unk = UNK_ID
167 | if is_bert:
168 | unk = 100
169 | table = tf.contrib.lookup.index_table_from_file(
170 | vocabulary_file=args.vocab_file, default_value=unk)
171 | src_dataset = tf.data.TextLineDataset(tf.gfile.Glob(input_file))
172 |
173 | if len(args.def_train_set) > 0:
174 | src_datasets = [src_dataset]
175 | base_name = os.path.basename(args.input_file) if is_training else os.path.basename(args.dev_file)
176 | path_to_base_name = args.input_file[:args.input_file.rfind('/')] if is_training else args.dev_file[:args.dev_file.rfind('/')]
177 |
178 | for set_name in args.def_train_set:
179 | base_names = base_name.split('.')
180 | base_names.insert(-1, set_name)
181 | base_name_1 = '.'.join(base_names)
182 | full_name = path_to_base_name+'/'+base_name_1
183 | src_datasets.append(tf.data.TextLineDataset(tf.gfile.Glob(full_name)))
184 |
185 | tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(output_file))
186 | if len(src_datasets) > 2:
187 | iter = iterator.get_cls_multi_def_iterator(src_datasets, tgt_dataset, table, args.batch_size, args.num_epochs,
188 | SOS if (not is_bert) else '[CLS]', EOS if (not is_bert) else '[SEP]',
189 | args.random_seed, args.num_buckets,
190 | src_max_len=args.max_len, is_training=is_training)
191 | else:
192 | iter = iterator.get_cls_def_iterator(src_datasets, tgt_dataset, table, args.batch_size, args.num_epochs,
193 | SOS if (not is_bert) else '[CLS]', EOS if (not is_bert) else '[SEP]',
194 | args.random_seed, args.num_buckets,
195 | src_max_len=args.max_len, is_training=is_training)
196 | return iter
197 |
198 | if task=='adv' or task=='ae':
199 | tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(output_file))
200 | iter = iterator.get_adv_iterator(src_dataset, tgt_dataset, table, args.batch_size, args.num_epochs,
201 | SOS if (not is_bert) else '[CLS]', EOS if (not is_bert) else '[SEP]',
202 | SEP if (not is_bert) else '[SEP]',
203 | args.random_seed, args.num_buckets,
204 | src_max_len=args.max_len, is_training=is_training)
205 | return iter
206 |
207 | if task=='clss':
208 | tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(output_file))
209 | iter = iterator.get_cls_iterator(src_dataset, tgt_dataset, table, args.batch_size, args.num_epochs,
210 | SOS if (not is_bert) else '[CLS]', EOS if (not is_bert) else '[SEP]',
211 | SEP if (not is_bert) else '[SEP]',
212 | args.random_seed, args.num_buckets,
213 | src_max_len=args.max_len, is_training=is_training)
214 | return iter
215 |
216 | if task == 'adv_counter_fitting':
217 | ae_table = tf.contrib.lookup.index_table_from_file(vocabulary_file=args.ae_vocab_file, default_value=unk)
218 | tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(output_file))
219 | iter = iterator.get_adv_cf_iterator(src_dataset, tgt_dataset, table, args.batch_size, args.num_epochs,
220 | SOS if (not is_bert) else '[CLS]', EOS if (not is_bert) else '[SEP]',
221 | SEP if (not is_bert) else '[SEP]',
222 | args.random_seed, args.num_buckets,
223 | src_max_len=args.max_len, is_training=is_training,
224 | ae_vocab_table=ae_table)
225 | return iter
226 |
227 | # if task=='ae':
228 | # if output_file is not None:
229 | # tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(output_file))
230 | # else:
231 | # tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(input_file))
232 | # iter = iterator.get_iterator(src_dataset, tgt_dataset, table, args.batch_size, args.num_epochs,
233 | # SOS if (not is_bert) else '[CLS]', EOS if (not is_bert) else '[SEP]',
234 | # args.random_seed, args.num_buckets, src_max_len=args.max_len,
235 | # tgt_max_len=args.max_len, is_training=is_training, min_len=args.tgt_min_len)
236 | # return iter
237 |
238 |
239 | def parse_generated(input_file):
240 | sampled, generated = [], []
241 | for line in open('data_gen/'+input_file+'.txt', 'r'):
242 | if line.startswith('Example '):
243 | if ' spl:' in line:
244 | comps = line.strip().split('\t')
245 | sampled.append(comps[1])
246 | elif ' nmt:' in line:
247 | comps = line.strip().split('\t')
248 | generated.append(comps[1])
249 | with open('data_gen/'+input_file+'_spl.in', 'w') as output_file:
250 | for sample in sampled:
251 | output_file.write(sample+'\n')
252 | with open('data_gen/'+input_file+'_dec.in', 'w') as output_file:
253 | for generate in generated:
254 | output_file.write(generate+'\n')
255 |
256 | with open('data_gen/'+input_file+'_acc.txt', 'w') as output_file:
257 | for generate in generated:
258 | output_file.write('yelp\t0\t\t'+generate+'\n')
259 |
260 |
261 | if __name__ == '__main__':
262 | parse_generated('cls_bi_att-hinge10-lr0001-wl5-emb3-beam')
263 |
--------------------------------------------------------------------------------
/models/bert/measures.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import tensorflow as tf
3 | import json
4 | import six
5 | import models.bert.tokenization as tokenization
6 | import math
7 |
8 | def f1_score(prediction, ground_truth):
9 | print('compare: '+prediction+'\t'+ground_truth)
10 | prediction_tokens = prediction.split()
11 | ground_truth_tokens = ground_truth.split()
12 | common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens)
13 | num_same = sum(common.values())
14 | if num_same == 0:
15 | return 0
16 | precision = 1.0 * num_same / len(prediction_tokens)
17 | recall = 1.0 * num_same / len(ground_truth_tokens)
18 | f1 = (2 * precision * recall) / (precision + recall)
19 | return f1
20 |
21 | def exact_match(prediction, ground_truth):
22 | return prediction == ground_truth
23 |
24 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
25 | scores_for_ground_truths = []
26 | for idx, ground_truth in enumerate(ground_truths):
27 | score = metric_fn(prediction, ground_truth)
28 | scores_for_ground_truths.append(score)
29 | return max(scores_for_ground_truths)
30 |
31 | def computeF1(outputs, targets):
32 | return sum([metric_max_over_ground_truths(f1_score, o, t) for o, t in zip(outputs, targets)])/len(outputs) * 100
33 |
34 | def computeEM(outputs, targets):
35 | outs = [metric_max_over_ground_truths(exact_match, o, t) for o, t in zip(outputs, targets)]
36 | return sum(outs)/len(outputs) * 100
37 |
38 | def _get_best_indexes(logits, n_best_size):
39 | """Get the n-best logits from a list."""
40 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
41 |
42 | best_indexes = []
43 | for i in range(len(index_and_score)):
44 | if i >= n_best_size:
45 | break
46 | best_indexes.append(index_and_score[i][0])
47 | return best_indexes
48 |
49 | def _compute_softmax(scores):
50 | """Compute softmax probability over raw logits."""
51 | if not scores:
52 | return []
53 |
54 | max_score = None
55 | for score in scores:
56 | if max_score is None or score > max_score:
57 | max_score = score
58 |
59 | exp_scores = []
60 | total_sum = 0.0
61 | for score in scores:
62 | x = math.exp(score - max_score)
63 | exp_scores.append(x)
64 | total_sum += x
65 |
66 | probs = []
67 | for score in exp_scores:
68 | probs.append(score / total_sum)
69 | return probs
70 |
71 | def get_final_text(pred_text, orig_text, do_lower_case):
72 | """Project the tokenized prediction back to the original text."""
73 |
74 | # When we created the data, we kept track of the alignment between original
75 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
76 | # now `orig_text` contains the span of our original text corresponding to the
77 | # span that we predicted.
78 | #
79 | # However, `orig_text` may contain extra characters that we don't want in
80 | # our prediction.
81 | #
82 | # For example, let's say:
83 | # pred_text = steve smith
84 | # orig_text = Steve Smith's
85 | #
86 | # We don't want to return `orig_text` because it contains the extra "'s".
87 | #
88 | # We don't want to return `pred_text` because it's already been normalized
89 | # (the SQuAD eval script also does punctuation stripping/lower casing but
90 | # our tokenizer does additional normalization like stripping accent
91 | # characters).
92 | #
93 | # What we really want to return is "Steve Smith".
94 | #
95 | # Therefore, we have to apply a semi-complicated alignment heruistic between
96 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
97 | # can fail in certain cases in which case we just return `orig_text`.
98 |
99 | def _strip_spaces(text):
100 | ns_chars = []
101 | ns_to_s_map = collections.OrderedDict()
102 | for (i, c) in enumerate(text):
103 | if c == " ":
104 | continue
105 | ns_to_s_map[len(ns_chars)] = i
106 | ns_chars.append(c)
107 | ns_text = "".join(ns_chars)
108 | return (ns_text, ns_to_s_map)
109 |
110 | # We first tokenize `orig_text`, strip whitespace from the result
111 | # and `pred_text`, and check if they are the same length. If they are
112 | # NOT the same length, the heuristic has failed. If they are the same
113 | # length, we assume the characters are one-to-one aligned.
114 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
115 |
116 | tok_text = " ".join(tokenizer.tokenize(orig_text))
117 |
118 | start_position = tok_text.find(pred_text)
119 | if start_position == -1:
120 | return orig_text
121 | end_position = start_position + len(pred_text) - 1
122 |
123 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
124 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
125 |
126 | if len(orig_ns_text) != len(tok_ns_text):
127 | return orig_text
128 |
129 | # We then project the characters in `pred_text` back to `orig_text` using
130 | # the character-to-character alignment.
131 | tok_s_to_ns_map = {}
132 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
133 | tok_s_to_ns_map[tok_index] = i
134 |
135 | orig_start_position = None
136 | if start_position in tok_s_to_ns_map:
137 | ns_start_position = tok_s_to_ns_map[start_position]
138 | if ns_start_position in orig_ns_to_s_map:
139 | orig_start_position = orig_ns_to_s_map[ns_start_position]
140 |
141 | if orig_start_position is None:
142 | return orig_text
143 |
144 | orig_end_position = None
145 | if end_position in tok_s_to_ns_map:
146 | ns_end_position = tok_s_to_ns_map[end_position]
147 | if ns_end_position in orig_ns_to_s_map:
148 | orig_end_position = orig_ns_to_s_map[ns_end_position]
149 |
150 | if orig_end_position is None:
151 | return orig_text
152 |
153 | output_text = orig_text[orig_start_position:(orig_end_position + 1)]
154 | return output_text
155 |
156 | def write_predictions(all_examples, all_features, all_results, n_best_size,
157 | max_answer_length, do_lower_case, output_prediction_file,
158 | output_nbest_file, output_null_log_odds_file):
159 | """Write final predictions to the json file and log-odds of null if needed."""
160 | tf.logging.info("Writing predictions to: %s" % (output_prediction_file))
161 | tf.logging.info("Writing nbest to: %s" % (output_nbest_file))
162 |
163 | example_index_to_features = collections.defaultdict(list)
164 | for feature in all_features:
165 | example_index_to_features[feature.example_index].append(feature)
166 |
167 | unique_id_to_result = {}
168 | for result in all_results:
169 | unique_id_to_result[result.unique_id] = result
170 |
171 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
172 | "PrelimPrediction",
173 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
174 |
175 | all_predictions = collections.OrderedDict()
176 | all_nbest_json = collections.OrderedDict()
177 | scores_diff_json = collections.OrderedDict()
178 |
179 | groundtruths = []
180 | predictions = []
181 | for (example_index, example) in enumerate(all_examples):
182 |
183 | features = example_index_to_features[example_index]
184 |
185 | prelim_predictions = []
186 | # keep track of the minimum score of null start+end of position 0
187 | score_null = 1000000 # large and positive
188 | min_null_feature_index = 0 # the paragraph slice with min mull score
189 | null_start_logit = 0 # the start logit at the slice with min null score
190 | null_end_logit = 0 # the end logit at the slice with min null score
191 | for (feature_index, feature) in enumerate(features):
192 | result = unique_id_to_result[feature.unique_id]
193 | start_indexes = _get_best_indexes(result.start_logits, n_best_size)
194 | end_indexes = _get_best_indexes(result.end_logits, n_best_size)
195 | # if we could have irrelevant answers, get the min score of irrelevant
196 | for start_index in start_indexes:
197 | for end_index in end_indexes:
198 | # We could hypothetically create invalid predictions, e.g., predict
199 | # that the start of the span is in the question. We throw out all
200 | # invalid predictions.
201 | if start_index >= len(feature.tokens):
202 | continue
203 | if end_index >= len(feature.tokens):
204 | continue
205 | if start_index not in feature.token_to_orig_map:
206 | continue
207 | if end_index not in feature.token_to_orig_map:
208 | continue
209 | if not feature.token_is_max_context.get(start_index, False):
210 | continue
211 | if end_index < start_index:
212 | continue
213 | length = end_index - start_index + 1
214 | if length > max_answer_length:
215 | continue
216 | prelim_predictions.append(
217 | _PrelimPrediction(
218 | feature_index=feature_index,
219 | start_index=start_index,
220 | end_index=end_index,
221 | start_logit=result.start_logits[start_index],
222 | end_logit=result.end_logits[end_index]))
223 |
224 | prelim_predictions = sorted(
225 | prelim_predictions,
226 | key=lambda x: (x.start_logit + x.end_logit),
227 | reverse=True)
228 |
229 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
230 | "NbestPrediction", ["text", "start_logit", "end_logit"])
231 |
232 | seen_predictions = {}
233 | nbest = []
234 | for pred in prelim_predictions:
235 | if len(nbest) >= n_best_size:
236 | break
237 | feature = features[pred.feature_index]
238 | if pred.start_index > 0: # this is a non-null prediction
239 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
240 | orig_doc_start = feature.token_to_orig_map[pred.start_index]
241 | orig_doc_end = feature.token_to_orig_map[pred.end_index]
242 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
243 | tok_text = " ".join(tok_tokens)
244 |
245 | # De-tokenize WordPieces that have been split off.
246 | tok_text = tok_text.replace(" ##", "")
247 | tok_text = tok_text.replace("##", "")
248 |
249 | # Clean whitespace
250 | tok_text = tok_text.strip()
251 | tok_text = " ".join(tok_text.split())
252 | orig_text = " ".join(orig_tokens)
253 |
254 | final_text = get_final_text(tok_text, orig_text, do_lower_case)
255 | if final_text in seen_predictions:
256 | continue
257 |
258 | seen_predictions[final_text] = True
259 | else:
260 | final_text = ""
261 | seen_predictions[final_text] = True
262 |
263 | nbest.append(
264 | _NbestPrediction(
265 | text=final_text,
266 | start_logit=pred.start_logit,
267 | end_logit=pred.end_logit))
268 |
269 | # if we didn't inlude the empty option in the n-best, inlcude it
270 | # In very rare edge cases we could have no valid predictions. So we
271 | # just create a nonce prediction in this case to avoid failure.
272 | if not nbest:
273 | nbest.append(
274 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
275 |
276 | assert len(nbest) >= 1
277 |
278 | total_scores = []
279 | best_non_null_entry = None
280 | for entry in nbest:
281 | total_scores.append(entry.start_logit + entry.end_logit)
282 | if not best_non_null_entry:
283 | if entry.text:
284 | best_non_null_entry = entry
285 |
286 | probs = _compute_softmax(total_scores)
287 |
288 | nbest_json = []
289 | for (i, entry) in enumerate(nbest):
290 | output = collections.OrderedDict()
291 | output["text"] = entry.text
292 | output["probability"] = probs[i]
293 | output["start_logit"] = entry.start_logit
294 | output["end_logit"] = entry.end_logit
295 | nbest_json.append(output)
296 |
297 | assert len(nbest_json) >= 1
298 |
299 | all_predictions[example.qas_id] = nbest_json[0]["text"]
300 |
301 | all_nbest_json[example.qas_id] = nbest_json
302 |
303 | predictions.append(nbest_json[0]["text"])
304 | groundtruth_answer = example.orig_answer_text
305 | groundtruths.append(groundtruth_answer)
306 |
307 | assert len(predictions)==len(groundtruths)
308 |
309 | f1 = computeF1(predictions, groundtruths)
310 | em = computeEM(predictions, groundtruths)
311 |
312 | tf.logging.info('Eval F1: '+str(f1) + ', EM: '+str(em))
313 |
314 | with tf.gfile.GFile(output_prediction_file, "w") as writer:
315 | writer.write(json.dumps(all_predictions, indent=4) + "\n")
316 |
317 | with tf.gfile.GFile(output_nbest_file, "w") as writer:
318 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
319 |
--------------------------------------------------------------------------------
/models/mySeq2Seq.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Adv-Def-Text Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | RNN-based auto-encoder model.
18 | Author: Ying Xu
19 | Date: Jul 8, 2020
20 | """
21 |
22 | import tensorflow as tf
23 | import tensorflow.contrib.seq2seq as seq2seq
24 | from misc import input_data
25 | import models.utils as utils
26 |
27 | class Seq2SeqModel():
28 | def __init__(self, args, mode = None):
29 |
30 | self.mode = mode
31 | self.bidirectional = True if args.enc_type == 'bi' else False
32 | self.args = args
33 |
34 | # self.batch_size = args.batch_size
35 | self.batch_size = tf.placeholder(tf.int32, [], name='batch_size')
36 | self.beam_width = args.beam_width
37 |
38 | self._make_graph()
39 |
40 | def _make_graph(self):
41 |
42 | self._init_placeholders()
43 | self._init_decoder_train_connectors()
44 |
45 | with tf.variable_scope("my_seq2seq", reuse=tf.AUTO_REUSE) as scope:
46 | self._init_embedding()
47 |
48 | self._init_encoder()
49 | self._init_decoder()
50 |
51 | if self.mode == "Train":
52 | self._init_optimizer()
53 |
54 | def _init_placeholders(self):
55 | self.encoder_inputs = tf.placeholder(
56 | shape=(None, None),
57 | dtype=tf.int32,
58 | name='encoder_inputs'
59 | )
60 |
61 | self.encoder_inputs_length = tf.placeholder(
62 | shape=(None,),
63 | dtype=tf.int32,
64 | name='encoder_inputs_length',
65 | )
66 |
67 | self.decoder_inputs = tf.placeholder(
68 | shape=(None, None),
69 | dtype=tf.int32,
70 | name='decoder_inputs',
71 | )
72 |
73 | self.decoder_outputs = tf.placeholder(
74 | shape=(None, None),
75 | dtype=tf.int32,
76 | name='decoder_outputs',
77 | )
78 |
79 | self.decoder_targets_length = tf.placeholder(
80 | shape = (None,),
81 | dtype = tf.int32,
82 | name = 'decoder_targets_length',
83 | )
84 |
85 | self.classification_outputs = tf.placeholder(
86 | shape=(None, self.args.output_classes),
87 | dtype=tf.int32,
88 | name='classification_outputs',
89 | )
90 |
91 | def _init_decoder_train_connectors(self):
92 | with tf.name_scope('DecoderTrainFeeds'):
93 | self.decoder_train_length = self.decoder_targets_length
94 | self.loss_weights = tf.ones(
95 | [self.batch_size, tf.reduce_max(self.decoder_train_length)],
96 | dtype=tf.float32)
97 |
98 | def _init_embedding(self):
99 | self.embedding_encoder = input_data._create_pretrained_emb_from_txt(
100 | vocab_file=self.args.vocab_file, embed_file=self.args.emb_file)
101 | self.encoder_embedding_inputs = tf.nn.embedding_lookup(
102 | self.embedding_encoder,
103 | self.encoder_inputs) #[batch size, sequence len, h_dim]
104 |
105 | self.embedding_decoder = self.embedding_encoder
106 | self.decoder_embedding_inputs = tf.nn.embedding_lookup(
107 | self.embedding_decoder,
108 | self.decoder_inputs) #[batch size, sequence len, h_dim]
109 |
110 | def _init_encoder(self):
111 | with tf.variable_scope("Encoder") as scope:
112 | if self.bidirectional:
113 | fw_cell = tf.contrib.rnn.MultiRNNCell(
114 | [utils.make_cell(self.args.enc_num_units, utils.get_device_str(self.args.num_gpus)) for _ in range(self.args.enc_layers)])
115 | bw_cell = tf.contrib.rnn.MultiRNNCell(
116 | [utils.make_cell(self.args.enc_num_units, utils.get_device_str(self.args.num_gpus)) for _ in range(self.args.enc_layers)])
117 | bi_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn(
118 | fw_cell,
119 | bw_cell,
120 | self.encoder_embedding_inputs,
121 | dtype=tf.float32,
122 | sequence_length=self.encoder_inputs_length)
123 |
124 | self.encoder_outputs = tf.concat(bi_outputs, -1) #[batch size, sequence len, h_dim*2]
125 | bi_encoder_state = bi_state
126 |
127 | if self.args.enc_layers == 1:
128 | self.encoder_state = bi_encoder_state
129 | else:
130 | # alternatively concat forward and backward states
131 | encoder_state = []
132 | for layer_id in range(self.args.enc_layers):
133 | encoder_state.append(bi_encoder_state[0][layer_id]) # forward
134 | encoder_state.append(bi_encoder_state[1][layer_id]) # backward
135 | self.encoder_state = tuple(encoder_state)
136 | else:
137 | encoder_cell = tf.contrib.rnn.MultiRNNCell([utils.make_cell(self.args.enc_num_units, utils.get_device_str(self.args.num_gpus)) for _ in range(self.args.enc_layers)])
138 | self.encoder_outputs, self.encoder_state = tf.nn.dynamic_rnn(
139 | cell=encoder_cell,
140 | inputs=self.encoder_embedding_inputs,
141 | sequence_length=self.encoder_inputs_length,
142 | dtype=tf.float32
143 | ) #[batch size, sequence len, h_dim]
144 |
145 | def _init_decoder(self):
146 |
147 | def create_decoder_cell():
148 | cell = tf.contrib.rnn.MultiRNNCell([utils.make_cell(self.args.enc_num_units, utils.get_device_str(self.args.num_gpus)) for _ in range(self.args.dec_layers)])
149 |
150 | if self.args.beam_width > 0 and self.mode == "Infer":
151 | dec_start_state = seq2seq.tile_batch(self.encoder_state, self.beam_width)
152 | enc_outputs = seq2seq.tile_batch(self.encoder_outputs, self.beam_width)
153 | enc_lengths = seq2seq.tile_batch(self.encoder_inputs_length, self.beam_width)
154 | else:
155 | dec_start_state = self.encoder_state
156 | enc_outputs = self.encoder_outputs
157 | enc_lengths = self.encoder_inputs_length
158 |
159 | if self.args.attention:
160 | attention_states = enc_outputs
161 |
162 | attention_mechanism = tf.contrib.seq2seq.LuongAttention(
163 | self.args.dec_num_units,
164 | attention_states,
165 | memory_sequence_length = enc_lengths)
166 |
167 | decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
168 | cell,
169 | attention_mechanism,
170 | attention_layer_size = self.args.dec_num_units)
171 |
172 | if self.args.beam_width > 0 and self.mode == "Infer":
173 | initial_state = decoder_cell.zero_state(self.batch_size * self.beam_width, tf.float32)
174 | else:
175 | initial_state = decoder_cell.zero_state(self.batch_size, tf.float32)
176 |
177 | initial_state = initial_state.clone(cell_state = dec_start_state)
178 | else:
179 |
180 | decoder_cell = cell
181 | initial_state = dec_start_state
182 |
183 | return decoder_cell, initial_state
184 |
185 | with tf.variable_scope("Decoder") as scope:
186 |
187 | projection_layer = tf.layers.Dense(self.args.vocab_size, use_bias=False, name="projection_layer")
188 |
189 | self.encoder_state = tuple(self.encoder_state[-2:])
190 |
191 | decoder_cell, initial_state = create_decoder_cell()
192 |
193 | if self.mode == "Train":
194 | training_helper = tf.contrib.seq2seq.TrainingHelper(
195 | self.decoder_embedding_inputs,
196 | self.decoder_train_length)
197 |
198 |
199 | training_decoder = tf.contrib.seq2seq.BasicDecoder(
200 | cell = decoder_cell,
201 | helper = training_helper,
202 | initial_state = initial_state,
203 | output_layer=projection_layer)
204 |
205 | (self.decoder_outputs_train,
206 | self.decoder_state_train,
207 | final_sequence_length) = tf.contrib.seq2seq.dynamic_decode(
208 | decoder = training_decoder,
209 | impute_finished = True,
210 | scope = scope
211 | )
212 |
213 | self.decoder_logits_train = self.decoder_outputs_train.rnn_output
214 | decoder_predictions_train = tf.argmax(self.decoder_logits_train, axis=-1)
215 | self.decoder_predictions_train = tf.identity(decoder_predictions_train)
216 |
217 | elif self.mode == "Infer":
218 | start_tokens = tf.tile(tf.constant([input_data.SOS_ID], dtype=tf.int32), [self.batch_size])
219 |
220 | if self.args.beam_width > 0:
221 | inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
222 | cell = decoder_cell,
223 | embedding = self.embedding_decoder,
224 | start_tokens = tf.ones_like(self.encoder_inputs_length) * tf.constant(input_data.SOS_ID, dtype = tf.int32),
225 | end_token = tf.constant(input_data.EOS_ID, dtype = tf.int32),
226 | initial_state = initial_state,
227 | beam_width = self.beam_width,
228 | output_layer = projection_layer)
229 |
230 | self.decoder_outputs_inference, __, ___ = tf.contrib.seq2seq.dynamic_decode(
231 | decoder = inference_decoder,
232 | maximum_iterations = tf.round(tf.reduce_max(self.decoder_targets_length)) * 2,
233 | impute_finished = False,
234 | scope = scope)
235 |
236 | self.decoder_predictions_inference = tf.identity(self.decoder_outputs_inference.predicted_ids)
237 | else:
238 | inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
239 | self.embedding_decoder,
240 | start_tokens = start_tokens,
241 | end_token=input_data.EOS_ID) # EOS id
242 |
243 | inference_decoder = tf.contrib.seq2seq.BasicDecoder(
244 | cell = decoder_cell,
245 | helper = inference_helper,
246 | initial_state = initial_state,
247 | output_layer = projection_layer)
248 |
249 | self.decoder_outputs_inference, _, _ = tf.contrib.seq2seq.dynamic_decode(
250 | decoder = inference_decoder,
251 | maximum_iterations = tf.round(tf.reduce_max(self.decoder_targets_length)) * 2,
252 | impute_finished = False,
253 | scope = scope)
254 |
255 | self.decoder_predictions_inference = tf.identity(self.decoder_outputs_inference.sample_id)
256 |
257 | def _init_optimizer(self):
258 | loss_mask = tf.sequence_mask(
259 | tf.to_int32(self.decoder_targets_length),
260 | tf.reduce_max(self.decoder_targets_length),
261 | dtype = tf.float32)
262 | self.loss = tf.contrib.seq2seq.sequence_loss(
263 | self.decoder_logits_train,
264 | self.decoder_outputs,
265 | loss_mask)
266 | tf.summary.scalar('loss', self.loss)
267 | self.summary_op = tf.summary.merge_all()
268 |
269 | learning_rate = self.args.learning_rate
270 | optimizer = tf.train.AdamOptimizer(learning_rate)
271 | gradients = optimizer.compute_gradients(self.loss)
272 | capped_gradients = [(tf.clip_by_value(grad, -1.0*self.args.max_gradient_norm, self.args.max_gradient_norm), var) for grad, var in gradients if grad is not None]
273 | self.train_op = optimizer.apply_gradients(capped_gradients)
274 |
275 |
276 | # Inputs and Outputs for train and infer
277 | def make_train_inputs(self, x):
278 | return {
279 | self.encoder_inputs: x[0],
280 | self.decoder_inputs: x[1],
281 | self.decoder_outputs: x[2],
282 | self.classification_outputs: x[3],
283 | self.encoder_inputs_length: x[4],
284 | self.decoder_targets_length: x[5],
285 | self.batch_size: len(x[0])
286 | }
287 |
288 | def make_train_outputs(self, full_loss_step=True, defence=False):
289 | return [self.train_op, self.loss, self.decoder_predictions_train, self.summary_op]
290 |
291 | def make_infer_outputs(self):
292 | return self.decoder_predictions_inference
293 |
294 | def make_eval_outputs(self):
295 | return self.loss
296 |
--------------------------------------------------------------------------------
/models/bert/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in self.basic_tokenizer.tokenize(text):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | class BasicTokenizer(object):
186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
187 |
188 | def __init__(self, do_lower_case=True):
189 | """Constructs a BasicTokenizer.
190 |
191 | Args:
192 | do_lower_case: Whether to lower case the input.
193 | """
194 | self.do_lower_case = do_lower_case
195 |
196 | def tokenize(self, text):
197 | """Tokenizes a piece of text."""
198 | text = convert_to_unicode(text)
199 | text = self._clean_text(text)
200 |
201 | # This was added on November 1st, 2018 for the multilingual and Chinese
202 | # models. This is also applied to the English models now, but it doesn't
203 | # matter since the English models were not trained on any Chinese data
204 | # and generally don't have any Chinese data in them (there are Chinese
205 | # characters in the vocabulary because Wikipedia does have some Chinese
206 | # words in the English Wikipedia.).
207 | text = self._tokenize_chinese_chars(text)
208 |
209 | orig_tokens = whitespace_tokenize(text)
210 | split_tokens = []
211 | for token in orig_tokens:
212 | if self.do_lower_case:
213 | token = token.lower()
214 | token = self._run_strip_accents(token)
215 | split_tokens.extend(self._run_split_on_punc(token))
216 |
217 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
218 | return output_tokens
219 |
220 | def _run_strip_accents(self, text):
221 | """Strips accents from a piece of text."""
222 | text = unicodedata.normalize("NFD", text)
223 | output = []
224 | for char in text:
225 | cat = unicodedata.category(char)
226 | if cat == "Mn":
227 | continue
228 | output.append(char)
229 | return "".join(output)
230 |
231 | def _run_split_on_punc(self, text):
232 | """Splits punctuation on a piece of text."""
233 | chars = list(text)
234 | i = 0
235 | start_new_word = True
236 | output = []
237 | while i < len(chars):
238 | char = chars[i]
239 | if _is_punctuation(char):
240 | output.append([char])
241 | start_new_word = True
242 | else:
243 | if start_new_word:
244 | output.append([])
245 | start_new_word = False
246 | output[-1].append(char)
247 | i += 1
248 |
249 | return ["".join(x) for x in output]
250 |
251 | def _tokenize_chinese_chars(self, text):
252 | """Adds whitespace around any CJK character."""
253 | output = []
254 | for char in text:
255 | cp = ord(char)
256 | if self._is_chinese_char(cp):
257 | output.append(" ")
258 | output.append(char)
259 | output.append(" ")
260 | else:
261 | output.append(char)
262 | return "".join(output)
263 |
264 | def _is_chinese_char(self, cp):
265 | """Checks whether CP is the codepoint of a CJK character."""
266 | # This defines a "chinese character" as anything in the CJK Unicode block:
267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
268 | #
269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
270 | # despite its name. The modern Korean Hangul alphabet is a different block,
271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
272 | # space-separated words, so they are not treated specially and handled
273 | # like the all of the other languages.
274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
275 | (cp >= 0x3400 and cp <= 0x4DBF) or #
276 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
277 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
278 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
280 | (cp >= 0xF900 and cp <= 0xFAFF) or #
281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
282 | return True
283 |
284 | return False
285 |
286 | def _clean_text(self, text):
287 | """Performs invalid character removal and whitespace cleanup on text."""
288 | output = []
289 | for char in text:
290 | cp = ord(char)
291 | if cp == 0 or cp == 0xfffd or _is_control(char):
292 | continue
293 | if _is_whitespace(char):
294 | output.append(" ")
295 | else:
296 | output.append(char)
297 | return "".join(output)
298 |
299 |
300 | class WordpieceTokenizer(object):
301 | """Runs WordPiece tokenziation."""
302 |
303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
304 | self.vocab = vocab
305 | self.unk_token = unk_token
306 | self.max_input_chars_per_word = max_input_chars_per_word
307 |
308 | def tokenize(self, text):
309 | """Tokenizes a piece of text into its word pieces.
310 |
311 | This uses a greedy longest-match-first algorithm to perform tokenization
312 | using the given vocabulary.
313 |
314 | For example:
315 | input = "unaffable"
316 | output = ["un", "##aff", "##able"]
317 |
318 | Args:
319 | text: A single token or whitespace separated tokens. This should have
320 | already been passed through `BasicTokenizer.
321 |
322 | Returns:
323 | A list of wordpiece tokens.
324 | """
325 |
326 | text = convert_to_unicode(text)
327 |
328 | output_tokens = []
329 | for token in whitespace_tokenize(text):
330 | chars = list(token)
331 | if len(chars) > self.max_input_chars_per_word:
332 | output_tokens.append(self.unk_token)
333 | continue
334 |
335 | is_bad = False
336 | start = 0
337 | sub_tokens = []
338 | while start < len(chars):
339 | end = len(chars)
340 | cur_substr = None
341 | while start < end:
342 | substr = "".join(chars[start:end])
343 | if start > 0:
344 | substr = "##" + substr
345 | if substr in self.vocab:
346 | cur_substr = substr
347 | break
348 | end -= 1
349 | if cur_substr is None:
350 | is_bad = True
351 | break
352 | sub_tokens.append(cur_substr)
353 | start = end
354 |
355 | if is_bad:
356 | output_tokens.append(self.unk_token)
357 | else:
358 | output_tokens.extend(sub_tokens)
359 | return output_tokens
360 |
361 |
362 | def _is_whitespace(char):
363 | """Checks whether `chars` is a whitespace character."""
364 | # \t, \n, and \r are technically contorl characters but we treat them
365 | # as whitespace since they are generally considered as such.
366 | if char == " " or char == "\t" or char == "\n" or char == "\r":
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat == "Zs":
370 | return True
371 | return False
372 |
373 |
374 | def _is_control(char):
375 | """Checks whether `chars` is a control character."""
376 | # These are technically control characters but we count them as whitespace
377 | # characters.
378 | if char == "\t" or char == "\n" or char == "\r":
379 | return False
380 | cat = unicodedata.category(char)
381 | if cat in ("Cc", "Cf"):
382 | return True
383 | return False
384 |
385 |
386 | def _is_punctuation(char):
387 | """Checks whether `chars` is a punctuation character."""
388 | cp = ord(char)
389 | # We treat all non-letter/number ASCII as punctuation.
390 | # Characters such as "^", "$", and "`" are not in the Unicode
391 | # Punctuation class but we treat them as punctuation anyways, for
392 | # consistency.
393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat.startswith("P"):
398 | return True
399 | return False
400 |
--------------------------------------------------------------------------------
/yelp_preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Adv-Def-Text Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Dataset preprocessing.
18 | Author: Ying Xu
19 | Date: Jul 8, 2020
20 | """
21 |
22 | from __future__ import unicode_literals
23 | from argparse import ArgumentParser
24 |
25 | import json
26 | from tqdm import tqdm
27 | from collections import Counter
28 | import spacy
29 | nlp = spacy.blank("en")
30 | GLOVE_WORD_SIZE = int(2.2e6)
31 | CF_WORD_SIZE = 65713
32 |
33 | parser = ArgumentParser()
34 | parser.add_argument('--data_dir', default='/Users/yxu132/Downloads/yelp_dataset', type=str, help='path to DATA_DIR')
35 | parser.add_argument('--embed_file', default='/Users/yxu132/pub-repos/decaNLP/embeddings/glove.840B.300d.txt', type=str, help='path to glove embeding file')
36 | parser.add_argument('--para_limit', default=50, type=int, help='maximum number of words for each paragraph')
37 | args = parser.parse_args()
38 |
39 | def parse_json():
40 | texts = []
41 | ratings = []
42 | for line in open(args.data_dir+'/yelp_academic_dataset_review.json', 'r'):
43 | # for line in open(args.data_dir + '/sample.json', 'r'):
44 | example = json.loads(line)
45 | texts.append(example['text'].replace('\n', ' ').replace('\r', ''))
46 | ratings.append(example['stars'])
47 | with open(args.data_dir+'/yelp_review.full', 'w') as output_file:
48 | output_file.write('\n'.join(texts))
49 | with open(args.data_dir+'/yelp_review.ratings', 'w') as output_file:
50 | output_file.write('\n'.join([str(rating) for rating in ratings]))
51 |
52 | def readLinesList(filename):
53 | ret = []
54 | for line in open(filename, 'r'):
55 | ret.append(line.strip())
56 | return ret
57 |
58 | def read_lines():
59 | ret = []
60 | labels = readLinesList(args.data_dir+'/yelp_review.ratings')
61 | for ind, line in tqdm(enumerate(open(args.data_dir+'/yelp_review.full', 'r'))):
62 | line = line.strip().lower()
63 | line = line.replace('\\n', ' ').replace('\\', '')
64 | line = line.replace('(', ' (').replace(')', ') ')
65 | line = line.replace('!', '! ')
66 | line = ' '.join(line.split())
67 | example = {}
68 | example['text'] = line
69 | example['label'] = labels[ind]
70 | ret.append(example)
71 | return ret
72 |
73 | def get_tokenize(sent):
74 | sent = sent.replace(
75 | "''", '" ').replace("``", '" ')
76 | doc = nlp(sent)
77 | context_tokens = [token.text for token in doc]
78 | new_sent = ' '.join(context_tokens)
79 | return new_sent, context_tokens
80 |
81 | def tokenize_sentences(sentences, para_limit=None):
82 | print('Tokenize input sentences...')
83 | word_counter = Counter()
84 | context_list, context_tokens_list = [], []
85 | labels = []
86 | for sentence in tqdm(sentences):
87 | context, context_tokens = get_tokenize(sentence['text'])
88 | if len(context_tokens) > para_limit:
89 | continue
90 | for token in context_tokens:
91 | word_counter[token] += 1
92 | context_list.append(context)
93 | context_tokens_list.append(context_tokens)
94 | labels.append(sentence['label'])
95 | return context_list, context_tokens_list, labels, word_counter
96 |
97 | def filter_against_embedding(sentences, counter, emb_file, limit=-1,
98 | size=GLOVE_WORD_SIZE, vec_size=300):
99 |
100 | embedding_dict = {}
101 | filtered_elements = [k for k, v in counter.items() if v > limit]
102 | assert size is not None
103 | assert vec_size is not None
104 | with codecs.open(emb_file, "r", "utf-8") as fh:
105 | for line in tqdm(fh, total=size):
106 | array = line.split()
107 | word = "".join(array[0:-vec_size])
108 | vector = list(map(float, array[-vec_size:]))
109 | if word in counter and counter[word] > limit:
110 | embedding_dict[word] = vector
111 | print("{} / {} tokens have corresponding embedding vector".format(
112 | len(embedding_dict), len(filtered_elements)))
113 |
114 | embedding_tokens = set(embedding_dict.keys())
115 | filtered_sentences = []
116 | for sentence in sentences:
117 | tokens = sentence['text'].split()
118 | if len(set(tokens) - embedding_tokens) > 0:
119 | continue
120 | filtered_sentences.append(sentence)
121 |
122 | return filtered_sentences, embedding_dict
123 |
124 | def writeLines(llist, output_file):
125 | with codecs.open(output_file, "w", "utf-8") as output:
126 | output.write('\n'.join(llist))
127 |
128 | def get_embedding(counter, data_type, emb_file, limit=-1, size=None, vec_size=None):
129 | print("Generating {} embedding...".format(data_type))
130 | embedding_dict = {}
131 |
132 | filtered_elements = [k for k, v in counter.items() if v > limit]
133 | assert size is not None
134 | assert vec_size is not None
135 | with codecs.open(emb_file, "r", "utf-8") as fh:
136 | for line in tqdm(fh, total=size):
137 | array = line.split()
138 | word = "".join(array[0:-vec_size])
139 | vector = list(map(float, array[-vec_size:]))
140 | if word in counter and counter[word] > limit:
141 | embedding_dict[word] = vector
142 | missing_words = set(filtered_elements) - set(embedding_dict.keys())
143 | print('\n'.join(missing_words))
144 |
145 | print("{} / {} tokens have corresponding {} embedding vector".format(
146 | len(embedding_dict), len(filtered_elements), data_type))
147 |
148 | token2idx_dict = {token: idx for idx,
149 | token in enumerate(embedding_dict.keys(), 0)}
150 |
151 | idx2emb_dict = {idx: embedding_dict[token]
152 | for token, idx in token2idx_dict.items()}
153 | emb_mat = [idx2emb_dict[idx] for idx in range(len(idx2emb_dict))]
154 | return emb_mat, token2idx_dict
155 |
156 | def embed_sentences(word_counter, word_emb_file):
157 | word_emb_mat, word2idx_dict = get_embedding(
158 | word_counter, "word", emb_file=word_emb_file, size=GLOVE_WORD_SIZE, vec_size=300)
159 | return word_emb_mat, word2idx_dict
160 |
161 | def save(filename, obj, message=None):
162 | if message is not None:
163 | print("Saving {}...".format(message))
164 | with open(filename, "w") as fh:
165 | json.dump(obj, fh)
166 |
167 | import numpy as np
168 | def process():
169 |
170 | print("Step 2.1: Tokenize sentences...")
171 | sentences = read_lines()
172 | context_list, context_tokens_list, labels, word_counter = \
173 | tokenize_sentences(sentences, para_limit=args.para_limit)
174 | writeLines(context_list, args.data_dir+'/yelp.in')
175 | writeLines(labels, args.data_dir+'/yelp.out')
176 |
177 | print("\nStep 2.2: Filter dataset against glove embedding...")
178 | texts = readLinesList(args.data_dir+'/yelp.in')
179 | labels = readLinesList(args.data_dir+'/yelp.out')
180 | sentences = []
181 | for ind, text in enumerate(texts):
182 | sentence = {}
183 | sentence['text'] = text
184 | sentence['label'] = labels[ind]
185 | sentences.append(sentence)
186 | print('\nbefore filtering: '+str(len(sentences)))
187 |
188 | filtered_sentences, embed_dict = filter_against_embedding(sentences, word_counter, emb_file=args.embed_file)
189 | print('\nafter filtering: '+str(len(filtered_sentences)))
190 |
191 | texts = [sentence['text'] for sentence in filtered_sentences]
192 | labels = [sentence['label'] for sentence in filtered_sentences]
193 | writeLines(texts, args.data_dir + '/yelp_filtered.in')
194 | writeLines(labels,args.data_dir + '/yelp_filtered.out')
195 |
196 |
197 | print("\nStep 2.3: Split into train, dev and test datasets...")
198 | dev_test_percentage = 0.05
199 | sentences = []
200 | texts = readLinesList(args.data_dir+'/yelp_filtered.in')
201 | labels = readLinesList(args.data_dir+'/yelp_filtered.out')
202 | for ind, text in enumerate(texts):
203 | sentence={}
204 | sentence['text'] = text
205 | sentence['label'] = labels[ind]
206 | sentences.append(sentence)
207 | sentences = np.array(sentences)
208 |
209 | total = len(sentences)
210 | dev_test_num = int(total * dev_test_percentage)
211 | dev = sentences[:dev_test_num]
212 | test = sentences[dev_test_num: dev_test_num*2]
213 | train = sentences[dev_test_num*2: ]
214 |
215 | writeLines([sent['text'] for sent in train], args.data_dir + '/yelp_train.in')
216 | writeLines([sent['text'] for sent in dev], args.data_dir + '/yelp_dev.in')
217 | writeLines([sent['text'] for sent in test], args.data_dir + '/yelp_test.in')
218 | writeLines([sent['label'] for sent in train], args.data_dir + '/yelp_train.out')
219 | writeLines([sent['label'] for sent in dev], args.data_dir + '/yelp_dev.out')
220 | writeLines([sent['label'] for sent in test], args.data_dir + '/yelp_test.out')
221 |
222 | print("Step 2.4: Extract embeddings for filtered sentence vocabs...")
223 |
224 | sentences_tokens = [sent['text'].split() for sent in sentences]
225 | word_counter = dict()
226 | for sentence in sentences_tokens:
227 | for token in sentence:
228 | if token in word_counter:
229 | word_counter[token] = word_counter[token] + 1
230 | else:
231 | word_counter[token] = 1
232 |
233 | word_counter_new = sorted(word_counter.items(), key=lambda kv: (kv[1], kv[0]), reverse=True)
234 |
235 | vocab_output_file = codecs.open(args.data_dir + '/vocab_count.txt', "w", "utf-8")
236 | for word in word_counter_new:
237 | vocab_output_file.write(word[0]+' '+str(word[1])+'\n')
238 |
239 | word_emb_mat, word2idx_dict = embed_sentences(word_counter, word_emb_file=args.embed_file)
240 | writeLines(word2idx_dict.keys(), args.data_dir + '/vocab.in')
241 | save(args.data_dir + '/emb.json', word_emb_mat, message="word embedding")
242 |
243 | def binarise_and_balance():
244 | partitions = ['train', 'dev', 'test']
245 |
246 | for partition in partitions:
247 | sentences = readLinesList(args.data_dir+'/yelp_'+partition+'.in')
248 | pos_sents, neg_sents = [], []
249 | for ind, line in enumerate(open(args.data_dir+'/yelp_'+partition+'.out', 'r')):
250 | if line.strip() == '1.0' or line.strip() == '2.0':
251 | neg_sents.append(sentences[ind])
252 | elif line.strip() == '4.0' or line.strip() == '5.0':
253 | pos_sents.append(sentences[ind])
254 |
255 | np.random.seed(0)
256 | shuffled_ids = np.arange(len(pos_sents))
257 | np.random.shuffle(shuffled_ids)
258 | pos_sents = np.array(pos_sents)[shuffled_ids]
259 |
260 | sents = neg_sents + pos_sents.tolist()[:len(neg_sents)]
261 | labels = ['1.0 0.0'] * len(neg_sents) + ['0.0 1.0'] * len(neg_sents)
262 |
263 | shuffled_ids = np.arange(len(sents))
264 | np.random.shuffle(shuffled_ids)
265 | sents = np.array(sents)[shuffled_ids]
266 | labels = np.array(labels)[shuffled_ids]
267 |
268 | with open(args.data_dir+'/'+partition+'.in', 'w') as output_file:
269 | for line in sents:
270 | output_file.write(line+'\n')
271 | with open(args.data_dir+'/'+partition+'.out', 'w') as output_file:
272 | for line in labels:
273 | output_file.write(line+'\n')
274 |
275 |
276 | ###################### CF embedding ###################
277 |
278 | import codecs
279 | import os
280 |
281 | def parse_cf_emb(cf_file_path):
282 |
283 | vocab = []
284 | matrix = []
285 | for line in tqdm(open(cf_file_path, 'r'), total=CF_WORD_SIZE):
286 | comps = line.strip().split()
287 | word = ''.join(comps[0:-300])
288 | vec = comps[-300:]
289 | vocab.append(word)
290 | matrix.append(vec)
291 | writeLines(vocab, 'embeddings/counter-fitted-vectors-vocab.txt')
292 | json.dump(matrix, open('embeddings/counter-fitted-vectors-emb.json', 'w'))
293 |
294 | def transform_cf_emb():
295 |
296 | if not os.path.exists('embeddings/counter-fitted-vectors-vocab.txt') or \
297 | not os.path.exists('embeddings/counter-fitted-vectors-emb.json'):
298 | parse_cf_emb('embeddings/counter-fitted-vectors.txt')
299 |
300 |
301 | vocab = readLinesList(args.data_dir + '/vocab.txt')
302 | cf_vocab = readLinesList('embeddings/counter-fitted-vectors-vocab.txt')
303 |
304 | print('glove_vocab_size: '+str(len(vocab)))
305 | print('cf_vocab_size: ' + str(len(cf_vocab)))
306 |
307 | with codecs.open(args.data_dir + '/emb.json', "r", "utf-8") as fh:
308 | emb = json.load(fh)
309 |
310 | with codecs.open('embeddings/counter-fitted-vectors-emb.json', "r", "utf-8") as fh:
311 | cf_emb = json.load(fh)
312 |
313 | vocab_diff = []
314 | vocab_diff_ind = []
315 | for ind, word in enumerate(vocab):
316 | if word not in cf_vocab:
317 | vocab_diff.append(word)
318 | vocab_diff_ind.append(ind)
319 |
320 | print('extend_vocab_size: ' + str(len(vocab_diff_ind)))
321 |
322 |
323 | new_cf_vocab = cf_vocab + vocab_diff
324 | new_emb = cf_emb
325 | for ind, word in enumerate(vocab_diff):
326 | new_emb.append(emb[vocab_diff_ind[ind]])
327 |
328 | print('combined_cf_vocab_size: ' + str(len(new_emb)))
329 |
330 | writeLines(new_cf_vocab, args.data_dir + '/cf_vocab.in')
331 | json.dump(new_emb, open(args.data_dir + '/cf_emb.json', 'w'))
332 |
333 | def split_pos_neg():
334 |
335 | input_sents = readLinesList(args.data_dir+'/train.in')
336 | labels = readLinesList(args.data_dir+'/train.out')
337 |
338 | pos_out_file = open(args.data_dir+'/train.pos.in', 'w')
339 | neg_out_file = open(args.data_dir+'/train.neg.in', 'w')
340 | pos_lab_file = open(args.data_dir+'/train.pos.out', 'w')
341 | neg_lab_file = open(args.data_dir+'/train.neg.out', 'w')
342 |
343 | for ind, sent in enumerate(input_sents):
344 | label = labels[ind]
345 | if label == '1.0 0.0':
346 | neg_out_file.write(sent+'\n')
347 | neg_lab_file.write(label+'\n')
348 | elif label == '0.0 1.0':
349 | pos_out_file.write(sent + '\n')
350 | pos_lab_file.write(label + '\n')
351 |
352 |
353 | if __name__ == '__main__':
354 | print("Step 1: Parse json file...")
355 | parse_json()
356 | print("\nStep 2: Data partition/GloVe embedding extraction...")
357 | process()
358 | print("\nStep 3: Binarise and downsampling...")
359 | binarise_and_balance()
360 | print("\nStep 4: Counter-fitted embedding extraction...")
361 | transform_cf_emb()
362 | print("\nStep 5: Split train set into pos/neg examples (for conditional generation only)...")
363 | split_pos_neg()
364 |
365 |
366 |
--------------------------------------------------------------------------------
/misc/acc_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Adv-Def-Text Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Author: Jey Han Lau
18 | Date: Jul 19
19 | """
20 |
21 | import torch
22 | import numpy as np
23 | from transformers import XLNetTokenizer, XLNetLMHeadModel
24 | from scipy.special import softmax
25 | from argparse import ArgumentParser
26 |
27 | def config():
28 | parser = ArgumentParser()
29 | # basic
30 | parser.add_argument('--file_dir', type=str, default=None, help="data directory")
31 | parser.add_argument('--ids_file', type=str, default=None, help="list of ids to eval")
32 | parser.add_argument('--id', type=str, default=None, help="single setting to evaluate")
33 | parser.add_argument('--parsed_file', type=str, default=None, help='')
34 | parser.add_argument('--accept_name', type=str, default='xlnet', help='bert or xlnet')
35 |
36 | args = parser.parse_args()
37 |
38 | model_name = 'xlnet-large-cased'
39 | args.tokenizer = XLNetTokenizer.from_pretrained(model_name)
40 | args.acpt_model = XLNetLMHeadModel.from_pretrained(model_name)
41 |
42 | args.device = torch.device('cuda:0')
43 | args.acpt_model.to(args.device)
44 | args.acpt_model.eval()
45 |
46 | return args
47 |
48 | # global
49 | PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
50 | (except for Alexei and Maria) are discovered.
51 | The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
52 | remainder of the story. 1883 Western Siberia,
53 | a young Grigori Rasputin is asked by his father and a group of men to perform magic.
54 | Rasputin has a vision and denounces one of the men as a horse thief. Although his
55 | father initially slaps him for making such an accusation, Rasputin watches as the
56 | man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
57 | the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
58 | with people, even a bishop, begging for his blessing. """
59 |
60 |
61 | def readlines(file_name):
62 | ret = []
63 | for line in open(file_name, 'r'):
64 | ret.append(line.strip())
65 | return ret
66 |
67 | # ###########
68 | # # functions#
69 | # ###########
70 | # def model_score(tokenize_input, tokenize_context, model, tokenizer, device, model_name='xlnet', use_context=False):
71 | # if model_name.startswith("gpt"):
72 | # if not use_context:
73 | # # prepend the sentence with <|endoftext|> token, so that the loss is computed correctly
74 | # tensor_input = torch.tensor([[50256] + tokenizer.convert_tokens_to_ids(tokenize_input)], device=device)
75 | # loss = model(tensor_input, labels=tensor_input)
76 | # else:
77 | # tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_context + tokenize_input)], device=device)
78 | # labels = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_context + tokenize_input)], device=device)
79 | # # -1 label for context (loss not computed over these tokens)
80 | # labels[:, :len(tokenize_context)] = -1
81 | # loss = model(tensor_input, labels=labels)
82 | # return float(loss[0]) * -1.0 * len(tokenize_input)
83 | # elif model_name.startswith("bert"):
84 | # batched_indexed_tokens = []
85 | # batched_segment_ids = []
86 | # if not use_context:
87 | # tokenize_combined = ["[CLS]"] + tokenize_input + ["[SEP]"]
88 | # else:
89 | # tokenize_combined = ["[CLS]"] + tokenize_context + tokenize_input + ["[SEP]"]
90 | # for i in range(len(tokenize_input)):
91 | # # Mask a token that we will try to predict back with `BertForMaskedLM`
92 | # masked_index = i + 1 + (len(tokenize_context) if use_context else 0)
93 | # tokenize_masked = tokenize_combined.copy()
94 | # tokenize_masked[masked_index] = '[MASK]'
95 | #
96 | # # Convert token to vocabulary indices
97 | # indexed_tokens = tokenizer.convert_tokens_to_ids(tokenize_masked)
98 | # # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
99 | # segment_ids = [0] * len(tokenize_masked)
100 | #
101 | # batched_indexed_tokens.append(indexed_tokens)
102 | # batched_segment_ids.append(segment_ids)
103 | #
104 | # # Convert inputs to PyTorch tensors
105 | # tokens_tensor = torch.tensor(batched_indexed_tokens, device=device)
106 | # segment_tensor = torch.tensor(batched_segment_ids, device=device)
107 | #
108 | # # Predict all tokens
109 | # with torch.no_grad():
110 | # outputs = model(tokens_tensor, token_type_ids=segment_tensor)
111 | # predictions = outputs[0]
112 | # # go through each word and sum their logprobs
113 | # lp = 0.0
114 | # for i in range(len(tokenize_input)):
115 | # masked_index = i + 1 + (len(tokenize_context) if use_context else 0)
116 | # predicted_score = predictions[i, masked_index]
117 | # predicted_prob = softmax(predicted_score.cpu().numpy())
118 | # lp += np.log(predicted_prob[tokenizer.convert_tokens_to_ids([tokenize_combined[masked_index]])[0]])
119 | # return lp
120 | # elif model_name.startswith("xlnet"):
121 | # xlnet_bidir=True
122 | # tokenize_ptext = tokenizer.tokenize(PADDING_TEXT.lower())
123 | # if not use_context:
124 | # tokenize_input2 = tokenize_ptext + tokenize_input
125 | # else:
126 | # tokenize_input2 = tokenize_ptext + tokenize_context + tokenize_input
127 | # # go through each word and sum their logprobs
128 | # lp = 0.0
129 | # for max_word_id in range((len(tokenize_input2) - len(tokenize_input)), (len(tokenize_input2))):
130 | # sent = tokenize_input2[:]
131 | # input_ids = torch.tensor([tokenizer.convert_tokens_to_ids(sent)], device=device)
132 | # perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
133 | # # if not bidir, mask target word + right/future words
134 | # if not xlnet_bidir:
135 | # perm_mask[:, :, max_word_id:] = 1.0
136 | # # if bidir, mask only the target word
137 | # else:
138 | # perm_mask[:, :, max_word_id] = 1.0
139 | #
140 | # target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
141 | # target_mapping[0, 0, max_word_id] = 1.0
142 | # with torch.no_grad():
143 | # outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
144 | # next_token_logits = outputs[0]
145 | #
146 | # word_id = tokenizer.convert_tokens_to_ids([tokenize_input2[max_word_id]])[0]
147 | # predicted_prob = softmax((next_token_logits[0][-1]).cpu().numpy())
148 | # lp += np.log(predicted_prob[word_id])
149 | # return lp
150 |
151 | def bert_model_score(tokenize_input, args):
152 | batched_indexed_tokens = []
153 | batched_segment_ids = []
154 | # if not args.use_context:
155 | tokenize_combined = ["[CLS]"] + tokenize_input + ["[SEP]"]
156 | # else:
157 | # tokenize_combined = ["[CLS]"] + tokenize_context + tokenize_input + ["[SEP]"]
158 | for i in range(len(tokenize_input)):
159 | # Mask a token that we will try to predict back with `BertForMaskedLM`
160 | masked_index = i + 1
161 | tokenize_masked = tokenize_combined.copy()
162 | tokenize_masked[masked_index] = '[MASK]'
163 |
164 | # Convert token to vocabulary indices
165 | indexed_tokens = args.tokenizer.convert_tokens_to_ids(tokenize_masked)
166 | # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
167 | segment_ids = [0] * len(tokenize_masked)
168 |
169 | batched_indexed_tokens.append(indexed_tokens)
170 | batched_segment_ids.append(segment_ids)
171 |
172 | # Convert inputs to PyTorch tensors
173 | tokens_tensor = torch.tensor(batched_indexed_tokens, device=args.device)
174 | segment_tensor = torch.tensor(batched_segment_ids, device=args.device)
175 |
176 | # Predict all tokens
177 | with torch.no_grad():
178 | outputs = args.acpt_model(tokens_tensor, token_type_ids=segment_tensor)
179 | predictions = outputs[0]
180 | # go through each word and sum their logprobs
181 | lp = 0.0
182 | for i in range(len(tokenize_input)):
183 | masked_index = i + 1
184 | predicted_score = predictions[i, masked_index]
185 | predicted_prob = softmax(predicted_score.cpu().numpy())
186 | lp += np.log(predicted_prob[args.tokenizer.convert_tokens_to_ids([tokenize_combined[masked_index]])[0]])
187 | return lp
188 |
189 | def xlnet_model_score(tokenize_input, args):
190 | xlnet_bidir = True
191 | tokenize_context = None
192 | tokenize_ptext = args.tokenizer.tokenize(PADDING_TEXT.lower())
193 | # if not args.use_context:
194 | tokenize_input2 = tokenize_ptext + tokenize_input
195 | # else:
196 | # tokenize_input2 = tokenize_ptext + tokenize_context + tokenize_input
197 | # go through each word and sum their logprobs
198 | lp = 0.0
199 | for max_word_id in range((len(tokenize_input2) - len(tokenize_input)), (len(tokenize_input2))):
200 | sent = tokenize_input2[:]
201 | input_ids = torch.tensor([args.tokenizer.convert_tokens_to_ids(sent)], device=args.device)
202 | perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=args.device)
203 | # if not bidir, mask target word + right/future words
204 | if not xlnet_bidir:
205 | perm_mask[:, :, max_word_id:] = 1.0
206 | # if bidir, mask only the target word
207 | else:
208 | perm_mask[:, :, max_word_id] = 1.0
209 |
210 | target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=args.device)
211 | target_mapping[0, 0, max_word_id] = 1.0
212 | with torch.no_grad():
213 | outputs = args.acpt_model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
214 | next_token_logits = outputs[0]
215 |
216 | word_id = args.tokenizer.convert_tokens_to_ids([tokenize_input2[max_word_id]])[0]
217 | predicted_prob = softmax((next_token_logits[0][-1]).cpu().numpy())
218 | lp += np.log(predicted_prob[word_id])
219 | return lp
220 |
221 | def evaluate_accept(references, translations, args):
222 |
223 | scores = []
224 |
225 | for ref, translation in zip(references, translations):
226 | tokenize_input = args.tokenizer.tokenize(ref)
227 | text_len = len(tokenize_input)
228 | if args.accept_name == 'bert':
229 | # compute sentence logprob
230 | lp = bert_model_score(tokenize_input, args)
231 | elif args.accept_name == 'xlnet':
232 | lp = xlnet_model_score(tokenize_input, args)
233 | # acceptability measures
234 | penalty = ((5 + text_len) ** 0.8 / (5 + 1) ** 0.8)
235 | score_ref = lp / penalty
236 | # print(score_ref)
237 |
238 | tokenize_input = args.tokenizer.tokenize(translation)
239 | text_len = len(tokenize_input)
240 | if args.accept_name == 'bert':
241 | # compute sentence logprob
242 | lp = bert_model_score(tokenize_input, args)
243 | elif args.accept_name == 'xlnet':
244 | lp = xlnet_model_score(tokenize_input, args)
245 | # acceptability measures
246 | penalty = ((5 + text_len) ** 0.8 / (5 + 1) ** 0.8)
247 | score_trans = lp / penalty
248 | # print(score_trans)
249 | scores.append(score_trans-score_ref)
250 |
251 | return scores
252 |
253 |
254 | # ######
255 | # # main#
256 | # ######
257 | # def main(args):
258 | # # system scores
259 | # mean_lps = []
260 | # # Load pre-trained model and tokenizer
261 | # args.model_name = 'bert-base-uncased'
262 | # if args.model_name.startswith("gpt"):
263 | # model = GPT2LMHeadModel.from_pretrained(args.model_name)
264 | # tokenizer = GPT2Tokenizer.from_pretrained(args.model_name)
265 | # elif args.model_name.startswith("bert"):
266 | # model = BertForMaskedLM.from_pretrained(args.model_name)
267 | # tokenizer = BertTokenizer.from_pretrained(args.model_name,
268 | # do_lower_case=(True if "uncased" in args.model_name else False))
269 | # elif args.model_name.startswith("xlnet"):
270 | # tokenizer = XLNetTokenizer.from_pretrained(args.model_name)
271 | # model = XLNetLMHeadModel.from_pretrained(args.model_name)
272 | # else:
273 | # print("Supported models: gpt, bert and xlnet only.")
274 | # raise SystemExit
275 | #
276 | # # put model to device (GPU/CPU)
277 | # device = torch.device(args.device)
278 | # model.to(device)
279 | #
280 | # # eval mode; no dropout
281 | # model.eval()
282 | #
283 | # text = "400 was great ! he was very helpful and i incubates comparitive back and visit again . i really playmobile the theme and the staff was very friendly ! a herb to guppy !"
284 | #
285 | # tokenize_input = tokenizer.tokenize(text)
286 | # text_len = len(tokenize_input)
287 | #
288 | # # compute sentence logprob
289 | # lp = model_score(tokenize_input, None, model, tokenizer, device, args.model_name)
290 | #
291 | # # acceptability measures
292 | # penalty = ((5 + text_len) ** 0.8 / (5 + 1) ** 0.8)
293 | # print("score=", lp / penalty)
294 |
295 |
296 |
297 | if __name__ == "__main__":
298 | args = config()
299 |
300 | if args.ids_file is not None:
301 | ids = readlines(args.ids_file)
302 | for id in ids:
303 | comps = id.split()
304 | src_sents = readlines(args.file_dir+'/'+comps[1]+'/s2s_output/src_changed.txt')
305 | tgt_sents = readlines(args.file_dir+'/'+comps[1]+'/s2s_output/adv_changed.txt')
306 | scores = evaluate_accept(src_sents, tgt_sents, args)
307 | print(comps[0]+': '+str(sum(scores)/len(scores)))
308 | with open(args.file_dir+'/'+comps[1]+'/s2s_output/accept_changed.txt', 'w') as output_file:
309 | for score in scores:
310 | output_file.write(str(score)+'\n')
311 | elif args.id is not None:
312 | src_sents = readlines(args.file_dir + '/' + args.id + '/s2s_output/src_changed.txt')
313 | tgt_sents = readlines(args.file_dir + '/' + args.id + '/s2s_output/adv_changed.txt')
314 | scores = evaluate_accept(src_sents, tgt_sents, args)
315 | print('accept score: ' + str(sum(scores) / len(scores)))
316 | with open(args.file_dir + '/' + args.id + '/s2s_output/accept_changed.txt', 'w') as output_file:
317 | for score in scores:
318 | output_file.write(str(score) + '\n')
319 | elif args.parsed_file is not None:
320 | src_sents, tgt_sents = [], []
321 | is_src = True
322 | for line in open(args.parsed_file, 'r'):
323 | if line.strip() != '':
324 | comps = line.strip().split('\t')
325 | if is_src:
326 | src_sents.append(comps[1])
327 | else:
328 | tgt_sents.append(comps[1])
329 | is_src = (not is_src)
330 | scores = evaluate_accept(src_sents, tgt_sents, args)
331 | print('accept score: ' + str(sum(scores) / len(scores)))
332 |
333 |
--------------------------------------------------------------------------------
/models/bert/input_data.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import json
3 | import collections
4 | import six
5 | import models.bert.tokenization as tokenization
6 |
7 | class SquadExample(object):
8 | """A single training/test example for simple sequence classification.
9 |
10 | For examples without an answer, the start and end position are -1.
11 | """
12 |
13 | def __init__(self,
14 | qas_id,
15 | question_text,
16 | doc_tokens,
17 | orig_answer_text=None,
18 | start_position=None,
19 | end_position=None,
20 | is_impossible=False):
21 | self.qas_id = qas_id
22 | self.question_text = question_text
23 | self.doc_tokens = doc_tokens
24 | self.orig_answer_text = orig_answer_text
25 | self.start_position = start_position
26 | self.end_position = end_position
27 | self.is_impossible = is_impossible
28 |
29 | def __str__(self):
30 | return self.__repr__()
31 |
32 | def __repr__(self):
33 | s = ""
34 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
35 | s += ", question_text: %s" % (
36 | tokenization.printable_text(self.question_text))
37 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
38 | if self.start_position:
39 | s += ", start_position: %d" % (self.start_position)
40 | if self.start_position:
41 | s += ", end_position: %d" % (self.end_position)
42 | if self.start_position:
43 | s += ", is_impossible: %r" % (self.is_impossible)
44 | return s
45 |
46 |
47 | class InputFeatures(object):
48 | """A single set of features of data."""
49 |
50 | def __init__(self,
51 | unique_id,
52 | example_index,
53 | doc_span_index,
54 | tokens,
55 | token_to_orig_map,
56 | token_is_max_context,
57 | input_ids,
58 | input_mask,
59 | segment_ids,
60 | start_position=None,
61 | end_position=None,
62 | is_impossible=None):
63 | self.unique_id = unique_id
64 | self.example_index = example_index
65 | self.doc_span_index = doc_span_index
66 | self.tokens = tokens
67 | self.token_to_orig_map = token_to_orig_map
68 | self.token_is_max_context = token_is_max_context
69 | self.input_ids = input_ids
70 | self.input_mask = input_mask
71 | self.segment_ids = segment_ids
72 | self.start_position = start_position
73 | self.end_position = end_position
74 | self.is_impossible = is_impossible
75 |
76 |
77 | class FeatureWriter(object):
78 | """Writes InputFeature to TF example file."""
79 |
80 | def __init__(self, filename, is_training):
81 | self.filename = filename
82 | self.is_training = is_training
83 | self.num_features = 0
84 | self._writer = tf.python_io.TFRecordWriter(filename)
85 |
86 | def process_feature(self, feature):
87 | """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
88 | self.num_features += 1
89 |
90 | def create_int_feature(values):
91 | feature = tf.train.Feature(
92 | int64_list=tf.train.Int64List(value=list(values)))
93 | return feature
94 |
95 | features = collections.OrderedDict()
96 | features["unique_ids"] = create_int_feature([feature.unique_id])
97 | features["input_ids"] = create_int_feature(feature.input_ids)
98 | features["input_mask"] = create_int_feature(feature.input_mask)
99 | features["segment_ids"] = create_int_feature(feature.segment_ids)
100 |
101 | if self.is_training:
102 | features["start_positions"] = create_int_feature([feature.start_position])
103 | features["end_positions"] = create_int_feature([feature.end_position])
104 | impossible = 0
105 | if feature.is_impossible:
106 | impossible = 1
107 | features["is_impossible"] = create_int_feature([impossible])
108 |
109 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
110 | self._writer.write(tf_example.SerializeToString())
111 |
112 | def close(self):
113 | self._writer.close()
114 |
115 | def read_squad_examples(input_file, is_training):
116 | """Read a SQuAD json file into a list of SquadExample."""
117 | with tf.gfile.Open(input_file, "r") as reader:
118 | input_data = json.load(reader)["data"]
119 |
120 | def is_whitespace(c):
121 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
122 | return True
123 | return False
124 |
125 | examples = []
126 | for entry in input_data:
127 | # if len(examples) > 20:
128 | # break
129 | for paragraph in entry["paragraphs"]:
130 | # if len(examples) > 20:
131 | # break
132 | paragraph_text = paragraph["context"]
133 | doc_tokens = []
134 | char_to_word_offset = []
135 | prev_is_whitespace = True
136 | for c in paragraph_text:
137 | if is_whitespace(c):
138 | prev_is_whitespace = True
139 | else:
140 | if prev_is_whitespace:
141 | doc_tokens.append(c)
142 | else:
143 | doc_tokens[-1] += c
144 | prev_is_whitespace = False
145 | char_to_word_offset.append(len(doc_tokens) - 1)
146 |
147 | for qa in paragraph["qas"]:
148 | # if len(examples) > 20:
149 | # break
150 | qas_id = qa["id"]
151 | question_text = qa["question"]
152 | start_position = None
153 | end_position = None
154 | orig_answer_text = None
155 | is_impossible = False
156 | if is_training:
157 | # if FLAGS.version_2_with_negative:
158 | # is_impossible = bert["is_impossible"]
159 | # if (len(bert["answers"]) != 1) and (not is_impossible):
160 | # raise ValueError(
161 | # "For training, each question should have exactly 1 answer.")
162 | if not is_impossible:
163 | answer = qa["answers"][0]
164 | orig_answer_text = answer["text"]
165 | answer_offset = answer["answer_start"]
166 | answer_length = len(orig_answer_text)
167 | start_position = char_to_word_offset[answer_offset]
168 | end_position = char_to_word_offset[answer_offset + answer_length -
169 | 1]
170 | # Only add answers where the text can be exactly recovered from the
171 | # document. If this CAN'T happen it's likely due to weird Unicode
172 | # stuff so we will just skip the example.
173 | #
174 | # Note that this means for training mode, every example is NOT
175 | # guaranteed to be preserved.
176 | actual_text = " ".join(
177 | doc_tokens[start_position:(end_position + 1)])
178 | cleaned_answer_text = " ".join(
179 | tokenization.whitespace_tokenize(orig_answer_text))
180 | if actual_text.find(cleaned_answer_text) == -1:
181 | tf.logging.warning("Could not find answer: '%s' vs. '%s'",
182 | actual_text, cleaned_answer_text)
183 | continue
184 | else:
185 | start_position = -1
186 | end_position = -1
187 | orig_answer_text = ""
188 | else:
189 | orig_answer_text=[]
190 | for qa_single in qa["answers"]:
191 | orig_answer_text.append(qa_single['text'])
192 |
193 | example = SquadExample(
194 | qas_id=qas_id,
195 | question_text=question_text,
196 | doc_tokens=doc_tokens,
197 | orig_answer_text=orig_answer_text,
198 | start_position=start_position,
199 | end_position=end_position,
200 | is_impossible=is_impossible)
201 | examples.append(example)
202 |
203 | return examples
204 |
205 |
206 | def _check_is_max_context(doc_spans, cur_span_index, position):
207 | """Check if this is the 'max context' doc span for the token."""
208 |
209 | # Because of the sliding window approach taken to scoring documents, a single
210 | # token can appear in multiple documents. E.g.
211 | # Doc: the man went to the store and bought a gallon of milk
212 | # Span A: the man went to the
213 | # Span B: to the store and bought
214 | # Span C: and bought a gallon of
215 | # ...
216 | #
217 | # Now the word 'bought' will have two scores from spans B and C. We only
218 | # want to consider the score with "maximum context", which we define as
219 | # the *minimum* of its left and right context (the *sum* of left and
220 | # right context will always be the same, of course).
221 | #
222 | # In the example the maximum context for 'bought' would be span C since
223 | # it has 1 left context and 3 right context, while span B has 4 left context
224 | # and 0 right context.
225 | best_score = None
226 | best_span_index = None
227 | for (span_index, doc_span) in enumerate(doc_spans):
228 | end = doc_span.start + doc_span.length - 1
229 | if position < doc_span.start:
230 | continue
231 | if position > end:
232 | continue
233 | num_left_context = position - doc_span.start
234 | num_right_context = end - position
235 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
236 | if best_score is None or score > best_score:
237 | best_score = score
238 | best_span_index = span_index
239 |
240 | return cur_span_index == best_span_index
241 |
242 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
243 | orig_answer_text):
244 | """Returns tokenized answer spans that better match the annotated answer."""
245 |
246 | # The SQuAD annotations are character based. We first project them to
247 | # whitespace-tokenized words. But then after WordPiece tokenization, we can
248 | # often find a "better match". For example:
249 | #
250 | # Question: What year was John Smith born?
251 | # Context: The leader was John Smith (1895-1943).
252 | # Answer: 1895
253 | #
254 | # The original whitespace-tokenized answer will be "(1895-1943).". However
255 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
256 | # the exact answer, 1895.
257 | #
258 | # However, this is not always possible. Consider the following:
259 | #
260 | # Question: What country is the top exporter of electornics?
261 | # Context: The Japanese electronics industry is the lagest in the world.
262 | # Answer: Japan
263 | #
264 | # In this case, the annotator chose "Japan" as a character sub-span of
265 | # the word "Japanese". Since our WordPiece tokenizer does not split
266 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
267 | # in SQuAD, but does happen.
268 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
269 |
270 | for new_start in range(input_start, input_end + 1):
271 | for new_end in range(input_end, new_start - 1, -1):
272 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
273 | if text_span == tok_answer_text:
274 | return (new_start, new_end)
275 |
276 | return (input_start, input_end)
277 |
278 | def convert_examples_to_features(examples, tokenizer, max_seq_length,
279 | doc_stride, max_query_length, is_training,
280 | output_fn):
281 | """Loads a data file into a list of `InputBatch`s."""
282 |
283 | unique_id = 1000000000
284 |
285 | for (example_index, example) in enumerate(examples):
286 | query_tokens = tokenizer.tokenize(example.question_text)
287 |
288 | if len(query_tokens) > max_query_length:
289 | query_tokens = query_tokens[0:max_query_length]
290 |
291 | tok_to_orig_index = []
292 | orig_to_tok_index = []
293 | all_doc_tokens = []
294 | for (i, token) in enumerate(example.doc_tokens):
295 | orig_to_tok_index.append(len(all_doc_tokens))
296 | sub_tokens = tokenizer.tokenize(token)
297 | for sub_token in sub_tokens:
298 | tok_to_orig_index.append(i)
299 | all_doc_tokens.append(sub_token)
300 |
301 | tok_start_position = None
302 | tok_end_position = None
303 | if is_training and example.is_impossible:
304 | tok_start_position = -1
305 | tok_end_position = -1
306 | if is_training and not example.is_impossible:
307 | tok_start_position = orig_to_tok_index[example.start_position]
308 | if example.end_position < len(example.doc_tokens) - 1:
309 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
310 | else:
311 | tok_end_position = len(all_doc_tokens) - 1
312 | (tok_start_position, tok_end_position) = _improve_answer_span(
313 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
314 | example.orig_answer_text)
315 |
316 | # The -3 accounts for [CLS], [SEP] and [SEP]
317 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
318 |
319 | # We can have documents that are longer than the maximum sequence length.
320 | # To deal with this we do a sliding window approach, where we take chunks
321 | # of the up to our max length with a stride of `doc_stride`.
322 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
323 | "DocSpan", ["start", "length"])
324 | doc_spans = []
325 | start_offset = 0
326 | while start_offset < len(all_doc_tokens):
327 | length = len(all_doc_tokens) - start_offset
328 | if length > max_tokens_for_doc:
329 | length = max_tokens_for_doc
330 | doc_spans.append(_DocSpan(start=start_offset, length=length))
331 | if start_offset + length == len(all_doc_tokens):
332 | break
333 | start_offset += min(length, doc_stride)
334 |
335 | for (doc_span_index, doc_span) in enumerate(doc_spans):
336 | tokens = []
337 | token_to_orig_map = {}
338 | token_is_max_context = {}
339 | segment_ids = []
340 | tokens.append("[CLS]")
341 | segment_ids.append(0)
342 | for token in query_tokens:
343 | tokens.append(token)
344 | segment_ids.append(0)
345 | tokens.append("[SEP]")
346 | segment_ids.append(0)
347 |
348 | for i in range(doc_span.length):
349 | split_token_index = doc_span.start + i
350 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
351 |
352 | is_max_context = _check_is_max_context(doc_spans, doc_span_index,
353 | split_token_index)
354 | token_is_max_context[len(tokens)] = is_max_context
355 | tokens.append(all_doc_tokens[split_token_index])
356 | segment_ids.append(1)
357 | tokens.append("[SEP]")
358 | segment_ids.append(1)
359 |
360 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
361 |
362 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
363 | # tokens are attended to.
364 | input_mask = [1] * len(input_ids)
365 |
366 | # Zero-pad up to the sequence length.
367 | while len(input_ids) < max_seq_length:
368 | input_ids.append(0)
369 | input_mask.append(0)
370 | segment_ids.append(0)
371 |
372 | assert len(input_ids) == max_seq_length
373 | assert len(input_mask) == max_seq_length
374 | assert len(segment_ids) == max_seq_length
375 |
376 | start_position = None
377 | end_position = None
378 | if is_training and not example.is_impossible:
379 | # For training, if our document chunk does not contain an annotation
380 | # we throw it out, since there is nothing to predict.
381 | doc_start = doc_span.start
382 | doc_end = doc_span.start + doc_span.length - 1
383 | out_of_span = False
384 | if not (tok_start_position >= doc_start and
385 | tok_end_position <= doc_end):
386 | out_of_span = True
387 | if out_of_span:
388 | start_position = 0
389 | end_position = 0
390 | else:
391 | doc_offset = len(query_tokens) + 2
392 | start_position = tok_start_position - doc_start + doc_offset
393 | end_position = tok_end_position - doc_start + doc_offset
394 |
395 | if is_training and example.is_impossible:
396 | start_position = 0
397 | end_position = 0
398 |
399 | if example_index < 20:
400 | tf.logging.info("*** Example ***")
401 | tf.logging.info("unique_id: %s" % (unique_id))
402 | tf.logging.info("example_index: %s" % (example_index))
403 | tf.logging.info("doc_span_index: %s" % (doc_span_index))
404 | tf.logging.info("tokens: %s" % " ".join(
405 | [tokenization.printable_text(x) for x in tokens]))
406 | tf.logging.info("token_to_orig_map: %s" % " ".join(
407 | ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)]))
408 | tf.logging.info("token_is_max_context: %s" % " ".join([
409 | "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context)
410 | ]))
411 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
412 | tf.logging.info(
413 | "input_mask: %s" % " ".join([str(x) for x in input_mask]))
414 | tf.logging.info(
415 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
416 | if is_training and example.is_impossible:
417 | tf.logging.info("impossible example")
418 | if is_training and not example.is_impossible:
419 | answer_text = " ".join(tokens[start_position:(end_position + 1)])
420 | tf.logging.info("start_position: %d" % (start_position))
421 | tf.logging.info("end_position: %d" % (end_position))
422 | tf.logging.info(
423 | "answer: %s" % (tokenization.printable_text(answer_text)))
424 |
425 | feature = InputFeatures(
426 | unique_id=unique_id,
427 | example_index=example_index,
428 | doc_span_index=doc_span_index,
429 | tokens=tokens,
430 | token_to_orig_map=token_to_orig_map,
431 | token_is_max_context=token_is_max_context,
432 | input_ids=input_ids,
433 | input_mask=input_mask,
434 | segment_ids=segment_ids,
435 | start_position=start_position,
436 | end_position=end_position,
437 | is_impossible=example.is_impossible)
438 |
439 | # Run callback
440 | output_fn(feature)
441 |
442 | unique_id += 1
443 |
444 | def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
445 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
446 |
447 | name_to_features = {
448 | "unique_ids": tf.FixedLenFeature([], tf.int64),
449 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
450 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
451 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
452 | }
453 |
454 | if is_training:
455 | name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
456 | name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)
457 |
458 | def _decode_record(record, name_to_features):
459 | """Decodes a record to a TensorFlow example."""
460 | example = tf.parse_single_example(record, name_to_features)
461 |
462 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
463 | # So cast all int64 to int32.
464 | for name in list(example.keys()):
465 | t = example[name]
466 | if t.dtype == tf.int64:
467 | t = tf.to_int32(t)
468 | example[name] = t
469 |
470 | return example
471 |
472 | def input_fn(params):
473 | """The actual input function."""
474 | batch_size = params["batch_size"]
475 |
476 | # For training, we want a lot of parallel reading and shuffling.
477 | # For eval, we want no shuffling and parallel reading doesn't matter.
478 | d = tf.data.TFRecordDataset(input_file)
479 | if is_training:
480 | d = d.repeat()
481 | d = d.shuffle(buffer_size=100)
482 |
483 | d = d.apply(
484 | tf.contrib.data.map_and_batch(
485 | lambda record: _decode_record(record, name_to_features),
486 | batch_size=batch_size,
487 | drop_remainder=drop_remainder))
488 |
489 | return d
490 |
491 | return input_fn
492 |
493 |
--------------------------------------------------------------------------------