├── __init__.py ├── data_loader ├── __init__.py ├── load_wiki.py ├── load_sated.py ├── load_reddit.py └── load_cornell_movie.py ├── LICENSE ├── README.md ├── reddit_lm.py ├── auditing.py ├── dialogue.py ├── dialogue_ranks.py ├── sated_nmt.py ├── reddit_lm_ranks.py ├── sated_nmt_ranks.py └── helper.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Congzheng Song 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Auditing Data Provenance in Text-Generation Models 2 | This repository contains example of experiments for the paper 3 | Auditing Data Provenance in Text-Generation Models (https://arxiv.org/pdf/1811.00513.pdf/). 4 | 5 | ### Train text-generation models 6 | The first step is to train target and shadow text-generation models. 7 | To train language model, run the function train_reddit_lm in reddit_lm.py 8 | To train NMT model, run the function train_sated_nmt in sated_nmt.py 9 | To train dialog model, run the function train_cornell_movie in dialogue.py 10 | 11 | To train multiple shadow models, set field exp_id=1,2,...n in above function, where n is the number of shadow models. 12 | Set cross_domain=True to use cross domain datasets for shadow models. An example script for training target model 13 | and 10 shadow models on reddit language model with 300 users' data: 14 | ```python 15 | train_reddit_lm(exp_id=None, cross_domain=False, num_users=300) # target 16 | for i in range(10): 17 | train_reddit_lm(exp_id=i, cross_domain=True, num_users=300) # shadow i 18 | ``` 19 | 20 | ### Collect predictied ranks 21 | The next step is to collect predicted ranks on the models we just trained. Use function get_target_ranks and 22 | get_shadow_ranks to collect the ranks in reddit_lm_ranks.py, sated_nmt_ranks.py and dialogue_ranks.py. An example 23 | script for collecting ranks on reddit language model: 24 | ```python 25 | get_target_ranks(num_users=300) # target 26 | for i in range(10): 27 | get_shadow_ranks(exp_id=i, cross_domain=True, num_users=300) # shadow i 28 | ``` 29 | 30 | ### Auditing user membership 31 | Finally, we can perform auditing on collected ranks. 32 | Use function user_mi_attack in auditing.py. For example, script for audting reddit language model: 33 | ```python 34 | user_mi_attack(data_name='reddit', num_exp=10, num_users=300, cross_domain=True) # 10 shadow models, 300 users 35 | ``` 36 | 37 | data_name can be 'reddit', 'sated' and 'dialogs'. 38 | -------------------------------------------------------------------------------- /data_loader/load_wiki.py: -------------------------------------------------------------------------------- 1 | from load_reddit import build_vocab 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | 6 | WIKI_PATH_DIR = '/hdd/song/nlp/wikitext-103/' 7 | WIKI_TRAIN_PATH = WIKI_PATH_DIR + 'wiki.train.tokens' 8 | WIKI_DEV_PATH = WIKI_PATH_DIR + 'wiki.valid.tokens' 9 | WIKI_TEST_PATH = WIKI_PATH_DIR + 'wiki.test.tokens' 10 | 11 | 12 | def load_wiki_lines(filename=WIKI_TRAIN_PATH, num_lines=100): 13 | data = [] 14 | with open(filename, "r") as f: 15 | for line in f: 16 | line = line.replace('\n', '').lower().lstrip().rstrip() 17 | if not line: 18 | continue 19 | arr = line.split(' ') 20 | if arr[0] == '=': 21 | continue 22 | data.append(arr + ['']) 23 | if len(data) >= num_lines > 0: 24 | break 25 | print num_lines, len(data) 26 | return data 27 | 28 | 29 | def load_wiki_test_data(): 30 | dev = load_wiki_lines(WIKI_DEV_PATH, num_lines=0) 31 | test = load_wiki_lines(WIKI_TEST_PATH, num_lines=0) 32 | 33 | return dev + test 34 | 35 | 36 | def load_wiki_by_users(num_users=200, num_data_per_user=100, num_words=5000, vocabs=None): 37 | train_data = load_wiki_lines(num_lines=2 * num_users * num_data_per_user) 38 | print "Splitting data to {} users, each has {} texts".format(num_users * 2, num_data_per_user) 39 | 40 | all_users = np.arange(num_users * 2) 41 | np.random.seed(None) 42 | train_users = np.random.choice(all_users, size=num_users, replace=False) 43 | print train_users[:10] 44 | 45 | user_comments = defaultdict(list) 46 | all_words = [] 47 | 48 | for u in train_users: 49 | data = train_data[u * num_data_per_user: (u + 1) * num_data_per_user] 50 | for words in data: 51 | all_words += words 52 | user_comments[str(u)] = data 53 | 54 | if vocabs is None: 55 | vocabs = build_vocab(all_words, num_words + 1) 56 | 57 | all_words = [] 58 | for user in user_comments: 59 | comments = user_comments[user] 60 | for i in range(len(comments)): 61 | comment = comments[i] 62 | for j in range(len(comment)): 63 | word = comment[j] 64 | if word not in vocabs: 65 | comment[j] = '' 66 | all_words += comment 67 | 68 | vocabs = build_vocab(all_words, None) 69 | return user_comments, vocabs 70 | 71 | 72 | if __name__ == '__main__': 73 | load_wiki_by_users() -------------------------------------------------------------------------------- /reddit_lm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import keras.backend as K 4 | import numpy as np 5 | from keras import Model 6 | from keras.layers import Input, Embedding, CuDNNLSTM, CuDNNGRU, Dropout, Dense 7 | from keras.optimizers import Adam 8 | 9 | from data_loader.load_reddit import read_top_user_comments, read_test_comments 10 | from data_loader.load_wiki import load_wiki_by_users, load_wiki_test_data 11 | from helper import DenseTransposeTied, flatten_data, iterate_minibatches, words_to_indices 12 | 13 | MODEL_PATH = '/hdd/song/nlp/reddit/model/' 14 | RESULT_PATH = '/hdd/song/nlp/reddit/result/' 15 | 16 | 17 | def process_test_data(data, vocabs): 18 | for t in data: 19 | for i in range(len(t)): 20 | if t[i] not in vocabs: 21 | t[i] = '' 22 | 23 | 24 | def build_lm_model(emb_h=128, h=128, nh=1, V=5000, maxlen=35, drop_p=0.25, tied=False, rnn_fn='lstm'): 25 | input_layer = Input((maxlen,)) 26 | emb_layer = Embedding(V, emb_h, mask_zero=False) 27 | emb_output = emb_layer(input_layer) 28 | 29 | if rnn_fn == 'lstm': 30 | rnn = CuDNNLSTM 31 | elif rnn_fn == 'gru': 32 | rnn = CuDNNGRU 33 | else: 34 | raise ValueError(rnn_fn) 35 | 36 | if drop_p > 0.: 37 | emb_output = Dropout(drop_p)(emb_output) 38 | 39 | lstm_layer = rnn(h, return_sequences=True)(emb_output) 40 | if drop_p > 0.: 41 | lstm_layer = Dropout(drop_p)(lstm_layer) 42 | 43 | for _ in range(nh - 1): 44 | lstm_layer = rnn(h, return_sequences=True)(lstm_layer) 45 | if drop_p > 0.: 46 | lstm_layer = Dropout(drop_p)(lstm_layer) 47 | 48 | if tied: 49 | if emb_h != h: 50 | raise ValueError('When using the tied flag, nhid must be equal to emsize') 51 | output = DenseTransposeTied(V, tied_to=emb_layer, activation='linear')(lstm_layer) 52 | else: 53 | output = Dense(V, activation='linear')(lstm_layer) 54 | model = Model(inputs=[input_layer], outputs=[output]) 55 | return model 56 | 57 | 58 | def train_reddit_lm(num_users=300, num_words=5000, num_epochs=30, maxlen=35, batch_size=20, exp_id=0, 59 | h=128, emb_h=256, lr=1e-3, drop_p=0.25, tied=False, nh=1, loo=None, sample_user=False, 60 | cross_domain=False, print_every=1000, rnn_fn='lstm'): 61 | if cross_domain: 62 | loo = None 63 | sample_user = True 64 | user_comments, vocabs = load_wiki_by_users(num_users=num_users, num_words=num_words) 65 | else: 66 | user_comments, vocabs = read_top_user_comments(num_users, num_words, sample_user=sample_user) 67 | 68 | train_data = [] 69 | users = sorted(user_comments.keys()) 70 | 71 | for i, user in enumerate(users): 72 | if loo is not None and i == loo: 73 | print "Leaving {} out".format(i) 74 | continue 75 | train_data += user_comments[user] 76 | 77 | train_data = words_to_indices(train_data, vocabs) 78 | train_data = flatten_data(train_data) 79 | 80 | if cross_domain: 81 | test_data = load_wiki_test_data() 82 | else: 83 | test_data = read_test_comments() 84 | 85 | process_test_data(test_data, vocabs) 86 | test_data = words_to_indices(test_data, vocabs) 87 | test_data = flatten_data(test_data) 88 | 89 | n_data = (len(train_data) - 1) // maxlen 90 | X_train = train_data[:-1][:n_data * maxlen].reshape(-1, maxlen) 91 | y_train = train_data[1:][:n_data * maxlen].reshape(-1, maxlen) 92 | print X_train.shape 93 | 94 | n_test_data = (len(test_data) - 1) // maxlen 95 | X_test = test_data[:-1][:n_test_data * maxlen].reshape(-1, maxlen) 96 | y_test = test_data[1:][:n_test_data * maxlen].reshape(-1, maxlen) 97 | print X_test.shape 98 | 99 | model = build_lm_model(emb_h=emb_h, h=h, nh=nh, drop_p=drop_p, V=len(vocabs), tied=tied, maxlen=maxlen, 100 | rnn_fn=rnn_fn) 101 | 102 | input_var = K.placeholder((None, maxlen)) 103 | target_var = K.placeholder((None, maxlen)) 104 | 105 | prediction = model(input_var) 106 | 107 | loss = K.sparse_categorical_crossentropy(target_var, prediction, from_logits=True) 108 | loss = K.mean(K.sum(loss, axis=-1)) 109 | 110 | optimizer = Adam(lr=lr, clipnorm=5) 111 | 112 | updates = optimizer.get_updates(loss, model.trainable_weights) 113 | train_fn = K.function([input_var, target_var, K.learning_phase()], [loss], updates=updates) 114 | 115 | pred_fn = K.function([input_var, target_var, K.learning_phase()], [prediction, loss]) 116 | 117 | iteration = 1 118 | for epoch in range(num_epochs): 119 | train_batches = 0. 120 | train_loss = 0. 121 | train_iters = 0. 122 | 123 | for batch in iterate_minibatches(X_train, y_train, batch_size, shuffle=True): 124 | inputs, targets = batch 125 | err = train_fn([inputs, targets, 1])[0] 126 | train_batches += 1 127 | train_loss += err 128 | train_iters += maxlen 129 | 130 | iteration += 1 131 | if iteration % print_every == 0: 132 | test_acc = 0. 133 | test_n = 0. 134 | test_iters = 0. 135 | test_loss = 0. 136 | test_batches = 0. 137 | 138 | for batch in iterate_minibatches(X_test, y_test, batch_size, shuffle=False): 139 | inputs, targets = batch 140 | 141 | preds, err = pred_fn([inputs, targets, 0]) 142 | test_loss += err 143 | test_iters += maxlen 144 | test_batches += 1 145 | 146 | preds = preds.argmax(axis=-1) 147 | test_acc += np.sum(preds.flatten() == targets.flatten()) 148 | test_n += len(targets.flatten()) 149 | 150 | sys.stderr.write("Epoch {}, iteration {}, train loss={:.3f}, train perp={:.3f}, " 151 | "test loss={:.3f}, test perp={:.3f}, " 152 | "test acc={:.3f}%\n".format(epoch, iteration, 153 | train_loss / train_batches, 154 | np.exp(train_loss / train_iters), 155 | test_loss / test_batches, 156 | np.exp(test_loss / test_iters), 157 | test_acc / test_n * 100)) 158 | 159 | if cross_domain: 160 | fname = 'wiki_lm{}'.format('' if loo is None else loo) 161 | else: 162 | fname = 'reddit_lm{}'.format('' if loo is None else loo) 163 | 164 | if sample_user: 165 | fname += '_shadow_exp{}_{}'.format(exp_id, rnn_fn) 166 | np.savez(MODEL_PATH + 'shadow_users{}_{}_{}_{}.npz'.format(exp_id, rnn_fn, num_users, 167 | 'cd' if cross_domain else ''), users) 168 | 169 | model.save(MODEL_PATH + '{}_{}.h5'.format(fname, num_users)) 170 | 171 | 172 | if __name__ == '__main__': 173 | train_reddit_lm() 174 | -------------------------------------------------------------------------------- /auditing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | from sklearn.preprocessing import StandardScaler, Normalizer 5 | from sklearn.metrics import classification_report, accuracy_score, roc_auc_score, precision_recall_fscore_support 6 | from sklearn.svm import LinearSVC 7 | 8 | from sated_nmt import OUTPUT_PATH as SATED_OUTPUT_PATH 9 | from reddit_lm import RESULT_PATH as REDDIT_OUTPUT_PATH 10 | from dialogue import OUTPUT_PATH as CORNELL_OUTPUT_PATH 11 | 12 | 13 | def histogram_feats(ranks, bins=100, top_words=5000, num_words=5000, relative=False): 14 | if top_words < num_words: 15 | if bins == top_words: 16 | bins += 1 17 | top_words += 1 18 | 19 | range = (-num_words, top_words) if relative else (0, top_words) 20 | feats, _ = np.histogram(ranks, bins=bins, normed=False, range=range) 21 | return feats 22 | 23 | 24 | def sample_with_ratio(a, b, heldout_ratio=0.5): 25 | if heldout_ratio == 0.: 26 | return a 27 | if heldout_ratio == 1.: 28 | return b 29 | 30 | if not isinstance(a, list): 31 | a = a.tolist() 32 | b = b.tolist() 33 | 34 | l1 = len(a) 35 | l2 = len(b) 36 | ratio = float(l2) / (l1 + l2) 37 | if heldout_ratio > ratio: 38 | # remove from a 39 | n = int(l2 / heldout_ratio) 40 | rest_l1 = n - l2 41 | return a[:rest_l1] + b 42 | elif heldout_ratio < ratio: 43 | # remove from b 44 | n = int(l1 / (1 - heldout_ratio)) 45 | rest_l2 = n - l1 46 | return a + b[:rest_l2] 47 | else: 48 | return a + b 49 | 50 | 51 | def load_ranks(save_dir, num_users=5000, cross_domain=False): 52 | ranks = [] 53 | labels = [] 54 | y = [] 55 | for i in range(num_users): 56 | save_path = save_dir + 'rank_u{}_y1{}.npz'.format(i, '_cd' if cross_domain else '') 57 | if os.path.exists(save_path): 58 | f = np.load(save_path) 59 | train_rs, train_ls = f['arr_0'], f['arr_1'] 60 | ranks.append(train_rs) 61 | labels.append(train_ls) 62 | y.append(1) 63 | 64 | save_path = save_dir + 'rank_u{}_y0{}.npz'.format(i, '_cd' if cross_domain else '') 65 | if os.path.exists(save_path): 66 | f = np.load(save_path) 67 | test_rs, test_ls = f['arr_0'], f['arr_1'] 68 | ranks.append(test_rs) 69 | labels.append(test_ls) 70 | y.append(0) 71 | 72 | return ranks, labels, np.asarray(y) 73 | 74 | 75 | def get_indices_by_labels(sent_labels): 76 | sent_label_sum = [-np.sum(labels) for labels in sent_labels] 77 | return np.argsort(sent_label_sum) 78 | 79 | 80 | def ranks_to_feats(ranks, labels=None, prop=1.0, dim=100, num_words=5000, top_words=5000, shuffle=False, 81 | rare=False, relative=False, user_data_ratio=0., heldout_ratio=0., num_users=300): 82 | if relative or rare: 83 | assert labels is not None 84 | X = [] 85 | 86 | for i, user_ranks in enumerate(ranks): 87 | indices = np.arange(len(user_ranks)) 88 | if relative or rare: 89 | user_labels = labels[i] 90 | assert len(user_labels) == len(user_ranks) 91 | else: 92 | user_labels = None 93 | 94 | r = [] 95 | 96 | if 0. < user_data_ratio < 1. and i < num_users: 97 | l = len(user_ranks) 98 | for idx in range(l): 99 | user_ranks[idx] = np.clip(user_ranks[idx], 0, top_words) 100 | if relative: 101 | assert len(user_ranks[idx]) == len(user_labels[idx]) 102 | user_ranks[idx] = user_ranks[idx] - user_labels[idx] 103 | 104 | train_l = int(l * user_data_ratio) 105 | train_ranks = user_ranks[:train_l] 106 | heldout_ranks = user_ranks[train_l:] 107 | for rank in sample_with_ratio(train_ranks, heldout_ranks, heldout_ratio): 108 | r.append(rank) 109 | else: 110 | if shuffle: 111 | np.random.seed(None) 112 | np.random.shuffle(indices) 113 | 114 | if rare: 115 | indices = get_indices_by_labels(user_labels) 116 | 117 | n = int(len(indices) * prop) + 1 if isinstance(prop, float) else prop 118 | for idx in indices[:n]: 119 | user_ranks[idx] = np.clip(user_ranks[idx], 0, top_words) 120 | if relative: 121 | assert len(user_ranks[idx]) == len(user_labels[idx]) 122 | r.append(user_ranks[idx] - user_labels[idx]) 123 | else: 124 | r.append(user_ranks[idx]) 125 | 126 | # print i, r 127 | if isinstance(r[0], int): 128 | print i 129 | else: 130 | r = np.concatenate(r) 131 | 132 | feats = histogram_feats(r, bins=dim, num_words=num_words, top_words=top_words, relative=relative) 133 | X.append(feats) 134 | 135 | return np.vstack(X) 136 | 137 | 138 | def user_mi_attack(data_name='reddit', num_exp=5, num_users=5000, dim=100, prop=1.0, user_data_ratio=0., 139 | heldout_ratio=0., num_words=5000, top_words=5000, relative=False, rare=False, norm=True, 140 | scale=True, cross_domain=False, rerun=False): 141 | 142 | if data_name == 'reddit': 143 | result_path = REDDIT_OUTPUT_PATH 144 | elif data_name == 'sated': 145 | result_path = SATED_OUTPUT_PATH 146 | elif data_name == 'dialogs': 147 | result_path = CORNELL_OUTPUT_PATH 148 | else: 149 | raise ValueError(data_name) 150 | 151 | if dim > top_words: 152 | dim = top_words 153 | 154 | audit_save_path = result_path + 'mi_data_dim{}_prop{}_{}{}.npz'.format( 155 | dim, prop, num_users, '_cd' if cross_domain else '') 156 | 157 | if not rerun and os.path.exists(audit_save_path): 158 | f = np.load(audit_save_path) 159 | X_train, y_train, X_test, y_test = [f['arr_{}'.format(i)] for i in range(4)] 160 | else: 161 | save_dir = result_path + 'target_{}{}/'.format(num_users, '_dr' if 0. < user_data_ratio < 1. else '') 162 | ranks, labels, y_test = load_ranks(save_dir, num_users) 163 | X_test = ranks_to_feats(ranks, prop=prop, dim=dim, top_words=top_words, user_data_ratio=user_data_ratio, 164 | num_words=num_words, labels=labels, rare=rare, relative=relative, 165 | heldout_ratio=heldout_ratio) 166 | 167 | X_train, y_train = [], [] 168 | for exp_id in range(num_exp): 169 | save_dir = result_path + 'shadow_exp{}_{}/'.format(exp_id, num_users) 170 | ranks, labels, y = load_ranks(save_dir, num_users, cross_domain=cross_domain) 171 | feats = ranks_to_feats(ranks, prop=prop, dim=dim, top_words=top_words, relative=relative, 172 | num_words=num_words, labels=labels) 173 | X_train.append(feats) 174 | y_train.append(y) 175 | 176 | X_train = np.vstack(X_train) 177 | y_train = np.concatenate(y_train) 178 | np.savez(audit_save_path, X_train, y_train, X_test, y_test) 179 | 180 | print X_train.shape, y_train.shape 181 | 182 | if norm: 183 | normalizer = Normalizer(norm='l2') 184 | X_train = normalizer.transform(X_train) 185 | X_test = normalizer.transform(X_test) 186 | 187 | if scale: 188 | scaler = StandardScaler() 189 | X_train = scaler.fit_transform(X_train) 190 | X_test = scaler.transform(X_test) 191 | 192 | clf = LinearSVC(verbose=1) 193 | clf.fit(X_train, y_train) 194 | 195 | y_pred = clf.predict(X_test) 196 | y_score = clf.decision_function(X_test) 197 | 198 | print classification_report(y_pred=y_pred, y_true=y_test) 199 | 200 | acc = accuracy_score(y_test, y_pred) 201 | auc = roc_auc_score(y_test, y_score) 202 | pres, recs, _, _ = precision_recall_fscore_support(y_test, y_pred) 203 | pre = pres[1] 204 | rec = recs[1] 205 | 206 | print 'precision={}, recall={}, acc={}, auc={}'.format(pre, rec, acc, auc) 207 | return acc, auc, pre, rec 208 | -------------------------------------------------------------------------------- /data_loader/load_sated.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Counter, defaultdict 3 | from itertools import chain 4 | 5 | 6 | SATED_PATH = '/hdd/song/nlp/sated-release-0.9.0/en-fr/' 7 | SATED_TRAIN_ENG = SATED_PATH + 'train.en' 8 | SATED_TRAIN_FR = SATED_PATH + 'train.fr' 9 | SATED_TRAIN_USER = SATED_PATH + 'train.usr' 10 | SATED_DEV_ENG = SATED_PATH + 'dev.en' 11 | SATED_DEV_FR = SATED_PATH + 'dev.fr' 12 | SATED_DEV_USER = SATED_PATH + 'dev.usr' 13 | SATED_TEST_ENG = SATED_PATH + 'test.en' 14 | SATED_TEST_FR = SATED_PATH + 'test.fr' 15 | SATED_TEST_USER = SATED_PATH + 'test.usr' 16 | EUROPARL_PATH = '/hdd/song/nlp/europarl/' 17 | EUROPARL_DEEN_DE = EUROPARL_PATH + 'europarl.de-en.de.aligned.tok' 18 | EUROPARL_DEEN_EN = EUROPARL_PATH + 'europarl.de-en.en.aligned.tok' 19 | EUROPARL_FREN_FR = EUROPARL_PATH + 'europarl.fr-en.fr.aligned.tok' 20 | EUROPARL_FREN_EN = EUROPARL_PATH + 'europarl.fr-en.en.aligned.tok' 21 | 22 | 23 | def load_users(p=SATED_TRAIN_USER): 24 | users = [] 25 | with open(p, 'rb') as f: 26 | for line in f: 27 | users.append(line.replace('\n', '')) 28 | return users 29 | 30 | 31 | def load_texts(p=SATED_TRAIN_ENG): 32 | texts = [] 33 | with open(p, 'rb') as f: 34 | for line in f: 35 | arr = [''] + line.replace('\n', '').split(' ') + [''] 36 | words = [] 37 | for w in arr: 38 | words.append(w) 39 | texts.append(words) 40 | 41 | return texts 42 | 43 | 44 | def process_texts(texts, vocabs): 45 | for t in texts: 46 | for i, w in enumerate(t): 47 | if w not in vocabs: 48 | t[i] = '' 49 | 50 | 51 | def process_vocabs(vocabs, num_words=10000): 52 | counter = Counter(vocabs) 53 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 54 | print('Loaded {} vocabs'.format(len(count_pairs))) 55 | 56 | if num_words is not None: 57 | count_pairs = count_pairs[:num_words - 1] 58 | 59 | print count_pairs[:50] 60 | 61 | words, _ = list(zip(*count_pairs)) 62 | word_to_id = dict(zip(words, np.arange(len(words)))) 63 | return word_to_id 64 | 65 | 66 | def load_sated_data(num_words=10000): 67 | train_src_texts = load_texts(SATED_TRAIN_ENG) 68 | train_trg_texts = load_texts(SATED_TRAIN_FR) 69 | 70 | dev_src_texts = load_texts(SATED_DEV_ENG) 71 | dev_trg_texts = load_texts(SATED_DEV_FR) 72 | 73 | test_src_texts = load_texts(SATED_TEST_ENG) 74 | test_trg_texts = load_texts(SATED_TEST_FR) 75 | 76 | src_words = list(chain(*train_src_texts)) 77 | trg_words = list(chain(*train_trg_texts)) 78 | 79 | src_vocabs = process_vocabs(src_words, num_words) 80 | trg_vocabs = process_vocabs(trg_words, num_words) 81 | 82 | process_texts(train_src_texts, src_vocabs) 83 | process_texts(train_trg_texts, trg_vocabs) 84 | 85 | process_texts(dev_src_texts, src_vocabs) 86 | process_texts(dev_trg_texts, trg_vocabs) 87 | 88 | process_texts(test_src_texts, src_vocabs) 89 | process_texts(test_trg_texts, trg_vocabs) 90 | 91 | src_words = list(chain(*train_src_texts)) 92 | trg_words = list(chain(*train_trg_texts)) 93 | 94 | src_vocabs = process_vocabs(src_words, None) 95 | trg_vocabs = process_vocabs(trg_words, None) 96 | 97 | return train_src_texts, train_trg_texts, dev_src_texts, dev_trg_texts, test_src_texts, test_trg_texts, \ 98 | src_vocabs, trg_vocabs 99 | 100 | 101 | def load_sated_data_by_user(num_users=100, num_words=10000, test_on_user=False, sample_user=False, 102 | seed=12345, user_data_ratio=0.): 103 | src_users = load_users(SATED_TRAIN_USER) 104 | train_src_texts = load_texts(SATED_TRAIN_ENG) 105 | train_trg_texts = load_texts(SATED_TRAIN_FR) 106 | 107 | dev_src_texts = load_texts(SATED_DEV_ENG) 108 | dev_trg_texts = load_texts(SATED_DEV_FR) 109 | 110 | test_src_texts = load_texts(SATED_TEST_ENG) 111 | test_trg_texts = load_texts(SATED_TEST_FR) 112 | 113 | user_counter = Counter(src_users) 114 | all_users = [tup[0] for tup in user_counter.most_common()] 115 | # print len(all_users) 116 | np.random.seed(seed) 117 | np.random.shuffle(all_users) 118 | np.random.seed(None) 119 | 120 | train_users = set(all_users[:num_users]) 121 | test_users = set(all_users[num_users: num_users * 2]) 122 | 123 | if sample_user: 124 | attacker_users = all_users[num_users * 2: num_users * 4] 125 | # np.random.seed(None) 126 | train_users = np.random.choice(attacker_users, size=num_users, replace=False) 127 | print len(train_users) 128 | print train_users[:10] 129 | 130 | user_src_texts = defaultdict(list) 131 | user_trg_texts = defaultdict(list) 132 | 133 | test_user_src_texts = defaultdict(list) 134 | test_user_trg_texts = defaultdict(list) 135 | 136 | for u, s, t in zip(src_users, train_src_texts, train_trg_texts): 137 | if u in train_users: 138 | user_src_texts[u].append(s) 139 | user_trg_texts[u].append(t) 140 | if test_on_user and u in test_users: 141 | test_user_src_texts[u].append(s) 142 | test_user_trg_texts[u].append(t) 143 | 144 | if 0. < user_data_ratio < 1.: 145 | # held out some fraction of data for testing 146 | for u in user_src_texts: 147 | l = len(user_src_texts[u]) 148 | # print l 149 | l = int(l * user_data_ratio) 150 | user_src_texts[u] = user_src_texts[u][:l] 151 | user_trg_texts[u] = user_trg_texts[u][:l] 152 | 153 | src_words = [] 154 | trg_words = [] 155 | for u in train_users: 156 | src_words += list(chain(*user_src_texts[u])) 157 | trg_words += list(chain(*user_trg_texts[u])) 158 | 159 | src_vocabs = process_vocabs(src_words, num_words) 160 | trg_vocabs = process_vocabs(trg_words, num_words) 161 | 162 | for u in train_users: 163 | process_texts(user_src_texts[u], src_vocabs) 164 | process_texts(user_trg_texts[u], trg_vocabs) 165 | 166 | if test_on_user: 167 | for u in test_users: 168 | process_texts(test_user_src_texts[u], src_vocabs) 169 | process_texts(test_user_trg_texts[u], trg_vocabs) 170 | 171 | process_texts(dev_src_texts, src_vocabs) 172 | process_texts(dev_trg_texts, trg_vocabs) 173 | 174 | process_texts(test_src_texts, src_vocabs) 175 | process_texts(test_trg_texts, trg_vocabs) 176 | 177 | src_words = [] 178 | trg_words = [] 179 | for u in train_users: 180 | src_words += list(chain(*user_src_texts[u])) 181 | trg_words += list(chain(*user_trg_texts[u])) 182 | 183 | src_vocabs = process_vocabs(src_words, None) 184 | trg_vocabs = process_vocabs(trg_words, None) 185 | 186 | if test_on_user: 187 | return user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs 188 | else: 189 | return user_src_texts, user_trg_texts, dev_src_texts, dev_trg_texts, test_src_texts, test_trg_texts,\ 190 | src_vocabs, trg_vocabs 191 | 192 | 193 | def read_europarl_file(filename, num_lines=80000): 194 | texts = [] 195 | with open(filename, 'rb') as f: 196 | for line in f: 197 | arr = [''] + line.lower().replace('\n', '').split(' ') + [''] 198 | texts.append(arr) 199 | if len(texts) > num_lines: 200 | break 201 | return texts 202 | 203 | 204 | def load_europarl_by_user(num_users=200, num_data_per_user=150, num_words=5000, test_size=5000): 205 | src_texts = read_europarl_file(EUROPARL_FREN_EN, num_users * num_data_per_user * 2 + test_size) 206 | trg_texts = read_europarl_file(EUROPARL_FREN_FR, num_users * num_data_per_user * 2 + test_size) 207 | 208 | test_src_texts = src_texts[-test_size:] 209 | test_trg_texts = trg_texts[-test_size:] 210 | src_texts = src_texts[:-test_size] 211 | trg_texts = trg_texts[:-test_size] 212 | 213 | all_users = np.arange(num_users * 2) 214 | np.random.seed(None) 215 | train_users = np.random.choice(all_users, size=num_users, replace=False) 216 | 217 | user_src_texts = defaultdict(list) 218 | user_trg_texts = defaultdict(list) 219 | 220 | for u in train_users: 221 | user_src_texts[u] = src_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 222 | user_trg_texts[u] = trg_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 223 | 224 | src_words = [] 225 | trg_words = [] 226 | for u in train_users: 227 | src_words += list(chain(*user_src_texts[u])) 228 | trg_words += list(chain(*user_trg_texts[u])) 229 | 230 | src_vocabs = process_vocabs(src_words, num_words) 231 | trg_vocabs = process_vocabs(trg_words, num_words) 232 | 233 | for u in train_users: 234 | process_texts(user_src_texts[u], src_vocabs) 235 | process_texts(user_trg_texts[u], trg_vocabs) 236 | 237 | process_texts(test_src_texts, src_vocabs) 238 | process_texts(test_trg_texts, trg_vocabs) 239 | 240 | src_words = [] 241 | trg_words = [] 242 | for u in train_users: 243 | src_words += list(chain(*user_src_texts[u])) 244 | trg_words += list(chain(*user_trg_texts[u])) 245 | 246 | src_vocabs = process_vocabs(src_words, None) 247 | trg_vocabs = process_vocabs(trg_words, None) 248 | 249 | return user_src_texts, user_trg_texts, test_src_texts, test_trg_texts, test_src_texts, test_trg_texts,\ 250 | src_vocabs, trg_vocabs 251 | 252 | 253 | if __name__ == '__main__': 254 | load_sated_data_by_user(num_users=300, sample_user=False) 255 | -------------------------------------------------------------------------------- /data_loader/load_reddit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs 3 | import numpy as np 4 | 5 | from collections import defaultdict, Counter 6 | from nltk.tokenize import word_tokenize 7 | 8 | 9 | REDDIT_PATH = '/hdd/song/nlp/reddit/' 10 | REDDIT_USER_PATH = REDDIT_PATH + 'shard_by_author/' 11 | REDDIT_PROCESSED_PATH = REDDIT_PATH + 'shard_by_author_processed/' 12 | REDDIT_USER_COUNT_PATH = REDDIT_PATH + 'author_count' 13 | REDDIT_TEST_PATH = REDDIT_PATH + 'test_data.json' 14 | PTB_DATA_DIR = './ptb/simple-examples/data/' 15 | 16 | 17 | def translate(t): 18 | t = t.replace(u'\u2018', '\'') 19 | t = t.replace(u'\u2019', '\'') 20 | t = t.replace(u'\u201c', '\"') 21 | t = t.replace(u'\u201d', '\"') 22 | t = t.replace(u'\u2013', '-') 23 | t = t.replace(u'\u2014', '-') 24 | 25 | t = t.replace(u'\u2026', '') 26 | t = t.replace(u'\ufffd', '') 27 | t = t.replace(u'\ufe0f', '') 28 | t = t.replace(u'\u035c', '') 29 | t = t.replace(u'\u0296', '') 30 | t = t.replace(u'\u270a', '') 31 | t = t.replace(u'*', '') 32 | t = t.replace(u'~', '') 33 | 34 | t = t.replace(u'\ufb00', 'ff') 35 | 36 | return t 37 | 38 | 39 | def preprocess(t): 40 | words = t.split(' ') 41 | for i in range(len(words)): 42 | if 'http' in words[i]: 43 | words[i] = '/url/' 44 | return ' '.join(words) 45 | 46 | 47 | def remove_puncs(words): 48 | new_words = [] 49 | for w in words: 50 | flag = False 51 | for c in w: 52 | if c.isalnum(): 53 | flag = True 54 | break 55 | if flag: 56 | new_words.append(w) 57 | return new_words 58 | 59 | 60 | def write_processed_comments(): 61 | for user in os.listdir(REDDIT_USER_PATH): 62 | filename = os.path.join(REDDIT_USER_PATH, user) 63 | cnt = 0 64 | new_lines = [] 65 | with codecs.open(filename, encoding='utf-8') as f: 66 | for line in f: 67 | text = line[1:-2].decode('unicode_escape').lower() 68 | text = translate(text) 69 | text = preprocess(text) 70 | words = word_tokenize(text) 71 | words = remove_puncs(words) 72 | if len(words) < 3: 73 | continue 74 | cnt += 1 75 | new_line = ' '.join(words) 76 | # print new_line 77 | new_lines.append(new_line) 78 | print user, cnt 79 | 80 | with open(REDDIT_PROCESSED_PATH + user, 'wb') as f: 81 | for line in new_lines: 82 | f.write(line.encode('utf8') + '\n') 83 | # quit() 84 | 85 | 86 | def read_top_users(num_users=100, random=True, min_count=200): 87 | users = [] 88 | cnts = [] 89 | with codecs.open(REDDIT_USER_COUNT_PATH, encoding='utf-8') as f: 90 | for line in f: 91 | user, cnt = line.split('\t') 92 | cnt = int(cnt) 93 | if cnt < min_count: 94 | continue 95 | users.append(user) 96 | cnts.append(cnt) 97 | 98 | print len(users), sum(cnts) 99 | 100 | cnts = np.asarray(cnts) 101 | if random: 102 | np.random.seed(12345) 103 | top_indices = np.arange(len(cnts)) 104 | np.random.shuffle(top_indices) 105 | top_indices = top_indices[:num_users] 106 | np.random.seed(None) 107 | else: 108 | top_indices = np.argsort(-cnts)[:num_users] 109 | top_users = np.asarray(users)[top_indices] 110 | 111 | print('Loading {} comments for {} users'.format(cnts[top_indices].sum(), num_users)) 112 | 113 | return top_users 114 | 115 | 116 | def read_random_users(num_users=100, num_top_users=100): 117 | users = [] 118 | cnts = [] 119 | with codecs.open(REDDIT_USER_COUNT_PATH, encoding='utf-8') as f: 120 | for line in f: 121 | user, cnt = line.split('\t') 122 | cnt = int(cnt) 123 | users.append(user) 124 | cnts.append(cnt) 125 | 126 | cnts = np.asarray(cnts) 127 | indices = np.argsort(-cnts)[num_top_users:] 128 | users = np.asarray(users)[indices] 129 | random_users = np.random.choice(users, num_users, replace=False) 130 | return random_users 131 | 132 | 133 | def build_vocab(data, num_words=20000): 134 | counter = Counter(data) 135 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 136 | print('Loaded {} vocabs'.format(len(count_pairs))) 137 | 138 | if num_words is not None: 139 | count_pairs = count_pairs[:num_words - 1] 140 | 141 | words, _ = list(zip(*count_pairs)) 142 | word_to_id = dict(zip(words, np.arange(len(words)))) 143 | return word_to_id 144 | 145 | 146 | def read_comments_by_user(user, vocabs, max_token=1600): 147 | filename = os.path.join(REDDIT_PROCESSED_PATH, user) 148 | user_comments = [] 149 | user_num_tokens = 0 150 | with codecs.open(filename, encoding='utf-8') as f: 151 | for line in f: 152 | data = line.replace('/url/', '').replace('\n', '').split() + [''] 153 | if len(data) == 1: 154 | continue 155 | 156 | user_comments.append(data) 157 | user_num_tokens += len(data) 158 | if user_num_tokens >= max_token + 1: 159 | break 160 | 161 | for i in range(len(user_comments)): 162 | comment = user_comments[i] 163 | for j in range(len(comment)): 164 | word = comment[j] 165 | if word not in vocabs: 166 | comment[j] = '' 167 | return user_comments 168 | 169 | 170 | def read_top_user_comments(num_users=200, num_words=5000, max_token=None, vocabs=None, top_users=None, 171 | sample_user=False, load_raw=False): 172 | if sample_user and top_users is None: 173 | top_users = read_top_users(num_users * 4)[num_users * 2: num_users * 4] 174 | print len(top_users) 175 | 176 | np.random.seed(None) 177 | top_users = np.random.choice(top_users, size=num_users, replace=False) 178 | print len(top_users) 179 | elif top_users is None: 180 | top_users = read_top_users(num_users) 181 | 182 | user_comments = defaultdict(list) 183 | user_num_tokens = Counter() 184 | all_words = [] 185 | for user in top_users: 186 | filename = os.path.join(REDDIT_PROCESSED_PATH, user) 187 | with codecs.open(filename, encoding='utf-8') as f: 188 | for line in f: 189 | data = line.replace('/url/', '').replace('\n', '').split() + [''] 190 | 191 | if len(data) == 1: 192 | continue 193 | 194 | user_comments[user].append(data) 195 | user_num_tokens[user] += len(data) 196 | all_words += data 197 | if max_token is not None and user_num_tokens[user] >= max_token + 1: 198 | break 199 | if load_raw: 200 | return user_comments 201 | 202 | if vocabs is None: 203 | vocabs = build_vocab(all_words, num_words) 204 | 205 | all_words = [] 206 | for user in user_comments: 207 | comments = user_comments[user] 208 | for i in range(len(comments)): 209 | comment = comments[i] 210 | for j in range(len(comment)): 211 | word = comment[j] 212 | if word not in vocabs: 213 | comment[j] = '' 214 | all_words += comment 215 | 216 | vocabs = build_vocab(all_words, None) 217 | return user_comments, vocabs 218 | 219 | 220 | def read_test_comments_by_user(num_top_users=100): 221 | test_users = read_random_users(num_top_users, num_top_users) 222 | user_comments = defaultdict(list) 223 | for user in test_users: 224 | filename = os.path.join(REDDIT_PROCESSED_PATH, user) 225 | with codecs.open(filename, encoding='utf-8') as f: 226 | for line in f: 227 | data = line.replace('/url/', '').replace('\n', '').split() + [''] 228 | if len(data) == 1: 229 | continue 230 | 231 | user_comments[user].append(data) 232 | return user_comments 233 | 234 | 235 | def read_test_comments(): 236 | test_comments = [] 237 | with codecs.open(REDDIT_TEST_PATH, encoding='utf-8') as f: 238 | for line in f: 239 | text = line[1:-2].decode('unicode_escape').lower() 240 | text = translate(text) 241 | text = preprocess(text) 242 | text = text.replace('/url/', '').replace('\n', '') 243 | words = word_tokenize(text) 244 | words = remove_puncs(words) + [''] 245 | if len(words) > 2: 246 | test_comments.append(words) 247 | 248 | print('Loaded {} test data'.format(len(test_comments))) 249 | return test_comments 250 | 251 | 252 | def read_ptb_file(filename): 253 | data = [] 254 | with open(filename, "r") as f: 255 | for line in f: 256 | data.append(line.decode('utf-8').replace('\n', '').split()) 257 | return data 258 | 259 | 260 | def read_ptb_data_by_user(num_users=100, num_words=5000, vocabs=None): 261 | train_path = os.path.join(PTB_DATA_DIR, 'ptb.train.txt') 262 | train_data = read_ptb_file(train_path) 263 | 264 | l = len(train_data) 265 | num_data_per_user = l // (num_users * 2) 266 | print "Splitting data to {} users, each has {} texts".format(num_users * 2, num_data_per_user) 267 | 268 | all_users = np.arange(num_users * 2) 269 | np.random.seed(None) 270 | train_users = np.random.choice(all_users, size=num_users, replace=False) 271 | print train_users[:10] 272 | 273 | user_comments = defaultdict(list) 274 | all_words = [] 275 | 276 | for u in train_users: 277 | data = train_data[u * num_data_per_user: (u + 1) * num_data_per_user] 278 | for words in data: 279 | all_words += words 280 | user_comments[str(u)] = data 281 | 282 | if vocabs is None: 283 | vocabs = build_vocab(all_words, num_words + 1) 284 | 285 | all_words = [] 286 | for user in user_comments: 287 | comments = user_comments[user] 288 | for i in range(len(comments)): 289 | comment = comments[i] 290 | for j in range(len(comment)): 291 | word = comment[j] 292 | if word not in vocabs: 293 | comment[j] = '' 294 | all_words += comment 295 | 296 | vocabs = build_vocab(all_words, None) 297 | return user_comments, vocabs 298 | 299 | 300 | def read_ptb_test_data(): 301 | test_path = os.path.join(PTB_DATA_DIR, 'ptb.test.txt') 302 | test_data = read_ptb_file(test_path) 303 | return test_data 304 | 305 | 306 | if __name__ == '__main__': 307 | read_top_user_comments(num_users=5000, num_words=5000, sample_user=True) 308 | -------------------------------------------------------------------------------- /dialogue.py: -------------------------------------------------------------------------------- 1 | from keras import Model 2 | from keras.layers import Input, Embedding, LSTM, Dropout, Dense, CuDNNLSTM, CuDNNGRU 3 | from helper import DenseTransposeTied 4 | from keras.optimizers import Adam 5 | 6 | import keras.backend as K 7 | import copy 8 | 9 | from collections import defaultdict 10 | from data_loader.load_cornell_movie import load_ubuntu_by_user, load_cornell_movie_by_user 11 | from sated_nmt import beam_search, bleu_score 12 | 13 | import pprint 14 | import numpy as np 15 | 16 | MODEL_PATH = '/hdd/song/nlp/cornell_movie_dialogs_corpus/model/' 17 | OUTPUT_PATH = '/hdd/song/nlp/cornell_movie_dialogs_corpus/output/' 18 | 19 | 20 | def group_texts_by_len(src_texts, trg_texts, bs=20): 21 | print("Bucketing batches") 22 | # Bucket samples by source sentence length 23 | buckets = defaultdict(list) 24 | batches = [] 25 | for src, trg in zip(src_texts, trg_texts): 26 | buckets[len(src)].append((src, trg)) 27 | 28 | for src_len, bucket in buckets.items(): 29 | np.random.shuffle(bucket) 30 | num_batches = int(np.ceil(len(bucket) * 1.0 / bs)) 31 | for i in range(num_batches): 32 | cur_batch_size = bs if i < num_batches - 1 else len(bucket) - bs * i 33 | batches.append(([bucket[i * bs + j][0] for j in range(cur_batch_size)], 34 | [bucket[i * bs + j][1] for j in range(cur_batch_size)])) 35 | return batches 36 | 37 | 38 | def build_dialogue_model(Vs, Vt, demb=128, h=128, drop_p=0.5, tied=True, mask=True, training=None, rnn_fn='lstm'): 39 | if rnn_fn == 'lstm': 40 | rnn = LSTM if mask else CuDNNLSTM 41 | elif rnn_fn == 'gru': 42 | rnn = LSTM if mask else CuDNNGRU 43 | else: 44 | raise ValueError(rnn_fn) 45 | 46 | # build encoder 47 | encoder_input = Input((None,), dtype='float32', name='encoder_input') 48 | if mask: 49 | encoder_emb_layer = Embedding(Vs + 1, demb, mask_zero=True, name='encoder_emb') 50 | else: 51 | encoder_emb_layer = Embedding(Vs, demb, mask_zero=False, name='encoder_emb') 52 | 53 | encoder_emb = encoder_emb_layer(encoder_input) 54 | 55 | if drop_p > 0.: 56 | encoder_emb = Dropout(drop_p)(encoder_emb, training=training) 57 | 58 | encoder_rnn = rnn(h, return_sequences=True, return_state=True, name='encoder_rnn') 59 | encoder_rtn = encoder_rnn(encoder_emb) 60 | # # encoder_outputs, encoder_h, encoder_c = encoder_rnn(encoder_emb) 61 | # encoder_outputs = encoder_rtn[0] 62 | encoder_states = encoder_rtn[1:] 63 | 64 | # build decoder 65 | decoder_input = Input((None,), dtype='float32', name='decoder_input') 66 | if mask: 67 | decoder_emb_layer = Embedding(Vt + 1, demb, mask_zero=True, name='decoder_emb') 68 | else: 69 | decoder_emb_layer = Embedding(Vt, demb, mask_zero=False, name='decoder_emb') 70 | 71 | decoder_emb = decoder_emb_layer(decoder_input) 72 | 73 | if drop_p > 0.: 74 | decoder_emb = Dropout(drop_p)(decoder_emb, training=training) 75 | 76 | decoder_rnn = rnn(h, return_sequences=True, name='decoder_rnn') 77 | decoder_outputs = decoder_rnn(decoder_emb, initial_state=encoder_states) 78 | 79 | if drop_p > 0.: 80 | decoder_outputs = Dropout(drop_p)(decoder_outputs, training=training) 81 | 82 | if tied: 83 | final_outputs = DenseTransposeTied(Vt, tied_to=decoder_emb_layer, 84 | activation='linear', name='outputs')(decoder_outputs) 85 | else: 86 | final_outputs = Dense(Vt, activation='linear', name='outputs')(decoder_outputs) 87 | 88 | model = Model(inputs=[encoder_input, decoder_input], outputs=[final_outputs]) 89 | return model 90 | 91 | 92 | def build_inference_decoder(mask=False, demb=128, h=128, Vt=5000, tied=True): 93 | rnn = LSTM if mask else CuDNNLSTM 94 | 95 | # build decoder 96 | decoder_input = Input(batch_shape=(None, None), dtype='float32', name='decoder_input') 97 | encoder_outputs = Input(batch_shape=(None, None, h), dtype='float32', name='encoder_outputs') 98 | encoder_h = Input(batch_shape=(None, h), dtype='float32', name='encoder_h') 99 | encoder_c = Input(batch_shape=(None, h), dtype='float32', name='encoder_c') 100 | 101 | if mask: 102 | decoder_emb_layer = Embedding(Vt + 1, demb, mask_zero=True, 103 | name='decoder_emb') 104 | else: 105 | decoder_emb_layer = Embedding(Vt, demb, mask_zero=False, 106 | name='decoder_emb') 107 | 108 | decoder_emb = decoder_emb_layer(decoder_input) 109 | 110 | decoder_rnn = rnn(h, return_sequences=True, name='decoder_rnn') 111 | decoder_outputs = decoder_rnn(decoder_emb, initial_state=[encoder_h, encoder_c]) 112 | 113 | if tied: 114 | final_outputs = DenseTransposeTied(Vt, name='outputs', 115 | tied_to=decoder_emb_layer, activation='linear')(decoder_outputs) 116 | else: 117 | final_outputs = Dense(Vt, activation='linear', name='outputs')(decoder_outputs) 118 | 119 | inputs = [decoder_input, encoder_outputs, encoder_h, encoder_c] 120 | model = Model(inputs=inputs, outputs=[final_outputs]) 121 | return model 122 | 123 | 124 | def words_to_indices(data, vocab, mask=True): 125 | if mask: 126 | return [[vocab[w] + 1 for w in t] for t in data] 127 | else: 128 | return [[vocab[w] for w in t] for t in data] 129 | 130 | 131 | def pad_texts(texts, eos, mask=True): 132 | maxlen = max(len(t) for t in texts) 133 | for t in texts: 134 | while len(t) < maxlen: 135 | if mask: 136 | t.insert(0, 0) 137 | else: 138 | t.append(eos) 139 | return np.asarray(texts, dtype='float32') 140 | 141 | 142 | def train_cornell_movie(loo=0, num_users=200, num_words=5000, num_epochs=20, sample_user=False, exp_id=0, emb_h=128, 143 | lr=0.001, batch_size=32, mask=False, drop_p=0.5, h=128, user_data_ratio=0., cross_domain=False, 144 | ablation=False, tied=True, rnn_fn='gru'): 145 | if cross_domain: 146 | sample_user = True 147 | loo = None 148 | user_src_texts, user_trg_texts, dev_src_texts, dev_trg_texts, test_src_texts, test_trg_texts, \ 149 | src_vocabs, trg_vocabs = load_ubuntu_by_user(num_users, num_words=num_words) 150 | else: 151 | user_src_texts, user_trg_texts, dev_src_texts, dev_trg_texts, test_src_texts, test_trg_texts, \ 152 | src_vocabs, trg_vocabs = load_cornell_movie_by_user(num_users, num_words, user_data_ratio=user_data_ratio, 153 | sample_user=sample_user) 154 | train_src_texts, train_trg_texts = [], [] 155 | 156 | users = sorted(user_src_texts.keys()) 157 | 158 | for i, user in enumerate(users): 159 | if loo is not None and i == loo: 160 | print "Leave user {} out".format(user) 161 | continue 162 | train_src_texts += user_src_texts[user] 163 | train_trg_texts += user_trg_texts[user] 164 | 165 | train_src_texts = words_to_indices(train_src_texts, src_vocabs, mask=mask) 166 | train_trg_texts = words_to_indices(train_trg_texts, trg_vocabs, mask=mask) 167 | dev_src_texts = words_to_indices(dev_src_texts, src_vocabs, mask=mask) 168 | dev_trg_texts = words_to_indices(dev_trg_texts, trg_vocabs, mask=mask) 169 | 170 | print "Num train data {}, num test data {}".format(len(train_src_texts), len(dev_src_texts)) 171 | 172 | Vs = len(src_vocabs) 173 | Vt = len(trg_vocabs) 174 | print Vs, Vt 175 | 176 | model = build_dialogue_model(Vs=Vs, Vt=Vt, mask=mask, drop_p=drop_p, demb=emb_h, h=h, tied=tied, rnn_fn=rnn_fn) 177 | src_input_var, trg_input_var = model.inputs 178 | prediction = model.output 179 | 180 | trg_label_var = K.placeholder((None, None), dtype='float32') 181 | 182 | loss = K.sparse_categorical_crossentropy(trg_label_var, prediction, from_logits=True) 183 | loss = K.mean(K.sum(loss, axis=-1)) 184 | 185 | optimizer = Adam(lr=lr) 186 | 187 | updates = optimizer.get_updates(loss, model.trainable_weights) 188 | train_fn = K.function([src_input_var, trg_input_var, trg_label_var, K.learning_phase()], [loss], updates=updates) 189 | pred_fn = K.function([src_input_var, trg_input_var, trg_label_var, K.learning_phase()], [loss]) 190 | 191 | # pad batches to same length 192 | batches = [] 193 | padded_train_src_texts = copy.deepcopy(train_src_texts) 194 | padded_train_trg_texts = copy.deepcopy(train_trg_texts) 195 | for batch in group_texts_by_len(padded_train_src_texts, padded_train_trg_texts, bs=batch_size): 196 | src_input, trg_input = batch 197 | src_input = pad_texts(src_input, src_vocabs[''], mask=mask) 198 | trg_input = pad_texts(trg_input, trg_vocabs[''], mask=mask) 199 | batches.append((src_input, trg_input)) 200 | 201 | for epoch in range(num_epochs): 202 | np.random.shuffle(batches) 203 | 204 | for batch in batches: 205 | src_input, trg_input = batch 206 | _ = train_fn([src_input, trg_input[:, :-1], trg_input[:, 1:], 1])[0] 207 | 208 | train_loss, train_it = get_perp(train_src_texts, train_trg_texts, pred_fn, shuffle=True, prop=0.5) 209 | test_loss, test_it = get_perp(dev_src_texts, dev_trg_texts, pred_fn) 210 | 211 | print "Epoch {}, train loss={:.3f}, train perp={:.3f}, test loss={:.3f}, test perp={:.3f}".format( 212 | epoch, train_loss / len(train_src_texts) / 0.5, 213 | np.exp(train_loss / train_it), test_loss / len(dev_src_texts), 214 | np.exp(test_loss / test_it)) 215 | 216 | if cross_domain: 217 | fname = 'ubuntu_dialog' 218 | else: 219 | fname = 'cornell_movie_dialog{}'.format('' if loo is None else loo) 220 | 221 | if ablation: 222 | fname = 'ablation_' + fname 223 | 224 | if 0. < user_data_ratio < 1.: 225 | fname += '_dr{}'.format(user_data_ratio) 226 | 227 | if sample_user: 228 | fname += '_shadow_exp{}_{}'.format(exp_id, rnn_fn) 229 | np.savez(MODEL_PATH + 'shadow_users{}_{}_{}_{}.npz'.format(exp_id, rnn_fn, num_users, 230 | 'cd' if cross_domain else ''), users) 231 | 232 | model.save(MODEL_PATH + '{}_{}.h5'.format(fname, num_users)) 233 | 234 | 235 | def get_perp(user_src_data, user_trg_data, pred_fn, prop=1.0, shuffle=False): 236 | loss = 0. 237 | iters = 0. 238 | 239 | indices = np.arange(len(user_src_data)) 240 | n = int(prop * len(indices)) 241 | 242 | if shuffle: 243 | np.random.shuffle(indices) 244 | 245 | for idx in indices[:n]: 246 | src_text = np.asarray(user_src_data[idx], dtype=np.float32).reshape(1, -1) 247 | trg_text = np.asarray(user_trg_data[idx], dtype=np.float32) 248 | trg_input = trg_text[:-1].reshape(1, -1) 249 | trg_label = trg_text[1:].reshape(1, -1) 250 | 251 | err = pred_fn([src_text, trg_input, trg_label, 0])[0] 252 | 253 | loss += err 254 | iters += trg_label.shape[1] 255 | 256 | return loss, iters 257 | 258 | 259 | if __name__ == '__main__': 260 | train_cornell_movie(loo=None, num_users=300, sample_user=False, num_epochs=30, drop_p=0.5, h=128, emb_h=128) 261 | -------------------------------------------------------------------------------- /dialogue_ranks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter, defaultdict 3 | from itertools import chain 4 | 5 | import keras.backend as K 6 | import numpy as np 7 | from sklearn.ensemble import RandomForestClassifier 8 | from sklearn.metrics import roc_auc_score, accuracy_score, classification_report 9 | from sklearn.preprocessing import StandardScaler 10 | 11 | from dialogue import build_dialogue_model, words_to_indices, MODEL_PATH, OUTPUT_PATH 12 | from helper import flatten_data 13 | from data_loader.load_cornell_movie import load_extracted_cornell_movie, process_texts, process_vocabs, \ 14 | load_cornell_movie_by_user, load_extracted_ubuntu 15 | from sated_nmt_ranks import save_users_rank_results 16 | 17 | 18 | def load_cross_domain_shadow_user_data(train_users, num_users=100, num_words=5000, num_data_per_user=200): 19 | src_texts, trg_texts = load_extracted_ubuntu(num_users * num_data_per_user * 2) 20 | 21 | all_users = np.arange(num_users * 2) 22 | test_users = np.setdiff1d(all_users, train_users) 23 | 24 | user_src_texts = defaultdict(list) 25 | user_trg_texts = defaultdict(list) 26 | 27 | test_user_src_texts = defaultdict(list) 28 | test_user_trg_texts = defaultdict(list) 29 | 30 | for u in train_users: 31 | user_src_texts[u] = src_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 32 | user_trg_texts[u] = trg_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 33 | 34 | for u in test_users: 35 | test_user_src_texts[u] = src_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 36 | test_user_trg_texts[u] = trg_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 37 | 38 | src_words = [] 39 | trg_words = [] 40 | for u in train_users: 41 | src_words += list(chain(*user_src_texts[u])) 42 | trg_words += list(chain(*user_trg_texts[u])) 43 | 44 | src_vocabs = process_vocabs(src_words, num_words) 45 | trg_vocabs = process_vocabs(trg_words, num_words) 46 | 47 | for u in train_users: 48 | process_texts(user_src_texts[u], src_vocabs) 49 | process_texts(user_trg_texts[u], trg_vocabs) 50 | 51 | for u in test_users: 52 | process_texts(test_user_src_texts[u], src_vocabs) 53 | process_texts(test_user_trg_texts[u], trg_vocabs) 54 | 55 | src_words = [] 56 | trg_words = [] 57 | 58 | for u in train_users: 59 | src_words += list(chain(*user_src_texts[u])) 60 | trg_words += list(chain(*user_trg_texts[u])) 61 | 62 | src_vocabs = process_vocabs(src_words, None) 63 | trg_vocabs = process_vocabs(trg_words, None) 64 | 65 | return user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs 66 | 67 | 68 | def load_shadow_user_data(train_users, num_users=100, num_words=10000, min_count=20): 69 | train_data, dev_data, test_data = load_extracted_cornell_movie(dev_size=5000, test_size=5000) 70 | train_src_texts, train_trg_texts, src_users, _ = train_data 71 | 72 | user_counter = Counter(src_users) 73 | all_users = np.asarray([tup[0] for tup in user_counter.most_common() if tup[1] >= min_count]) 74 | print 'Loaded {} users'.format(len(all_users)) 75 | 76 | np.random.seed(12345) 77 | np.random.shuffle(all_users) 78 | np.random.seed(None) 79 | 80 | attacker_users = all_users[num_users * 2: num_users * 4] 81 | test_users = np.setdiff1d(attacker_users, train_users) 82 | 83 | user_src_texts = defaultdict(list) 84 | user_trg_texts = defaultdict(list) 85 | 86 | test_user_src_texts = defaultdict(list) 87 | test_user_trg_texts = defaultdict(list) 88 | 89 | for u, s, t in zip(src_users, train_src_texts, train_trg_texts): 90 | if u in train_users: 91 | user_src_texts[u].append(s) 92 | user_trg_texts[u].append(t) 93 | if u in test_users: 94 | test_user_src_texts[u].append(s) 95 | test_user_trg_texts[u].append(t) 96 | 97 | src_words = [] 98 | trg_words = [] 99 | for u in train_users: 100 | src_words += list(chain(*user_src_texts[u])) 101 | trg_words += list(chain(*user_trg_texts[u])) 102 | 103 | src_vocabs = process_vocabs(src_words, num_words) 104 | trg_vocabs = process_vocabs(trg_words, num_words) 105 | 106 | for u in train_users: 107 | process_texts(user_src_texts[u], src_vocabs) 108 | process_texts(user_trg_texts[u], trg_vocabs) 109 | 110 | for u in test_users: 111 | process_texts(test_user_src_texts[u], src_vocabs) 112 | process_texts(test_user_trg_texts[u], trg_vocabs) 113 | 114 | src_words = [] 115 | trg_words = [] 116 | 117 | for u in train_users: 118 | src_words += list(chain(*user_src_texts[u])) 119 | trg_words += list(chain(*user_trg_texts[u])) 120 | 121 | src_vocabs = process_vocabs(src_words, None) 122 | trg_vocabs = process_vocabs(trg_words, None) 123 | 124 | return user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs 125 | 126 | 127 | def get_ranks(user_src_data, user_trg_data, pred_fn): 128 | indices = np.arange(len(user_src_data)) 129 | 130 | ranks = [] 131 | for idx in indices: 132 | src_text = np.asarray(user_src_data[idx], dtype=np.float32).reshape(1, -1) 133 | trg_text = np.asarray(user_trg_data[idx], dtype=np.float32) 134 | trg_input = trg_text[:-1].reshape(1, -1) 135 | trg_label = trg_text[1:].reshape(1, -1) 136 | 137 | prob = pred_fn([src_text, trg_input, trg_label, 0])[0][0] 138 | sent_ranks = [] 139 | for p, t in zip(prob, trg_label.flatten()): 140 | t = int(t) 141 | rank = (-p).argsort().argsort()[t] 142 | sent_ranks.append(rank) 143 | ranks.append(sent_ranks) 144 | return ranks 145 | 146 | 147 | def get_shadow_ranks(exp_id=0, num_users=200, num_words=5000, mask=False, cross_domain=False, rnn_fn='lstm', 148 | h=128, emb_h=128, rerun=False): 149 | shadow_user_path = 'shadow_users{}_{}_{}_{}.npz'.format(exp_id, rnn_fn, num_users, 'cd' if cross_domain else '') 150 | shadow_train_users = np.load(MODEL_PATH + shadow_user_path)['arr_0'] 151 | shadow_train_users = list(shadow_train_users) 152 | print shadow_user_path, shadow_train_users 153 | 154 | save_dir = OUTPUT_PATH + 'shadow_exp{}_{}/'.format(exp_id, num_users) 155 | if not os.path.exists(save_dir): 156 | os.mkdir(save_dir) 157 | 158 | if cross_domain: 159 | user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs \ 160 | = load_cross_domain_shadow_user_data(shadow_train_users, num_users, num_words) 161 | else: 162 | user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs \ 163 | = load_shadow_user_data(shadow_train_users, num_users, num_words) 164 | shadow_test_users = sorted(test_user_src_texts.keys()) 165 | 166 | model_path = '{}_shadow_exp{}_{}_{}.h5'.format('ubuntu_dialog' if cross_domain else 'cornell_movie_dialog', 167 | exp_id, rnn_fn, num_users) 168 | 169 | model = build_dialogue_model(Vs=num_words, Vt=num_words, mask=mask, drop_p=0., h=h, demb=emb_h, rnn_fn=rnn_fn) 170 | model.load_weights(MODEL_PATH + model_path) 171 | 172 | src_input_var, trg_input_var = model.inputs 173 | prediction = model.output 174 | trg_label_var = K.placeholder((None, None), dtype='float32') 175 | 176 | prediction = K.softmax(prediction) 177 | prob_fn = K.function([src_input_var, trg_input_var, trg_label_var, K.learning_phase()], [prediction]) 178 | 179 | save_users_rank_results(users=shadow_train_users, rerun=rerun, 180 | user_src_texts=user_src_texts, user_trg_texts=user_trg_texts, 181 | src_vocabs=src_vocabs, trg_vocabs=trg_vocabs, cross_domain=cross_domain, 182 | prob_fn=prob_fn, save_dir=save_dir, member_label=1) 183 | save_users_rank_results(users=shadow_test_users, rerun=rerun, 184 | user_src_texts=test_user_src_texts, user_trg_texts=test_user_trg_texts, 185 | src_vocabs=src_vocabs, trg_vocabs=trg_vocabs, cross_domain=cross_domain, 186 | prob_fn=prob_fn, save_dir=save_dir, member_label=0) 187 | 188 | 189 | def load_train_users_heldout_data(train_users, src_vocabs, trg_vocabs, user_data_ratio=0.5): 190 | train_data, dev_data, test_data = load_extracted_cornell_movie(dev_size=5000, test_size=5000) 191 | train_src_texts, train_trg_texts, src_users, _ = train_data 192 | 193 | user_src_texts = defaultdict(list) 194 | user_trg_texts = defaultdict(list) 195 | 196 | for u, s, t in zip(src_users, train_src_texts, train_trg_texts): 197 | if u in train_users: 198 | user_src_texts[u].append(s) 199 | user_trg_texts[u].append(t) 200 | 201 | assert 0. < user_data_ratio < 1. 202 | # held out some fraction of data for testing 203 | for u in user_src_texts: 204 | l = len(user_src_texts[u]) 205 | l = int(l * user_data_ratio) 206 | user_src_texts[u] = user_src_texts[u][l:] 207 | user_trg_texts[u] = user_trg_texts[u][l:] 208 | 209 | for u in train_users: 210 | process_texts(user_src_texts[u], src_vocabs) 211 | process_texts(user_trg_texts[u], trg_vocabs) 212 | 213 | return user_src_texts, user_trg_texts 214 | 215 | 216 | def get_target_ranks(num_users=200, num_words=5000, mask=False, user_data_ratio=0., save_probs=False): 217 | user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs \ 218 | = load_cornell_movie_by_user(num_users, num_words, test_on_user=True, user_data_ratio=user_data_ratio) 219 | 220 | train_users = sorted(user_src_texts.keys()) 221 | test_users = sorted(test_user_src_texts.keys()) 222 | 223 | save_dir = OUTPUT_PATH + 'target_{}{}/'.format(num_users, '_dr' if 0. < user_data_ratio < 1. else '') 224 | if not os.path.exists(save_dir): 225 | os.mkdir(save_dir) 226 | 227 | model_path = 'cornell_movie_dialog' 228 | 229 | if 0. < user_data_ratio < 1.: 230 | model_path += '_dr{}'.format(user_data_ratio) 231 | heldout_src_texts, heldout_trg_texts = load_train_users_heldout_data(train_users, src_vocabs, trg_vocabs) 232 | for u in train_users: 233 | user_src_texts[u] += heldout_src_texts[u] 234 | user_trg_texts[u] += heldout_trg_texts[u] 235 | 236 | model = build_dialogue_model(Vs=num_words, Vt=num_words, mask=mask, drop_p=0.) 237 | model.load_weights(MODEL_PATH + '{}_{}.h5'.format(model_path, num_users)) 238 | 239 | src_input_var, trg_input_var = model.inputs 240 | prediction = model.output 241 | trg_label_var = K.placeholder((None, None), dtype='float32') 242 | 243 | prediction = K.softmax(prediction) 244 | prob_fn = K.function([src_input_var, trg_input_var, trg_label_var, K.learning_phase()], [prediction]) 245 | 246 | save_users_rank_results(users=train_users, save_probs=save_probs, 247 | user_src_texts=user_src_texts, user_trg_texts=user_trg_texts, 248 | src_vocabs=src_vocabs, trg_vocabs=trg_vocabs, cross_domain=False, 249 | prob_fn=prob_fn, save_dir=save_dir, member_label=1) 250 | save_users_rank_results(users=test_users, save_probs=save_probs, 251 | user_src_texts=test_user_src_texts, user_trg_texts=test_user_trg_texts, 252 | src_vocabs=src_vocabs, trg_vocabs=trg_vocabs, cross_domain=False, 253 | prob_fn=prob_fn, save_dir=save_dir, member_label=0) 254 | 255 | 256 | if __name__ == '__main__': 257 | get_target_ranks(num_users=300, save_probs=False) 258 | -------------------------------------------------------------------------------- /sated_nmt.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import keras.backend as K 4 | import numpy as np 5 | from keras import Model 6 | from keras.layers import Input, Embedding, LSTM, Dropout, Dense, CuDNNLSTM, Add, CuDNNGRU 7 | from keras.optimizers import Adam 8 | from keras.regularizers import l2 9 | 10 | from data_loader.load_sated import load_europarl_by_user, load_sated_data_by_user 11 | from helper import DenseTransposeTied, Attention 12 | 13 | MODEL_PATH = '/hdd/song/nlp/sated-release-0.9.0/model/' 14 | OUTPUT_PATH = '/hdd/song/nlp/sated-release-0.9.0/output/' 15 | 16 | 17 | def group_texts_by_len(src_texts, trg_texts, bs=20): 18 | print("Bucketing batches") 19 | # Bucket samples by source sentence length 20 | buckets = defaultdict(list) 21 | batches = [] 22 | for src, trg in zip(src_texts, trg_texts): 23 | buckets[len(src)].append((src, trg)) 24 | 25 | for src_len, bucket in buckets.items(): 26 | np.random.shuffle(bucket) 27 | num_batches = int(np.ceil(len(bucket) * 1.0 / bs)) 28 | for i in range(num_batches): 29 | cur_batch_size = bs if i < num_batches - 1 else len(bucket) - bs * i 30 | batches.append(([bucket[i * bs + j][0] for j in range(cur_batch_size)], 31 | [bucket[i * bs + j][1] for j in range(cur_batch_size)])) 32 | return batches 33 | 34 | 35 | def build_nmt_model(Vs, Vt, demb=128, h=128, drop_p=0.5, tied=True, mask=True, attn=True, l2_ratio=1e-4, 36 | training=None, rnn_fn='lstm'): 37 | if rnn_fn == 'lstm': 38 | rnn = LSTM if mask else CuDNNLSTM 39 | elif rnn_fn == 'gru': 40 | rnn = LSTM if mask else CuDNNGRU 41 | else: 42 | raise ValueError(rnn_fn) 43 | 44 | # build encoder 45 | encoder_input = Input((None,), dtype='float32', name='encoder_input') 46 | if mask: 47 | encoder_emb_layer = Embedding(Vs + 1, demb, mask_zero=True, embeddings_regularizer=l2(l2_ratio), 48 | name='encoder_emb') 49 | else: 50 | encoder_emb_layer = Embedding(Vs, demb, mask_zero=False, embeddings_regularizer=l2(l2_ratio), 51 | name='encoder_emb') 52 | 53 | encoder_emb = encoder_emb_layer(encoder_input) 54 | 55 | if drop_p > 0.: 56 | encoder_emb = Dropout(drop_p)(encoder_emb, training=training) 57 | 58 | encoder_rnn = rnn(h, return_sequences=True, return_state=True, kernel_regularizer=l2(l2_ratio), name='encoder_rnn') 59 | encoder_rtn = encoder_rnn(encoder_emb) 60 | # encoder_outputs, encoder_h, encoder_c = encoder_rnn(encoder_emb) 61 | encoder_outputs = encoder_rtn[0] 62 | encoder_states = encoder_rtn[1:] 63 | 64 | # build decoder 65 | decoder_input = Input((None,), dtype='float32', name='decoder_input') 66 | if mask: 67 | decoder_emb_layer = Embedding(Vt + 1, demb, mask_zero=True, embeddings_regularizer=l2(l2_ratio), 68 | name='decoder_emb') 69 | else: 70 | decoder_emb_layer = Embedding(Vt, demb, mask_zero=False, embeddings_regularizer=l2(l2_ratio), 71 | name='decoder_emb') 72 | 73 | decoder_emb = decoder_emb_layer(decoder_input) 74 | 75 | if drop_p > 0.: 76 | decoder_emb = Dropout(drop_p)(decoder_emb, training=training) 77 | 78 | decoder_rnn = rnn(h, return_sequences=True, kernel_regularizer=l2(l2_ratio), name='decoder_rnn') 79 | decoder_outputs = decoder_rnn(decoder_emb, initial_state=encoder_states) 80 | 81 | if drop_p > 0.: 82 | decoder_outputs = Dropout(drop_p)(decoder_outputs, training=training) 83 | 84 | if tied: 85 | final_outputs = DenseTransposeTied(Vt, kernel_regularizer=l2(l2_ratio), name='outputs', 86 | tied_to=decoder_emb_layer, activation='linear')(decoder_outputs) 87 | else: 88 | final_outputs = Dense(Vt, activation='linear', kernel_regularizer=l2(l2_ratio), name='outputs')(decoder_outputs) 89 | 90 | if attn: 91 | contexts = Attention(units=h, kernel_regularizer=l2(l2_ratio), name='attention', 92 | use_bias=False)([encoder_outputs, decoder_outputs]) 93 | if drop_p > 0.: 94 | contexts = Dropout(drop_p)(contexts, training=training) 95 | 96 | contexts_outputs = Dense(Vt, activation='linear', use_bias=False, name='context_outputs', 97 | kernel_regularizer=l2(l2_ratio))(contexts) 98 | 99 | final_outputs = Add(name='final_outputs')([final_outputs, contexts_outputs]) 100 | 101 | model = Model(inputs=[encoder_input, decoder_input], outputs=[final_outputs]) 102 | return model 103 | 104 | 105 | def build_inference_decoder(mask=False, demb=128, h=128, Vt=5000, tied=True, attn=True): 106 | rnn = LSTM if mask else CuDNNLSTM 107 | 108 | # build decoder 109 | decoder_input = Input(batch_shape=(None, None), dtype='float32', name='decoder_input') 110 | encoder_outputs = Input(batch_shape=(None, None, h), dtype='float32', name='encoder_outputs') 111 | encoder_h = Input(batch_shape=(None, h), dtype='float32', name='encoder_h') 112 | encoder_c = Input(batch_shape=(None, h), dtype='float32', name='encoder_c') 113 | 114 | if mask: 115 | decoder_emb_layer = Embedding(Vt + 1, demb, mask_zero=True, 116 | name='decoder_emb') 117 | else: 118 | decoder_emb_layer = Embedding(Vt, demb, mask_zero=False, 119 | name='decoder_emb') 120 | 121 | decoder_emb = decoder_emb_layer(decoder_input) 122 | 123 | decoder_rnn = rnn(h, return_sequences=True, name='decoder_rnn') 124 | decoder_outputs = decoder_rnn(decoder_emb, initial_state=[encoder_h, encoder_c]) 125 | 126 | if tied: 127 | final_outputs = DenseTransposeTied(Vt, name='outputs', 128 | tied_to=decoder_emb_layer, activation='linear')(decoder_outputs) 129 | else: 130 | final_outputs = Dense(Vt, activation='linear', name='outputs')(decoder_outputs) 131 | 132 | if attn: 133 | contexts = Attention(units=h, use_bias=False, name='attention')([encoder_outputs, decoder_outputs]) 134 | contexts_outputs = Dense(Vt, activation='linear', use_bias=False, name='context_outputs')(contexts) 135 | final_outputs = Add(name='final_outputs')([final_outputs, contexts_outputs]) 136 | 137 | inputs = [decoder_input, encoder_outputs, encoder_h, encoder_c] 138 | model = Model(inputs=inputs, outputs=[final_outputs]) 139 | return model 140 | 141 | 142 | def words_to_indices(data, vocab, mask=True): 143 | if mask: 144 | return [[vocab[w] + 1 for w in t] for t in data] 145 | else: 146 | return [[vocab[w] for w in t] for t in data] 147 | 148 | 149 | def pad_texts(texts, eos, mask=True): 150 | maxlen = max(len(t) for t in texts) 151 | for t in texts: 152 | while len(t) < maxlen: 153 | if mask: 154 | t.insert(0, 0) 155 | else: 156 | t.append(eos) 157 | return np.asarray(texts, dtype='float32') 158 | 159 | 160 | def get_perp(user_src_data, user_trg_data, pred_fn, prop=1.0, shuffle=False): 161 | loss = 0. 162 | iters = 0. 163 | 164 | indices = np.arange(len(user_src_data)) 165 | n = int(prop * len(indices)) 166 | 167 | if shuffle: 168 | np.random.shuffle(indices) 169 | 170 | for idx in indices[:n]: 171 | src_text = np.asarray(user_src_data[idx], dtype=np.float32).reshape(1, -1) 172 | trg_text = np.asarray(user_trg_data[idx], dtype=np.float32) 173 | trg_input = trg_text[:-1].reshape(1, -1) 174 | trg_label = trg_text[1:].reshape(1, -1) 175 | 176 | err = pred_fn([src_text, trg_input, trg_label, 0])[0] 177 | 178 | loss += err 179 | iters += trg_label.shape[1] 180 | 181 | return loss, iters 182 | 183 | 184 | def train_sated_nmt(loo=0, num_users=200, num_words=5000, num_epochs=20, h=128, emb_h=128, l2_ratio=1e-4, exp_id=0, 185 | lr=0.001, batch_size=32, mask=False, drop_p=0.5, cross_domain=False, tied=False, ablation=False, 186 | sample_user=False, user_data_ratio=0., rnn_fn='lstm'): 187 | if cross_domain: 188 | sample_user = True 189 | user_src_texts, user_trg_texts, dev_src_texts, dev_trg_texts, test_src_texts, test_trg_texts,\ 190 | src_vocabs, trg_vocabs = load_europarl_by_user(num_users=num_users, num_words=num_words) 191 | else: 192 | user_src_texts, user_trg_texts, dev_src_texts, dev_trg_texts, test_src_texts, test_trg_texts,\ 193 | src_vocabs, trg_vocabs = load_sated_data_by_user(num_users, num_words, sample_user=sample_user, 194 | user_data_ratio=user_data_ratio) 195 | train_src_texts, train_trg_texts = [], [] 196 | 197 | users = sorted(user_src_texts.keys()) 198 | 199 | for i, user in enumerate(users): 200 | if loo is not None and i == loo: 201 | print "Leave user {} out".format(user) 202 | continue 203 | train_src_texts += user_src_texts[user] 204 | train_trg_texts += user_trg_texts[user] 205 | 206 | train_src_texts = words_to_indices(train_src_texts, src_vocabs, mask=mask) 207 | train_trg_texts = words_to_indices(train_trg_texts, trg_vocabs, mask=mask) 208 | dev_src_texts = words_to_indices(dev_src_texts, src_vocabs, mask=mask) 209 | dev_trg_texts = words_to_indices(dev_trg_texts, trg_vocabs, mask=mask) 210 | 211 | print "Num train data {}, num test data {}".format(len(train_src_texts), len(dev_src_texts)) 212 | 213 | Vs = len(src_vocabs) 214 | Vt = len(trg_vocabs) 215 | print Vs, Vt 216 | 217 | model = build_nmt_model(Vs=Vs, Vt=Vt, mask=mask, drop_p=drop_p, h=h, demb=emb_h, tied=tied, l2_ratio=l2_ratio, 218 | rnn_fn=rnn_fn) 219 | src_input_var, trg_input_var = model.inputs 220 | prediction = model.output 221 | 222 | trg_label_var = K.placeholder((None, None), dtype='float32') 223 | 224 | loss = K.sparse_categorical_crossentropy(trg_label_var, prediction, from_logits=True) 225 | loss = K.mean(K.sum(loss, axis=-1)) 226 | 227 | optimizer = Adam(lr=lr, clipnorm=5.) 228 | 229 | updates = optimizer.get_updates(loss, model.trainable_weights) 230 | train_fn = K.function([src_input_var, trg_input_var, trg_label_var, K.learning_phase()], [loss], updates=updates) 231 | pred_fn = K.function([src_input_var, trg_input_var, trg_label_var, K.learning_phase()], [loss]) 232 | 233 | # pad batches to same length 234 | train_prop = 0.2 235 | batches = [] 236 | for batch in group_texts_by_len(train_src_texts, train_trg_texts, bs=batch_size): 237 | src_input, trg_input = batch 238 | src_input = pad_texts(src_input, src_vocabs[''], mask=mask) 239 | trg_input = pad_texts(trg_input, trg_vocabs[''], mask=mask) 240 | batches.append((src_input, trg_input)) 241 | 242 | for epoch in range(num_epochs): 243 | np.random.shuffle(batches) 244 | 245 | for batch in batches: 246 | src_input, trg_input = batch 247 | _ = train_fn([src_input, trg_input[:, :-1], trg_input[:, 1:], 1])[0] 248 | 249 | train_loss, train_it = get_perp(train_src_texts, train_trg_texts, pred_fn, shuffle=True, prop=train_prop) 250 | test_loss, test_it = get_perp(dev_src_texts, dev_trg_texts, pred_fn) 251 | 252 | print "Epoch {}, train loss={:.3f}, train perp={:.3f}, test loss={:.3f}, test perp={:.3f}".format( 253 | epoch, 254 | train_loss / len(train_src_texts) / train_prop, 255 | np.exp(train_loss / train_it), 256 | test_loss / len(dev_src_texts), 257 | np.exp(test_loss / test_it)) 258 | 259 | if cross_domain: 260 | fname = 'europal_nmt{}'.format('' if loo is None else loo) 261 | else: 262 | fname = 'sated_nmt{}'.format('' if loo is None else loo) 263 | 264 | if ablation: 265 | fname = 'ablation_' + fname 266 | 267 | if 0. < user_data_ratio < 1.: 268 | fname += '_dr{}'.format(user_data_ratio) 269 | 270 | if sample_user: 271 | fname += '_shadow_exp{}_{}'.format(exp_id, rnn_fn) 272 | np.savez(MODEL_PATH + 'shadow_users{}_{}_{}_{}.npz'.format(exp_id, rnn_fn, num_users, 273 | 'cd' if cross_domain else ''), users) 274 | 275 | model.save(MODEL_PATH + '{}_{}.h5'.format(fname, num_users)) 276 | K.clear_session() 277 | 278 | 279 | if __name__ == '__main__': 280 | epochs = 30 281 | train_sated_nmt(loo=None, sample_user=False, cross_domain=False, h=128, emb_h=128, 282 | num_epochs=30, num_users=300, drop_p=0.5, rnn_fn='lstm') 283 | 284 | -------------------------------------------------------------------------------- /data_loader/load_cornell_movie.py: -------------------------------------------------------------------------------- 1 | import io 2 | import ast 3 | import numpy as np 4 | import os 5 | 6 | from nltk.tokenize import word_tokenize 7 | from collections import Counter, defaultdict 8 | from itertools import chain 9 | from load_sated import process_vocabs, process_texts 10 | 11 | 12 | DATA_PATH = '/hdd/song/nlp/cornell_movie_dialogs_corpus/' 13 | MOVIE_LINES_PATH = DATA_PATH + 'movie_lines.txt' 14 | MOVIE_CONVERSATIONS_PATH = DATA_PATH + 'movie_conversations.txt' 15 | SEPARATOR = ' +++$+++ ' 16 | 17 | MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"] 18 | MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"] 19 | 20 | 21 | def load_lines(filename, fields=MOVIE_LINES_FIELDS): 22 | lines = {} 23 | 24 | with io.open(filename, 'r', encoding='iso-8859-1') as f: 25 | for line in f: 26 | values = line.replace('\n', '').split(SEPARATOR) 27 | 28 | # Extract fields 29 | line_obj = {} 30 | for i, field in enumerate(fields): 31 | line_obj[field] = values[i] 32 | 33 | lines[line_obj['lineID']] = line_obj 34 | 35 | return lines 36 | 37 | 38 | def load_conversations(lines, filename, fields=MOVIE_CONVERSATIONS_FIELDS): 39 | conversations = [] 40 | 41 | with io.open(filename, 'r', encoding='iso-8859-1') as f: 42 | for line in f: 43 | values = line.replace('\n', '').split(SEPARATOR) 44 | 45 | # Extract fields 46 | conv_obj = {} 47 | for i, field in enumerate(fields): 48 | conv_obj[field] = values[i] 49 | 50 | # Convert string to list (conv_obj["utteranceIDs"] == "['L598485', 'L598486', ...]") 51 | line_ids = ast.literal_eval(conv_obj["utteranceIDs"]) 52 | 53 | # Reassemble lines 54 | conv_obj["lines"] = [] 55 | for line_id in line_ids: 56 | conv_obj["lines"].append(lines[line_id]) 57 | 58 | conversations.append(conv_obj) 59 | 60 | return conversations 61 | 62 | 63 | def count_character_lines(): 64 | lines = load_lines(MOVIE_LINES_PATH) 65 | character_counter = Counter() 66 | for line_id in lines: 67 | line = lines[line_id] 68 | character = line["characterID"] 69 | character_counter[character] += 1 70 | 71 | print len(character_counter) 72 | print character_counter.most_common(100) 73 | 74 | 75 | def save_extracted_cornell_movie(): 76 | lines = load_lines(MOVIE_LINES_PATH) 77 | conversations = load_conversations(lines, MOVIE_CONVERSATIONS_PATH) 78 | src_texts = [] 79 | trg_texts = [] 80 | src_chars = [] 81 | trg_chars = [] 82 | 83 | for conv_obj in conversations: 84 | conv_line_objs = conv_obj["lines"] 85 | n_lines = len(conv_line_objs) 86 | for i in range(n_lines - 1): 87 | st = word_tokenize(conv_line_objs[i]["text"]) 88 | tt = word_tokenize(conv_line_objs[i + 1]["text"]) 89 | 90 | src_texts.append('\t'.join(st)) 91 | trg_texts.append('\t'.join(tt)) 92 | 93 | src_chars.append(conv_line_objs[i]["characterID"]) 94 | trg_chars.append(conv_line_objs[i + 1]["characterID"]) 95 | 96 | print len(src_texts), len(trg_texts) 97 | 98 | with io.open(DATA_PATH + 'extracted_src_trg.txt', 'w', encoding='iso-8859-1') as f: 99 | for st, sc, tt, tc in zip(src_texts, src_chars, trg_texts, trg_chars): 100 | f.write(SEPARATOR.join([st, sc, tt, tc]) + '\n') 101 | 102 | 103 | def load_extracted_cornell_movie(dev_size=10000, test_size=10000): 104 | src_texts = [] 105 | trg_texts = [] 106 | src_chars = [] 107 | trg_chars = [] 108 | 109 | with io.open(DATA_PATH + 'extracted_src_trg.txt', 'r', encoding='iso-8859-1') as f: 110 | for line in f: 111 | st, sc, tt, tc = line.lower().replace('\n', '').split(SEPARATOR) 112 | st = [''] + st.split('\t') + [''] 113 | tt = [''] + tt.split('\t') + [''] 114 | 115 | src_texts.append(st) 116 | trg_texts.append(tt) 117 | 118 | src_chars.append(sc) 119 | trg_chars.append(tc) 120 | 121 | src_texts = np.asarray(src_texts) 122 | trg_texts = np.asarray(trg_texts) 123 | 124 | np.random.seed(12345) 125 | indices = np.arange(len(src_texts)) 126 | train_indices = np.random.choice(indices, len(indices) - dev_size - test_size, replace=False) 127 | test_indices = np.setdiff1d(indices, train_indices) 128 | 129 | dev_indices = test_indices[:dev_size] 130 | test_indices = test_indices[dev_size:] 131 | 132 | train_src_texts, dev_src_texts, test_src_texts = \ 133 | src_texts[train_indices], src_texts[dev_indices], src_texts[test_indices] 134 | train_trg_texts, dev_trg_texts, test_trg_texts = \ 135 | trg_texts[train_indices], trg_texts[dev_indices], trg_texts[test_indices] 136 | 137 | train_src_chars = np.asarray(src_chars)[train_indices] 138 | train_trg_chars = np.asarray(trg_chars)[train_indices] 139 | 140 | train_data = (train_src_texts, train_trg_texts, train_src_chars, train_trg_chars) 141 | dev_data = (dev_src_texts, dev_trg_texts) 142 | test_data = (test_src_texts, test_trg_texts) 143 | return train_data, dev_data, test_data 144 | 145 | 146 | def load_cornell_movie(num_words=10000): 147 | train_data, dev_data, test_data = load_extracted_cornell_movie() 148 | train_src_texts, train_trg_texts, _, _ = train_data 149 | dev_src_texts, dev_trg_texts = dev_data 150 | test_src_texts, test_trg_texts = test_data 151 | 152 | src_words = list(chain(*train_src_texts)) 153 | trg_words = list(chain(*train_trg_texts)) 154 | 155 | src_vocabs = process_vocabs(src_words, num_words) 156 | trg_vocabs = process_vocabs(trg_words, num_words) 157 | 158 | process_texts(train_src_texts, src_vocabs) 159 | process_texts(train_trg_texts, trg_vocabs) 160 | 161 | process_texts(dev_src_texts, src_vocabs) 162 | process_texts(dev_trg_texts, trg_vocabs) 163 | 164 | process_texts(test_src_texts, src_vocabs) 165 | process_texts(test_trg_texts, trg_vocabs) 166 | 167 | src_words = list(chain(*train_src_texts)) 168 | trg_words = list(chain(*train_trg_texts)) 169 | 170 | src_vocabs = process_vocabs(src_words, None) 171 | trg_vocabs = process_vocabs(trg_words, None) 172 | return train_src_texts, train_trg_texts, dev_src_texts, dev_trg_texts, test_src_texts, test_trg_texts, \ 173 | src_vocabs, trg_vocabs 174 | 175 | 176 | def load_cornell_movie_by_user(num_users=100, num_words=5000, test_on_user=False, sample_user=False, min_count=20, 177 | user_data_ratio=0.): 178 | train_data, dev_data, test_data = load_extracted_cornell_movie(dev_size=5000, test_size=5000) 179 | train_src_texts, train_trg_texts, src_users, _ = train_data 180 | dev_src_texts, dev_trg_texts = dev_data 181 | test_src_texts, test_trg_texts = test_data 182 | 183 | user_counter = Counter(src_users) 184 | all_users = np.asarray([tup[0] for tup in user_counter.most_common() if tup[1] >= min_count]) 185 | print 'Loaded {} users'.format(len(all_users)) 186 | 187 | np.random.seed(12345) 188 | np.random.shuffle(all_users) 189 | np.random.seed(None) 190 | 191 | train_users = set(all_users[:num_users]) 192 | test_users = all_users[num_users:num_users * 2] 193 | 194 | if sample_user: 195 | attacker_users = all_users[num_users * 2: num_users * 4] 196 | np.random.seed(None) 197 | train_users = np.random.choice(attacker_users, size=num_users, replace=False) 198 | print train_users[:10] 199 | 200 | user_src_texts = defaultdict(list) 201 | user_trg_texts = defaultdict(list) 202 | 203 | test_user_src_texts = defaultdict(list) 204 | test_user_trg_texts = defaultdict(list) 205 | 206 | for u, s, t in zip(src_users, train_src_texts, train_trg_texts): 207 | if u in train_users: 208 | user_src_texts[u].append(s) 209 | user_trg_texts[u].append(t) 210 | if test_on_user and u in test_users: 211 | test_user_src_texts[u].append(s) 212 | test_user_trg_texts[u].append(t) 213 | 214 | if 0. < user_data_ratio < 1.: 215 | # held out some fraction of data for testing 216 | for u in user_src_texts: 217 | l = len(user_src_texts[u]) 218 | # print l 219 | l = int(l * user_data_ratio) 220 | user_src_texts[u] = user_src_texts[u][:l] 221 | user_trg_texts[u] = user_trg_texts[u][:l] 222 | 223 | src_words = [] 224 | trg_words = [] 225 | for u in train_users: 226 | src_words += list(chain(*user_src_texts[u])) 227 | trg_words += list(chain(*user_trg_texts[u])) 228 | 229 | src_vocabs = process_vocabs(src_words, num_words) 230 | trg_vocabs = process_vocabs(trg_words, num_words) 231 | 232 | for u in train_users: 233 | process_texts(user_src_texts[u], src_vocabs) 234 | process_texts(user_trg_texts[u], trg_vocabs) 235 | 236 | if test_on_user: 237 | for u in test_users: 238 | process_texts(test_user_src_texts[u], src_vocabs) 239 | process_texts(test_user_trg_texts[u], trg_vocabs) 240 | 241 | process_texts(dev_src_texts, src_vocabs) 242 | process_texts(dev_trg_texts, trg_vocabs) 243 | 244 | process_texts(test_src_texts, src_vocabs) 245 | process_texts(test_trg_texts, trg_vocabs) 246 | 247 | src_words = [] 248 | trg_words = [] 249 | 250 | for u in train_users: 251 | src_words += list(chain(*user_src_texts[u])) 252 | trg_words += list(chain(*user_trg_texts[u])) 253 | 254 | src_vocabs = process_vocabs(src_words, None) 255 | trg_vocabs = process_vocabs(trg_words, None) 256 | 257 | if test_on_user: 258 | return user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs 259 | else: 260 | return user_src_texts, user_trg_texts, dev_src_texts, dev_trg_texts, test_src_texts, test_trg_texts, \ 261 | src_vocabs, trg_vocabs 262 | 263 | 264 | UBUNTU_PATH = '/hdd/song/nlp/ubuntu/' 265 | 266 | 267 | def load_ubuntu_lines(filename): 268 | lines = [] 269 | with open(filename, 'rb') as f: 270 | for line in f: 271 | l = line[line.rindex("\t") + 1:].strip() # Strip metadata (timestamps, speaker names) 272 | lines.append(l.lower()) 273 | return lines 274 | 275 | 276 | def load_raw_ubuntu(max_dir=15): 277 | conversations = [] 278 | n = 0 279 | dialogs_path = os.path.join(UBUNTU_PATH, 'dialogs') 280 | for sub in os.listdir(dialogs_path)[:max_dir]: 281 | subdir = os.path.join(dialogs_path, sub) 282 | for filename in os.listdir(subdir): 283 | filename = os.path.join(subdir, filename) 284 | if filename.endswith('tsv'): 285 | lines = load_ubuntu_lines(filename) 286 | conversations.append(lines) 287 | n += len(lines) 288 | return conversations 289 | 290 | 291 | def preprocess(t): 292 | words = t.split(' ') 293 | for i in range(len(words)): 294 | if 'http' in words[i] or 'www' in words[i]: 295 | words[i] = '/url/' 296 | return ' '.join(words).decode("utf8") 297 | 298 | 299 | def save_extracted_ubuntu(): 300 | src_texts = [] 301 | trg_texts = [] 302 | 303 | conversations = load_raw_ubuntu() 304 | for lines in conversations: 305 | n_lines = len(lines) 306 | for i in range(n_lines - 1): 307 | st = word_tokenize(preprocess(lines[i])) 308 | tt = word_tokenize(preprocess(lines[i + 1])) 309 | 310 | src_texts.append('\t'.join(st)) 311 | trg_texts.append('\t'.join(tt)) 312 | 313 | print len(src_texts), len(trg_texts) 314 | 315 | with open(UBUNTU_PATH + 'extracted_src_trg.txt', 'wb') as f: 316 | for st, tt in zip(src_texts, trg_texts): 317 | f.write(SEPARATOR.join([st, tt]).encode('utf-8') + '\n') 318 | 319 | 320 | def load_extracted_ubuntu(num_lines): 321 | src_texts = [] 322 | trg_texts = [] 323 | 324 | with open(UBUNTU_PATH + 'extracted_src_trg.txt', 'rb') as f: 325 | for line in f: 326 | st, tt = line.lower().replace('\n', '').split(SEPARATOR) 327 | st = [''] + st.split('\t') + [''] 328 | tt = [''] + tt.split('\t') + [''] 329 | 330 | src_texts.append(st) 331 | trg_texts.append(tt) 332 | if len(src_texts) > num_lines: 333 | break 334 | 335 | return src_texts, trg_texts 336 | 337 | 338 | def load_ubuntu_by_user(num_users=200, num_words=5000, num_data_per_user=200, test_size=5000): 339 | src_texts, trg_texts = load_extracted_ubuntu(num_users * num_data_per_user * 2 + test_size) 340 | test_src_texts = src_texts[-test_size:] 341 | test_trg_texts = trg_texts[-test_size:] 342 | src_texts = src_texts[:-test_size] 343 | trg_texts = trg_texts[:-test_size] 344 | 345 | all_users = np.arange(num_users * 2) 346 | np.random.seed(None) 347 | train_users = np.random.choice(all_users, size=num_users, replace=False) 348 | 349 | user_src_texts = defaultdict(list) 350 | user_trg_texts = defaultdict(list) 351 | 352 | for u in train_users: 353 | user_src_texts[u] = src_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 354 | user_trg_texts[u] = trg_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 355 | 356 | src_words = [] 357 | trg_words = [] 358 | for u in train_users: 359 | src_words += list(chain(*user_src_texts[u])) 360 | trg_words += list(chain(*user_trg_texts[u])) 361 | 362 | src_vocabs = process_vocabs(src_words, num_words) 363 | trg_vocabs = process_vocabs(trg_words, num_words) 364 | 365 | for u in train_users: 366 | process_texts(user_src_texts[u], src_vocabs) 367 | process_texts(user_trg_texts[u], trg_vocabs) 368 | 369 | process_texts(test_src_texts, src_vocabs) 370 | process_texts(test_trg_texts, trg_vocabs) 371 | 372 | src_words = [] 373 | trg_words = [] 374 | for u in train_users: 375 | src_words += list(chain(*user_src_texts[u])) 376 | trg_words += list(chain(*user_trg_texts[u])) 377 | 378 | src_vocabs = process_vocabs(src_words, None) 379 | trg_vocabs = process_vocabs(trg_words, None) 380 | 381 | return user_src_texts, user_trg_texts, test_src_texts, test_trg_texts, test_src_texts, test_trg_texts,\ 382 | src_vocabs, trg_vocabs 383 | 384 | 385 | # if __name__ == '__main__': 386 | # # load_cornell_movie_by_user(num_users=200, sample_user=False, user_data_ratio=0.5) 387 | # load_ubuntu_by_users() -------------------------------------------------------------------------------- /reddit_lm_ranks.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import sys 4 | from collections import defaultdict 5 | 6 | import keras.backend as K 7 | import numpy as np 8 | from nltk.tokenize import word_tokenize 9 | 10 | from data_loader.load_reddit import read_top_users, REDDIT_PROCESSED_PATH, build_vocab, read_top_user_comments, \ 11 | remove_puncs 12 | from data_loader.load_wiki import WIKI_TRAIN_PATH, load_wiki_lines 13 | from reddit_lm import words_to_indices, MODEL_PATH, RESULT_PATH, build_lm_model 14 | 15 | 16 | def load_cross_domain_shadow_user_data(train_users, num_users=100, num_data_per_user=100, num_words=5000): 17 | train_data = load_wiki_lines(WIKI_TRAIN_PATH, num_lines=2 * num_users * num_data_per_user) 18 | 19 | # l = len(train_data) 20 | # num_data_per_user = l // (num_users * 2) 21 | print "Splitting data to {} users, each has {} texts".format(num_users * 2, num_data_per_user) 22 | 23 | all_users = np.arange(num_users * 2) 24 | train_users = [int(u) for u in train_users] 25 | test_users = np.setdiff1d(all_users, train_users) 26 | print len(test_users) 27 | 28 | train_user_comments = defaultdict(list) 29 | test_user_comments = defaultdict(list) 30 | 31 | all_words = [] 32 | 33 | for u in train_users: 34 | data = train_data[u * num_data_per_user: (u + 1) * num_data_per_user] 35 | for words in data: 36 | all_words += words 37 | train_user_comments[str(u)] = data 38 | 39 | for u in test_users: 40 | data = train_data[u * num_data_per_user: (u + 1) * num_data_per_user] 41 | test_user_comments[str(u)] = data 42 | 43 | vocabs = build_vocab(all_words, num_words + 1) 44 | 45 | all_words = [] 46 | for user in train_user_comments: 47 | comments = train_user_comments[user] 48 | for i in range(len(comments)): 49 | comment = comments[i] 50 | for j in range(len(comment)): 51 | word = comment[j] 52 | if word not in vocabs: 53 | comment[j] = '' 54 | all_words += comment 55 | 56 | vocabs = build_vocab(all_words, None) 57 | 58 | for user in test_user_comments: 59 | comments = test_user_comments[user] 60 | for i in range(len(comments)): 61 | comment = comments[i] 62 | for j in range(len(comment)): 63 | word = comment[j] 64 | if word not in vocabs: 65 | comment[j] = '' 66 | 67 | return train_user_comments, test_user_comments, vocabs 68 | 69 | 70 | def load_shadow_user_data(train_users, num_users=100, num_words=10000): 71 | all_users = read_top_users(num_users * 4) 72 | attacker_users = all_users[num_users * 2: num_users * 4] 73 | test_users = np.setdiff1d(attacker_users, train_users) 74 | 75 | train_user_comments = defaultdict(list) 76 | test_user_comments = defaultdict(list) 77 | all_words = [] 78 | for user in train_users: 79 | filename = os.path.join(REDDIT_PROCESSED_PATH, user) 80 | with codecs.open(filename, encoding='utf-8') as f: 81 | for line in f: 82 | data = line.replace('/url/', '').replace('\n', '').split() + [''] 83 | if len(data) == 1: 84 | continue 85 | 86 | train_user_comments[user].append(data) 87 | all_words += data 88 | 89 | for user in test_users: 90 | filename = os.path.join(REDDIT_PROCESSED_PATH, user) 91 | with codecs.open(filename, encoding='utf-8') as f: 92 | for line in f: 93 | data = line.replace('/url/', '').replace('\n', '').split() + [''] 94 | if len(data) == 1: 95 | continue 96 | 97 | test_user_comments[user].append(data) 98 | 99 | vocabs = build_vocab(all_words, num_words) 100 | 101 | all_words = [] 102 | for user in train_user_comments: 103 | comments = train_user_comments[user] 104 | for i in range(len(comments)): 105 | comment = comments[i] 106 | for j in range(len(comment)): 107 | word = comment[j] 108 | if word not in vocabs: 109 | comment[j] = '' 110 | all_words += comment 111 | 112 | vocabs = build_vocab(all_words, None) 113 | 114 | for user in test_user_comments: 115 | comments = test_user_comments[user] 116 | for i in range(len(comments)): 117 | comment = comments[i] 118 | for j in range(len(comment)): 119 | word = comment[j] 120 | if word not in vocabs: 121 | comment[j] = '' 122 | 123 | return train_user_comments, test_user_comments, vocabs 124 | 125 | 126 | def group_texts_by_len(texts, bs=20): 127 | # Bucket samples by source sentence length 128 | buckets = defaultdict(list) 129 | batches = [] 130 | for t in texts: 131 | if len(t) < 2: 132 | continue 133 | buckets[len(t)].append(t) 134 | 135 | for l, bucket in buckets.items(): 136 | num_batches = int(np.ceil(len(bucket) * 1.0 / bs)) 137 | for i in range(num_batches): 138 | cur_batch_size = bs if i < num_batches - 1 else len(bucket) - bs * i 139 | batches.append([bucket[i * bs + j] for j in range(cur_batch_size)]) 140 | 141 | return batches 142 | 143 | 144 | def group_texts_by_maxlen(user_data, maxlen, bs): 145 | flatten_data = np.asarray([w for t in user_data for w in t]).astype(np.int32) 146 | if len(flatten_data) - 1 < maxlen: 147 | maxlen = len(flatten_data) - 1 148 | 149 | n_data = (len(flatten_data) - 1) // maxlen 150 | 151 | inputs_data = flatten_data[:-1][:n_data * maxlen].reshape(n_data, maxlen) 152 | targets_data = flatten_data[1:][:n_data * maxlen].reshape(n_data, maxlen) 153 | 154 | cumsum = 0 155 | lengthes = [0] 156 | for i, t in enumerate(user_data): 157 | if i == 0: 158 | cumsum += len(t) - 1 159 | else: 160 | cumsum += len(t) 161 | lengthes.append(cumsum) 162 | if cumsum >= n_data * maxlen: 163 | lengthes[-1] = n_data * maxlen 164 | break 165 | 166 | n_batches = n_data // bs + 1 167 | batches = [] 168 | for i in range(n_batches): 169 | if i * bs >= len(inputs_data): 170 | break 171 | inputs = inputs_data[i * bs: (i + 1) * bs] 172 | targets = targets_data[i * bs: (i + 1) * bs] 173 | batches.append((inputs, targets)) 174 | 175 | return batches, lengthes 176 | 177 | 178 | def get_bigram_probs(texts, pred_fn): 179 | bs, l = texts.shape 180 | for i in range(l): 181 | pass 182 | 183 | 184 | def get_ranks_labels_by_batch(user_data, pred_fn, maxlen=35, bs=20, save_probs=False, trg_prob_only=False): 185 | batches, lengthes = group_texts_by_maxlen(user_data, maxlen, bs) 186 | _ranks = [] 187 | _labels = [] 188 | _probs = [] 189 | 190 | for inputs, targets in batches: 191 | bs, l = targets.shape 192 | # print inputs.shape, targets.shape 193 | probs = pred_fn([inputs, targets, 0])[0] 194 | # all_ranks = np.argsort(-probs, axis=-1).argsort(axis=-1) # 20 x len x words 195 | # all_ranks = all_ranks.reshape(bs * l, -1) 196 | probs = probs.reshape(bs * l, -1) 197 | if save_probs: 198 | if trg_prob_only: 199 | probs = probs[np.arange(bs * l), targets.flatten().astype(int)] 200 | _probs.append(probs) 201 | else: 202 | all_ranks = rank_lists(-probs) 203 | targets_ranks = all_ranks[np.arange(bs * l), targets.flatten().astype(int)] 204 | targets_ranks = targets_ranks.reshape(bs, l) 205 | assert targets.shape == targets_ranks.shape 206 | # ranks += [r for r in targets_ranks] 207 | # labels += [t for t in targets] 208 | _ranks.append(targets_ranks.flatten()) 209 | _labels.append(targets.flatten()) 210 | 211 | if save_probs: 212 | all_probs = np.concatenate(_probs) if trg_prob_only else np.vstack(_probs) 213 | else: 214 | all_ranks = np.concatenate(_ranks) 215 | all_labels = np.concatenate(_labels) 216 | assert lengthes[-1] == len(all_ranks) == len(all_labels) 217 | 218 | ranks = [] 219 | labels = [] 220 | probs = [] 221 | 222 | for b, e in zip(lengthes[:-1], lengthes[1:]): 223 | if save_probs: 224 | probs.append(all_probs[b: e]) 225 | else: 226 | ranks.append(all_ranks[b: e]) 227 | labels.append(all_labels[b: e]) 228 | 229 | if save_probs: 230 | return probs 231 | else: 232 | return ranks, labels 233 | 234 | 235 | def rank_lists(lists): 236 | # ranks = np.empty_like(lists) 237 | # for i, l in enumerate(lists): 238 | # ranks[i] = ss.rankdata(l, method='min') - 1 239 | 240 | temp = np.argsort(lists, axis=-1) 241 | ranks = np.empty_like(temp) 242 | ranks[np.arange(len(temp))[:, None], temp] = np.arange(temp.shape[1]) 243 | return ranks 244 | 245 | 246 | def save_users_rank_results(users, user_comments, vocabs, prob_fn, save_dir, member_label=1, 247 | cross_domain=False, save_probs=False, trg_prob_only=False, rerun=False): 248 | for i, u in enumerate(users): 249 | save_path = save_dir + 'rank_u{}_y{}{}.npz'.format(i, member_label, '_cd' if cross_domain else '') 250 | prob_path = save_dir + 'prob_u{}_y{}{}.npz'.format(i, member_label, '_cd' if cross_domain else '') 251 | 252 | if os.path.exists(save_path) and not save_probs and not rerun: 253 | continue 254 | 255 | user_data = words_to_indices(user_comments[u], vocabs) 256 | rtn = get_ranks_labels_by_batch(user_data, prob_fn, save_probs=save_probs, trg_prob_only=trg_prob_only) 257 | 258 | if save_probs: 259 | probs = rtn 260 | np.savez(prob_path, probs) 261 | else: 262 | ranks, labels = rtn[0], rtn[1] 263 | np.savez(save_path, ranks, labels) 264 | 265 | if (i + 1) % 500 == 0: 266 | sys.stderr.write('Finishing saving ranks for {} users\n'.format(i + 1)) 267 | 268 | 269 | def get_shadow_ranks(exp_id=0, num_users=100, num_words=5000, cross_domain=False, h=128, emb_h=256, rerun=False, 270 | rnn_fn='lstm'): 271 | shadow_user_path = 'shadow_users{}_{}_{}_{}.npz'.format(exp_id, rnn_fn, num_users, 'cd' if cross_domain else '') 272 | shadow_train_users = np.load(MODEL_PATH + shadow_user_path)['arr_0'] 273 | shadow_train_users = list(shadow_train_users) 274 | 275 | print len(shadow_train_users) 276 | print shadow_user_path 277 | 278 | save_dir = RESULT_PATH + 'shadow_exp{}_{}/'.format(exp_id, num_users) 279 | if not os.path.exists(save_dir): 280 | os.mkdir(save_dir) 281 | 282 | if cross_domain: 283 | train_user_comments, test_user_comments, vocabs =\ 284 | load_cross_domain_shadow_user_data(shadow_train_users, num_users=num_users, num_words=num_words) 285 | else: 286 | train_user_comments, test_user_comments, vocabs =\ 287 | load_shadow_user_data(shadow_train_users, num_users, num_words) 288 | shadow_test_users = sorted(test_user_comments.keys()) 289 | 290 | if cross_domain: 291 | model_path = 'wiki_lm_shadow_exp{}_{}_{}.h5'.format(exp_id, rnn_fn, num_users) 292 | else: 293 | model_path = 'reddit_lm_shadow_exp{}_{}_{}.h5'.format(exp_id, rnn_fn, num_users) 294 | 295 | model = build_lm_model(V=num_words, drop_p=0., h=h, emb_h=emb_h, rnn_fn=rnn_fn) 296 | model.load_weights(MODEL_PATH + model_path) 297 | 298 | input_var = K.placeholder((None, None)) 299 | prediction = model(input_var) 300 | label_var = K.placeholder((None, None), dtype='float32') 301 | prediction = K.softmax(prediction) 302 | prob_fn = K.function([input_var, label_var, K.learning_phase()], [prediction]) 303 | 304 | save_users_rank_results(shadow_train_users, train_user_comments, cross_domain=cross_domain, rerun=rerun, 305 | vocabs=vocabs, prob_fn=prob_fn, save_dir=save_dir, member_label=1) 306 | save_users_rank_results(shadow_test_users, test_user_comments, cross_domain=cross_domain, rerun=rerun, 307 | vocabs=vocabs, prob_fn=prob_fn, save_dir=save_dir, member_label=0) 308 | 309 | 310 | def get_target_ranks(num_users=100, num_words=5000, h=128, emb_h=256, rerun=False): 311 | users = read_top_users(num_users * 2) 312 | train_users = users[:num_users] 313 | test_users = users[num_users:] 314 | train_user_comments, vocabs = read_top_user_comments(num_users, num_words, top_users=train_users) 315 | test_user_comments, _ = read_top_user_comments(num_users, num_words, top_users=test_users, vocabs=vocabs) 316 | 317 | save_dir = RESULT_PATH + 'target_{}/'.format(num_users) 318 | if not os.path.exists(save_dir): 319 | os.mkdir(save_dir) 320 | 321 | model_path = 'reddit_lm_{}.h5'.format(num_users) 322 | model = build_lm_model(V=num_words, drop_p=0., h=h, emb_h=emb_h) 323 | model.load_weights(MODEL_PATH + model_path) 324 | 325 | input_var = K.placeholder((None, None)) 326 | prediction = model(input_var) 327 | label_var = K.placeholder((None, None), dtype='float32') 328 | prediction = K.softmax(prediction) 329 | prob_fn = K.function([input_var, label_var, K.learning_phase()], [prediction]) 330 | 331 | save_users_rank_results(train_users, train_user_comments, cross_domain=False, rerun=rerun, 332 | vocabs=vocabs, prob_fn=prob_fn, save_dir=save_dir, member_label=1) 333 | save_users_rank_results(test_users, test_user_comments, cross_domain=False, rerun=rerun, 334 | vocabs=vocabs, prob_fn=prob_fn, save_dir=save_dir, member_label=0) 335 | 336 | 337 | def read_translated_comments(users, vocabs, multi_step=False, trans='yandex'): 338 | user_comments = defaultdict(list) 339 | for user in users: 340 | filename = './translate/{}_{}_{}.txt'.format(user, trans, 'multi' if multi_step else 'two') 341 | with codecs.open(filename, encoding='utf-8') as f: 342 | for line in f: 343 | data = line.replace('\n', '') 344 | data = word_tokenize(data) 345 | data = remove_puncs(data) 346 | 347 | if len(data) == 1: 348 | print user, data 349 | continue 350 | 351 | user_comments[user].append(data + ['']) 352 | 353 | for user in user_comments: 354 | comments = user_comments[user] 355 | for i in range(len(comments)): 356 | comment = comments[i] 357 | for j in range(len(comment)): 358 | word = comment[j] 359 | if word not in vocabs: 360 | comment[j] = '' 361 | 362 | return user_comments 363 | 364 | 365 | def get_translated_rank(num_users=100, num_words=5000, h=128, emb_h=128): 366 | users = read_top_users(num_users * 2, min_count=0) 367 | train_users = users[:num_users] 368 | test_users = users[num_users:] 369 | 370 | _, vocabs = read_top_user_comments(num_users, num_words, top_users=train_users) 371 | train_user_comments = read_translated_comments(train_users, vocabs) 372 | test_user_comments = read_translated_comments(test_users, vocabs) 373 | 374 | save_dir = RESULT_PATH + 'target_yandex_{}/'.format(num_users) 375 | if not os.path.exists(save_dir): 376 | os.mkdir(save_dir) 377 | 378 | model_path = 'reddit_lm_{}.h5'.format(num_users) 379 | model = build_lm_model(V=num_words, drop_p=0., h=h, emb_h=emb_h) 380 | model.load_weights(MODEL_PATH + model_path) 381 | 382 | input_var = K.placeholder((None, None)) 383 | prediction = model(input_var) 384 | label_var = K.placeholder((None, None), dtype='float32') 385 | prediction = K.softmax(prediction) 386 | prob_fn = K.function([input_var, label_var, K.learning_phase()], [prediction]) 387 | 388 | save_users_rank_results(train_users, train_user_comments, cross_domain=False, 389 | vocabs=vocabs, prob_fn=prob_fn, save_dir=save_dir, member_label=1) 390 | save_users_rank_results(test_users, test_user_comments, cross_domain=False, 391 | vocabs=vocabs, prob_fn=prob_fn, save_dir=save_dir, member_label=0) 392 | -------------------------------------------------------------------------------- /sated_nmt_ranks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from collections import Counter, defaultdict 4 | from itertools import chain 5 | 6 | import keras.backend as K 7 | import numpy as np 8 | import scipy.stats as ss 9 | from sklearn.metrics import roc_auc_score, accuracy_score, classification_report 10 | from sklearn.preprocessing import Normalizer, StandardScaler 11 | from sklearn.svm import SVC 12 | 13 | from helper import flatten_data 14 | from data_loader.load_sated import process_texts, process_vocabs, load_texts, load_users, load_sated_data_by_user, \ 15 | SATED_TRAIN_USER, SATED_TRAIN_FR, SATED_TRAIN_ENG, read_europarl_file, EUROPARL_FREN_FR, EUROPARL_FREN_EN 16 | from sated_nmt import build_nmt_model, words_to_indices, MODEL_PATH, OUTPUT_PATH 17 | 18 | 19 | def load_cross_domain_shadow_user_data(train_users, num_users=100, num_words=10000, num_data_per_user=150, seed=12345): 20 | src_texts = read_europarl_file(EUROPARL_FREN_EN, num_users * num_data_per_user * 2) 21 | trg_texts = read_europarl_file(EUROPARL_FREN_FR, num_users * num_data_per_user * 2) 22 | 23 | all_users = np.arange(num_users * 2) 24 | test_users = np.setdiff1d(all_users, train_users) 25 | 26 | user_src_texts = defaultdict(list) 27 | user_trg_texts = defaultdict(list) 28 | 29 | test_user_src_texts = defaultdict(list) 30 | test_user_trg_texts = defaultdict(list) 31 | 32 | for u in train_users: 33 | user_src_texts[u] = src_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 34 | user_trg_texts[u] = trg_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 35 | 36 | for u in test_users: 37 | test_user_src_texts[u] = src_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 38 | test_user_trg_texts[u] = trg_texts[u * num_data_per_user: (u + 1) * num_data_per_user] 39 | 40 | src_words = [] 41 | trg_words = [] 42 | for u in train_users: 43 | src_words += list(chain(*user_src_texts[u])) 44 | trg_words += list(chain(*user_trg_texts[u])) 45 | 46 | src_vocabs = process_vocabs(src_words, num_words) 47 | trg_vocabs = process_vocabs(trg_words, num_words) 48 | 49 | for u in train_users: 50 | process_texts(user_src_texts[u], src_vocabs) 51 | process_texts(user_trg_texts[u], trg_vocabs) 52 | 53 | for u in test_users: 54 | process_texts(test_user_src_texts[u], src_vocabs) 55 | process_texts(test_user_trg_texts[u], trg_vocabs) 56 | 57 | src_words = [] 58 | trg_words = [] 59 | 60 | for u in train_users: 61 | src_words += list(chain(*user_src_texts[u])) 62 | trg_words += list(chain(*user_trg_texts[u])) 63 | 64 | src_vocabs = process_vocabs(src_words, None) 65 | trg_vocabs = process_vocabs(trg_words, None) 66 | 67 | return user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs 68 | 69 | 70 | def load_train_users_heldout_data(train_users, src_vocabs, trg_vocabs, user_data_ratio=0.5): 71 | src_users = load_users(SATED_TRAIN_USER) 72 | train_src_texts = load_texts(SATED_TRAIN_ENG) 73 | train_trg_texts = load_texts(SATED_TRAIN_FR) 74 | 75 | user_src_texts = defaultdict(list) 76 | user_trg_texts = defaultdict(list) 77 | 78 | for u, s, t in zip(src_users, train_src_texts, train_trg_texts): 79 | if u in train_users: 80 | user_src_texts[u].append(s) 81 | user_trg_texts[u].append(t) 82 | 83 | assert 0. < user_data_ratio < 1. 84 | # held out some fraction of data for testing 85 | for u in user_src_texts: 86 | l = len(user_src_texts[u]) 87 | l = int(l * user_data_ratio) 88 | user_src_texts[u] = user_src_texts[u][l:] 89 | user_trg_texts[u] = user_trg_texts[u][l:] 90 | 91 | for u in train_users: 92 | process_texts(user_src_texts[u], src_vocabs) 93 | process_texts(user_trg_texts[u], trg_vocabs) 94 | 95 | return user_src_texts, user_trg_texts 96 | 97 | 98 | def load_shadow_user_data(train_users, num_users=100, num_words=10000, seed=12345): 99 | src_users = load_users(SATED_TRAIN_USER) 100 | train_src_texts = load_texts(SATED_TRAIN_ENG) 101 | train_trg_texts = load_texts(SATED_TRAIN_FR) 102 | 103 | user_counter = Counter(src_users) 104 | all_users = [tup[0] for tup in user_counter.most_common()] 105 | np.random.seed(seed) 106 | np.random.shuffle(all_users) 107 | np.random.seed(None) 108 | 109 | attacker_users = all_users[num_users * 2: num_users * 4] 110 | test_users = np.setdiff1d(attacker_users, train_users) 111 | print len(train_users), len(test_users) 112 | 113 | user_src_texts = defaultdict(list) 114 | user_trg_texts = defaultdict(list) 115 | 116 | test_user_src_texts = defaultdict(list) 117 | test_user_trg_texts = defaultdict(list) 118 | 119 | for u, s, t in zip(src_users, train_src_texts, train_trg_texts): 120 | if u in train_users: 121 | user_src_texts[u].append(s) 122 | user_trg_texts[u].append(t) 123 | if u in test_users: 124 | test_user_src_texts[u].append(s) 125 | test_user_trg_texts[u].append(t) 126 | 127 | src_words = [] 128 | trg_words = [] 129 | for u in train_users: 130 | src_words += list(chain(*user_src_texts[u])) 131 | trg_words += list(chain(*user_trg_texts[u])) 132 | 133 | src_vocabs = process_vocabs(src_words, num_words) 134 | trg_vocabs = process_vocabs(trg_words, num_words) 135 | 136 | for u in train_users: 137 | process_texts(user_src_texts[u], src_vocabs) 138 | process_texts(user_trg_texts[u], trg_vocabs) 139 | 140 | for u in test_users: 141 | process_texts(test_user_src_texts[u], src_vocabs) 142 | process_texts(test_user_trg_texts[u], trg_vocabs) 143 | 144 | src_words = [] 145 | trg_words = [] 146 | 147 | for u in train_users: 148 | src_words += list(chain(*user_src_texts[u])) 149 | trg_words += list(chain(*user_trg_texts[u])) 150 | 151 | src_vocabs = process_vocabs(src_words, None) 152 | trg_vocabs = process_vocabs(trg_words, None) 153 | 154 | return user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs 155 | 156 | 157 | def rank_lists(lists): 158 | ranks = np.empty_like(lists) 159 | for i, l in enumerate(lists): 160 | ranks[i] = ss.rankdata(l, method='min') - 1 161 | return ranks 162 | 163 | 164 | def get_ranks(user_src_data, user_trg_data, pred_fn, save_probs=False): 165 | indices = np.arange(len(user_src_data)) 166 | 167 | ranks = [] 168 | labels = [] 169 | probs = [] 170 | for idx in indices: 171 | src_text = np.asarray(user_src_data[idx], dtype=np.float32).reshape(1, -1) 172 | trg_text = np.asarray(user_trg_data[idx], dtype=np.float32) 173 | trg_input = trg_text[:-1].reshape(1, -1) 174 | trg_label = trg_text[1:].reshape(1, -1) 175 | 176 | prob = pred_fn([src_text, trg_input, trg_label, 0])[0][0] 177 | if save_probs: 178 | probs.append(prob) 179 | 180 | # all_ranks = np.argsort(-prob, axis=-1).argsort(axis=-1) 181 | 182 | all_ranks = rank_lists(-prob) 183 | sent_ranks = all_ranks[np.arange(len(all_ranks)), trg_label.flatten().astype(int)] 184 | 185 | ranks.append(sent_ranks) 186 | labels.append(trg_label.flatten()) 187 | 188 | if save_probs: 189 | return probs 190 | 191 | return ranks, labels 192 | 193 | 194 | def save_users_rank_results(users, user_src_texts, user_trg_texts, src_vocabs, trg_vocabs, prob_fn, save_dir, 195 | member_label=1, cross_domain=False, save_probs=False, mask=False, rerun=False): 196 | for i, u in enumerate(users): 197 | save_path = save_dir + 'rank_u{}_y{}{}.npz'.format(i, member_label, '_cd' if cross_domain else '') 198 | prob_path = save_dir + 'prob_u{}_y{}{}.npz'.format(i, member_label, '_cd' if cross_domain else '') 199 | 200 | if os.path.exists(save_path) and not save_probs and not rerun: 201 | continue 202 | 203 | user_src_data = words_to_indices(user_src_texts[u], src_vocabs, mask=mask) 204 | user_trg_data = words_to_indices(user_trg_texts[u], trg_vocabs, mask=mask) 205 | 206 | rtn = get_ranks(user_src_data, user_trg_data, prob_fn, save_probs=save_probs) 207 | 208 | if save_probs: 209 | probs = rtn 210 | np.savez(prob_path, probs) 211 | else: 212 | ranks, labels = rtn[0], rtn[1] 213 | np.savez(save_path, ranks, labels) 214 | 215 | if (i + 1) % 500 == 0: 216 | sys.stderr.write('Finishing saving ranks for {} users'.format(i + 1)) 217 | 218 | 219 | def histogram_feats(ranks, bins=100, num_words=5000): 220 | feats, _ = np.histogram(ranks, bins=bins, normed=False, range=(0, num_words)) 221 | return feats 222 | 223 | 224 | def get_shadow_ranks(exp_id=0, num_users=200, num_words=5000, mask=False, h=128, emb_h=128, save_probs=False, 225 | tied=False, cross_domain=False, rnn_fn='lstm', rerun=False): 226 | shadow_user_path = 'shadow_users{}_{}_{}_{}.npz'.format(exp_id, rnn_fn, num_users, 'cd' if cross_domain else '') 227 | shadow_train_users = np.load(MODEL_PATH + shadow_user_path)['arr_0'] 228 | shadow_train_users = list(shadow_train_users) 229 | 230 | print shadow_user_path 231 | 232 | save_dir = OUTPUT_PATH + 'shadow_exp{}_{}/'.format(exp_id, num_users) 233 | if not os.path.exists(save_dir): 234 | os.mkdir(save_dir) 235 | 236 | if cross_domain: 237 | user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs \ 238 | = load_cross_domain_shadow_user_data(shadow_train_users, num_users, num_words) 239 | else: 240 | user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs \ 241 | = load_shadow_user_data(shadow_train_users, num_users, num_words) 242 | shadow_test_users = sorted(test_user_src_texts.keys()) 243 | 244 | model_path = '{}_shadow_exp{}_{}_{}.h5'.format('europal_nmt' if cross_domain else 'sated_nmt', 245 | exp_id, rnn_fn, num_users) 246 | 247 | model = build_nmt_model(Vs=num_words, Vt=num_words, mask=mask, drop_p=0., h=h, demb=emb_h, tied=tied, rnn_fn=rnn_fn) 248 | model.load_weights(MODEL_PATH + model_path) 249 | 250 | src_input_var, trg_input_var = model.inputs 251 | prediction = model.output 252 | trg_label_var = K.placeholder((None, None), dtype='float32') 253 | 254 | prediction = K.softmax(prediction) 255 | prob_fn = K.function([src_input_var, trg_input_var, trg_label_var, K.learning_phase()], [prediction]) 256 | 257 | save_users_rank_results(users=shadow_train_users, save_probs=save_probs, rerun=rerun, mask=mask, 258 | user_src_texts=user_src_texts, user_trg_texts=user_trg_texts, 259 | src_vocabs=src_vocabs, trg_vocabs=trg_vocabs, cross_domain=cross_domain, 260 | prob_fn=prob_fn, save_dir=save_dir, member_label=1) 261 | save_users_rank_results(users=shadow_test_users, save_probs=save_probs, rerun=rerun, mask=mask, 262 | user_src_texts=test_user_src_texts, user_trg_texts=test_user_trg_texts, 263 | src_vocabs=src_vocabs, trg_vocabs=trg_vocabs, cross_domain=cross_domain, 264 | prob_fn=prob_fn, save_dir=save_dir, member_label=0) 265 | 266 | 267 | def get_target_ranks(num_users=200, num_words=5000, mask=False, h=128, emb_h=128, user_data_ratio=0., 268 | tied=False, save_probs=False): 269 | user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs \ 270 | = load_sated_data_by_user(num_users, num_words, test_on_user=True, user_data_ratio=user_data_ratio) 271 | 272 | train_users = sorted(user_src_texts.keys()) 273 | test_users = sorted(test_user_src_texts.keys()) 274 | 275 | save_dir = OUTPUT_PATH + 'target_{}{}/'.format(num_users, '_dr' if 0. < user_data_ratio < 1. else '') 276 | if not os.path.exists(save_dir): 277 | os.mkdir(save_dir) 278 | 279 | model_path = 'sated_nmt'.format(num_users) 280 | 281 | if 0. < user_data_ratio < 1.: 282 | model_path += '_dr{}'.format(user_data_ratio) 283 | heldout_src_texts, heldout_trg_texts = load_train_users_heldout_data(train_users, src_vocabs, trg_vocabs) 284 | for u in train_users: 285 | user_src_texts[u] += heldout_src_texts[u] 286 | user_trg_texts[u] += heldout_trg_texts[u] 287 | 288 | model = build_nmt_model(Vs=num_words, Vt=num_words, mask=mask, drop_p=0., h=h, demb=emb_h, tied=tied) 289 | model.load_weights(MODEL_PATH + '{}_{}.h5'.format(model_path, num_users)) 290 | 291 | src_input_var, trg_input_var = model.inputs 292 | prediction = model.output 293 | trg_label_var = K.placeholder((None, None), dtype='float32') 294 | 295 | prediction = K.softmax(prediction) 296 | prob_fn = K.function([src_input_var, trg_input_var, trg_label_var, K.learning_phase()], [prediction]) 297 | 298 | save_users_rank_results(users=train_users, save_probs=save_probs, 299 | user_src_texts=user_src_texts, user_trg_texts=user_trg_texts, 300 | src_vocabs=src_vocabs, trg_vocabs=trg_vocabs, cross_domain=False, 301 | prob_fn=prob_fn, save_dir=save_dir, member_label=1) 302 | save_users_rank_results(users=test_users, save_probs=save_probs, 303 | user_src_texts=test_user_src_texts, user_trg_texts=test_user_trg_texts, 304 | src_vocabs=src_vocabs, trg_vocabs=trg_vocabs, cross_domain=False, 305 | prob_fn=prob_fn, save_dir=save_dir, member_label=0) 306 | 307 | 308 | def ranks_to_feats(ranks, prop=1.0, dim=100, num_words=5000, shuffle=True): 309 | X = [] 310 | i = 0 311 | for user_ranks in ranks: 312 | indices = np.arange(len(user_ranks)) 313 | if shuffle: 314 | np.random.shuffle(indices) 315 | n = int(len(indices) * prop) 316 | r = [] 317 | for idx in indices[:n]: 318 | r.append(user_ranks[idx]) 319 | r = np.concatenate(r) 320 | # print i, np.average(r) 321 | feats = histogram_feats(r, bins=dim, num_words=num_words) 322 | X.append(feats) 323 | i += 1 324 | # quit() 325 | return np.vstack(X) 326 | 327 | 328 | def user_mi_attack(num_exp=10, dim=100, prop=1.0, num_words=5000, cross_domain=True): 329 | f = np.load(OUTPUT_PATH + 'target_user_ranks.npz') 330 | X_test = ranks_to_feats(f['arr_0'], prop=prop, dim=dim, num_words=num_words) 331 | y_test = f['arr_1'] 332 | 333 | X = [] 334 | y = [] 335 | for exp_id in range(num_exp): 336 | f = np.load(OUTPUT_PATH + 'shadow_user_ranks_{}{}.npz'.format(exp_id, '_cd' if cross_domain else '')) 337 | feats = ranks_to_feats(f['arr_0'], prop=prop, dim=dim, num_words=num_words) 338 | X.append(feats) 339 | y.append(f['arr_1']) 340 | 341 | X_train = np.vstack(X) 342 | y_train = np.concatenate(y) 343 | 344 | print X_train.shape, y_train.shape 345 | normalizer = Normalizer(norm='l1') 346 | X_train = normalizer.fit_transform(X_train) 347 | X_test = normalizer.fit_transform(X_test) 348 | 349 | scaler = StandardScaler() 350 | X_train = scaler.fit_transform(X_train) 351 | X_test = scaler.transform(X_test) 352 | 353 | # clf = RandomForestClassifier(n_estimators=20) 354 | clf = SVC() 355 | clf.fit(X_train, y_train) 356 | 357 | y_score = clf.decision_function(X_test) # [:, 1] 358 | y_pred = clf.predict(X_test) 359 | 360 | print classification_report(y_pred=y_pred, y_true=y_test) 361 | print 'ACC:', accuracy_score(y_test, y_pred) 362 | print 'AUC:', roc_auc_score(y_test, y_score) 363 | 364 | 365 | def test_vocab(): 366 | user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs \ 367 | = load_sated_data_by_user(300, 5000, test_on_user=True, user_data_ratio=0.) 368 | train_data = [] 369 | test_data = [] 370 | 371 | for user in user_trg_texts: 372 | train_data += user_trg_texts[user] 373 | train_data = words_to_indices(train_data, trg_vocabs) 374 | train_data = flatten_data(train_data) 375 | 376 | for user in test_user_trg_texts: 377 | test_data += test_user_trg_texts[user] 378 | test_data = words_to_indices(test_data, trg_vocabs) 379 | test_data = flatten_data(test_data) 380 | 381 | n = float(len(train_data)) 382 | b = np.sum(train_data >= 1000) / n 383 | print 1 - b, b, n 384 | 385 | n = float(len(test_data)) 386 | b = np.sum(test_data >= 1000) / n 387 | print 1 - b, b, n 388 | 389 | 390 | if __name__ == '__main__': 391 | get_target_ranks(num_users=300, save_probs=False) 392 | for i in range(10): 393 | get_shadow_ranks(exp_id=i, num_users=300, cross_domain=False, rerun=True) 394 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import keras.backend as K 3 | from keras.layers import Layer 4 | from keras.legacy import interfaces 5 | from keras.engine import InputSpec 6 | from keras import activations, initializers, regularizers, constraints 7 | from keras.layers.recurrent import RNN 8 | 9 | from collections import namedtuple 10 | 11 | 12 | def words_to_indices(data, vocab): 13 | return [[vocab[w] for w in t] for t in data] 14 | 15 | 16 | def iterate_minibatches(inputs, targets, batchsize, shuffle=False): 17 | assert len(inputs) == len(targets) 18 | if shuffle: 19 | indices = np.arange(len(inputs)) 20 | np.random.shuffle(indices) 21 | 22 | for start_idx in range(0, len(inputs) - batchsize + 1, batchsize): 23 | if shuffle: 24 | excerpt = indices[start_idx:start_idx + batchsize] 25 | else: 26 | excerpt = slice(start_idx, start_idx + batchsize) 27 | 28 | yield inputs[excerpt], targets[excerpt] 29 | 30 | 31 | def flatten_data(data): 32 | return np.asarray([w for t in data for w in t]).astype(np.int32) 33 | 34 | 35 | class _CuDNNRNN(RNN): 36 | def __init__(self, 37 | return_sequences=False, 38 | return_state=False, 39 | go_backwards=False, 40 | stateful=False, 41 | **kwargs): 42 | if K.backend() != 'tensorflow': 43 | raise RuntimeError('CuDNN RNNs are only available ' 44 | 'with the TensorFlow backend.') 45 | super(RNN, self).__init__(**kwargs) 46 | self.return_sequences = return_sequences 47 | self.return_state = return_state 48 | self.go_backwards = go_backwards 49 | self.stateful = stateful 50 | self.supports_masking = False 51 | self.input_spec = [InputSpec(ndim=3)] 52 | if hasattr(self.cell.state_size, '__len__'): 53 | state_size = self.cell.state_size 54 | else: 55 | state_size = [self.cell.state_size] 56 | self.state_spec = [InputSpec(shape=(None, dim)) 57 | for dim in state_size] 58 | self.constants_spec = None 59 | self._states = None 60 | self._num_constants = None 61 | 62 | def _canonical_to_params(self, weights, biases): 63 | import tensorflow as tf 64 | weights = [tf.reshape(x, (-1,)) for x in weights] 65 | biases = [tf.reshape(x, (-1,)) for x in biases] 66 | return tf.concat(weights + biases, 0) 67 | 68 | def call(self, inputs, mask=None, training=None, initial_state=None): 69 | if isinstance(mask, list): 70 | mask = mask[0] 71 | if mask is not None: 72 | raise ValueError('Masking is not supported for CuDNN RNNs.') 73 | 74 | # input shape: `(samples, time (padded with zeros), input_dim)` 75 | # note that the .build() method of subclasses MUST define 76 | # self.input_spec and self.state_spec with complete input shapes. 77 | if isinstance(inputs, list): 78 | initial_state = inputs[1:] 79 | inputs = inputs[0] 80 | elif initial_state is not None: 81 | pass 82 | elif self.stateful: 83 | initial_state = self.states 84 | else: 85 | initial_state = self.get_initial_state(inputs) 86 | 87 | if len(initial_state) != len(self.states): 88 | raise ValueError('Layer has ' + str(len(self.states)) + 89 | ' states but was passed ' + 90 | str(len(initial_state)) + 91 | ' initial states.') 92 | 93 | if self.go_backwards: 94 | # Reverse time axis. 95 | inputs = K.reverse(inputs, 1) 96 | output, states = self._process_batch(inputs, initial_state, training) 97 | 98 | if self.stateful: 99 | updates = [] 100 | for i in range(len(states)): 101 | updates.append((self.states[i], states[i])) 102 | self.add_update(updates, inputs) 103 | 104 | if self.return_state: 105 | return [output] + states 106 | else: 107 | return output 108 | 109 | def get_config(self): 110 | config = {'return_sequences': self.return_sequences, 111 | 'return_state': self.return_state, 112 | 'go_backwards': self.go_backwards, 113 | 'stateful': self.stateful} 114 | base_config = super(RNN, self).get_config() 115 | return dict(list(base_config.items()) + list(config.items())) 116 | 117 | @classmethod 118 | def from_config(cls, config): 119 | return cls(**config) 120 | 121 | @property 122 | def trainable_weights(self): 123 | if self.trainable and self.built: 124 | return [self.kernel, self.recurrent_kernel, self.bias] 125 | return [] 126 | 127 | @property 128 | def non_trainable_weights(self): 129 | if not self.trainable and self.built: 130 | return [self.kernel, self.recurrent_kernel, self.bias] 131 | return [] 132 | 133 | @property 134 | def losses(self): 135 | return super(RNN, self).losses 136 | 137 | def get_losses_for(self, inputs=None): 138 | return super(RNN, self).get_losses_for(inputs=inputs) 139 | 140 | 141 | class CuDNNLSTM(_CuDNNRNN): 142 | def __init__(self, units, 143 | kernel_initializer='glorot_uniform', 144 | recurrent_initializer='orthogonal', 145 | bias_initializer='zeros', 146 | unit_forget_bias=True, 147 | kernel_regularizer=None, 148 | recurrent_regularizer=None, 149 | bias_regularizer=None, 150 | activity_regularizer=None, 151 | kernel_constraint=None, 152 | recurrent_constraint=None, 153 | bias_constraint=None, 154 | return_sequences=False, 155 | return_state=False, 156 | stateful=False, 157 | dropout=0., 158 | **kwargs): 159 | self.units = units 160 | super(CuDNNLSTM, self).__init__( 161 | return_sequences=return_sequences, 162 | return_state=return_state, 163 | stateful=stateful, 164 | **kwargs) 165 | 166 | self.kernel_initializer = initializers.get(kernel_initializer) 167 | self.recurrent_initializer = initializers.get(recurrent_initializer) 168 | self.bias_initializer = initializers.get(bias_initializer) 169 | self.unit_forget_bias = unit_forget_bias 170 | 171 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 172 | self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 173 | self.bias_regularizer = regularizers.get(bias_regularizer) 174 | self.activity_regularizer = regularizers.get(activity_regularizer) 175 | 176 | self.kernel_constraint = constraints.get(kernel_constraint) 177 | self.recurrent_constraint = constraints.get(recurrent_constraint) 178 | self.bias_constraint = constraints.get(bias_constraint) 179 | self.dropout = dropout 180 | 181 | @property 182 | def cell(self): 183 | Cell = namedtuple('cell', 'state_size') 184 | cell = Cell(state_size=(self.units, self.units)) 185 | return cell 186 | 187 | def build(self, input_shape): 188 | super(CuDNNLSTM, self).build(input_shape) 189 | if isinstance(input_shape, list): 190 | input_shape = input_shape[0] 191 | input_dim = input_shape[-1] 192 | 193 | from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops 194 | self._cudnn_lstm = cudnn_rnn_ops.CudnnLSTM( 195 | num_layers=1, 196 | num_units=self.units, 197 | input_size=input_dim, 198 | input_mode='linear_input', 199 | dropout=self.dropout, 200 | ) 201 | 202 | self.kernel = self.add_weight(shape=(input_dim, self.units * 4), 203 | name='kernel', 204 | initializer=self.kernel_initializer, 205 | regularizer=self.kernel_regularizer, 206 | constraint=self.kernel_constraint) 207 | self.recurrent_kernel = self.add_weight( 208 | shape=(self.units, self.units * 4), 209 | name='recurrent_kernel', 210 | initializer=self.recurrent_initializer, 211 | regularizer=self.recurrent_regularizer, 212 | constraint=self.recurrent_constraint) 213 | 214 | if self.unit_forget_bias: 215 | def bias_initializer(shape, *args, **kwargs): 216 | return K.concatenate([ 217 | self.bias_initializer((self.units * 5,), *args, **kwargs), 218 | initializers.Ones()((self.units,), *args, **kwargs), 219 | self.bias_initializer((self.units * 2,), *args, **kwargs), 220 | ]) 221 | else: 222 | bias_initializer = self.bias_initializer 223 | self.bias = self.add_weight(shape=(self.units * 8,), 224 | name='bias', 225 | initializer=bias_initializer, 226 | regularizer=self.bias_regularizer, 227 | constraint=self.bias_constraint) 228 | 229 | self.kernel_i = self.kernel[:, :self.units] 230 | self.kernel_f = self.kernel[:, self.units: self.units * 2] 231 | self.kernel_c = self.kernel[:, self.units * 2: self.units * 3] 232 | self.kernel_o = self.kernel[:, self.units * 3:] 233 | 234 | self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units] 235 | self.recurrent_kernel_f = self.recurrent_kernel[:, self.units: self.units * 2] 236 | self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2: self.units * 3] 237 | self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:] 238 | 239 | self.bias_i_i = self.bias[:self.units] 240 | self.bias_f_i = self.bias[self.units: self.units * 2] 241 | self.bias_c_i = self.bias[self.units * 2: self.units * 3] 242 | self.bias_o_i = self.bias[self.units * 3: self.units * 4] 243 | self.bias_i = self.bias[self.units * 4: self.units * 5] 244 | self.bias_f = self.bias[self.units * 5: self.units * 6] 245 | self.bias_c = self.bias[self.units * 6: self.units * 7] 246 | self.bias_o = self.bias[self.units * 7:] 247 | 248 | self.built = True 249 | 250 | def _process_batch(self, inputs, initial_state, training): 251 | if training is None: 252 | training = K.learning_phase() 253 | 254 | import tensorflow as tf 255 | inputs = tf.transpose(inputs, (1, 0, 2)) 256 | input_h = initial_state[0] 257 | input_c = initial_state[1] 258 | input_h = tf.expand_dims(input_h, axis=0) 259 | input_c = tf.expand_dims(input_c, axis=0) 260 | 261 | params = self._canonical_to_params( 262 | weights=[ 263 | self.kernel_i, 264 | self.kernel_f, 265 | self.kernel_c, 266 | self.kernel_o, 267 | self.recurrent_kernel_i, 268 | self.recurrent_kernel_f, 269 | self.recurrent_kernel_c, 270 | self.recurrent_kernel_o, 271 | ], 272 | biases=[ 273 | self.bias_i_i, 274 | self.bias_f_i, 275 | self.bias_c_i, 276 | self.bias_o_i, 277 | self.bias_i, 278 | self.bias_f, 279 | self.bias_c, 280 | self.bias_o, 281 | ], 282 | ) 283 | outputs, h, c = self._cudnn_lstm( 284 | inputs, 285 | input_h=input_h, 286 | input_c=input_c, 287 | params=params, 288 | is_training=training) 289 | 290 | if self.stateful or self.return_state: 291 | h = h[0] 292 | c = c[0] 293 | if self.return_sequences: 294 | output = tf.transpose(outputs, (1, 0, 2)) 295 | else: 296 | output = outputs[-1] 297 | return output, [h, c] 298 | 299 | def get_config(self): 300 | config = { 301 | 'units': self.units, 302 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 303 | 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), 304 | 'bias_initializer': initializers.serialize(self.bias_initializer), 305 | 'unit_forget_bias': self.unit_forget_bias, 306 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 307 | 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), 308 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 309 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 310 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 311 | 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), 312 | 'bias_constraint': constraints.serialize(self.bias_constraint)} 313 | base_config = super(CuDNNLSTM, self).get_config() 314 | return dict(list(base_config.items()) + list(config.items())) 315 | 316 | 317 | class Attention(Layer): 318 | def __init__(self, units, 319 | activation='linear', 320 | use_bias=True, 321 | kernel_initializer='glorot_uniform', 322 | bias_initializer='zeros', 323 | kernel_regularizer=None, 324 | bias_regularizer=None, 325 | activity_regularizer=None, 326 | kernel_constraint=None, 327 | bias_constraint=None, 328 | **kwargs): 329 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 330 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 331 | super(Attention, self).__init__(**kwargs) 332 | self.units = units 333 | self.activation = activations.get(activation) 334 | self.use_bias = use_bias 335 | self.kernel_initializer = initializers.get(kernel_initializer) 336 | self.bias_initializer = initializers.get(bias_initializer) 337 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 338 | self.bias_regularizer = regularizers.get(bias_regularizer) 339 | self.activity_regularizer = regularizers.get(activity_regularizer) 340 | self.kernel_constraint = constraints.get(kernel_constraint) 341 | self.bias_constraint = constraints.get(bias_constraint) 342 | # self.input_spec = InputSpec(min_ndim=2) 343 | self.supports_masking = True 344 | 345 | def build(self, input_shape): 346 | if not isinstance(input_shape, list) or len(input_shape) != 2: 347 | raise ValueError('An attention layer should be called ' 348 | 'on a list of 2 inputs.') 349 | enc_dim = input_shape[0][-1] 350 | dec_dim = input_shape[1][-1] 351 | 352 | self.W_enc = self.add_weight(shape=(enc_dim, self.units), 353 | initializer=self.kernel_initializer, 354 | name='W_enc', 355 | regularizer=self.kernel_regularizer, 356 | constraint=self.kernel_constraint) 357 | 358 | self.W_dec = self.add_weight(shape=(dec_dim, self.units), 359 | initializer=self.kernel_initializer, 360 | name='W_dec', 361 | regularizer=self.kernel_regularizer, 362 | constraint=self.kernel_constraint) 363 | 364 | self.W_score = self.add_weight(shape=(self.units, 1), 365 | initializer=self.kernel_initializer, 366 | name='W_score', 367 | regularizer=self.kernel_regularizer, 368 | constraint=self.kernel_constraint) 369 | 370 | if self.use_bias: 371 | self.bias_enc = self.add_weight(shape=(self.units,), 372 | initializer=self.bias_initializer, 373 | name='bias_enc', 374 | regularizer=self.bias_regularizer, 375 | constraint=self.bias_constraint) 376 | self.bias_dec = self.add_weight(shape=(self.units,), 377 | initializer=self.bias_initializer, 378 | name='bias_dec', 379 | regularizer=self.bias_regularizer, 380 | constraint=self.bias_constraint) 381 | self.bias_score = self.add_weight(shape=(1,), 382 | initializer=self.bias_initializer, 383 | name='bias_score', 384 | regularizer=self.bias_regularizer, 385 | constraint=self.bias_constraint) 386 | 387 | else: 388 | self.bias_enc = None 389 | self.bias_dec = None 390 | self.bias_score = None 391 | 392 | # self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) 393 | self.built = True 394 | 395 | def call(self, inputs, **kwargs): 396 | if not isinstance(inputs, list) or len(inputs) != 2: 397 | raise ValueError('An attention layer should be called ' 398 | 'on a list of 2 inputs.') 399 | encodings, decodings = inputs 400 | d_enc = K.dot(encodings, self.W_enc) 401 | d_dec = K.dot(decodings, self.W_dec) 402 | 403 | if self.use_bias: 404 | d_enc = K.bias_add(d_enc, self.bias_enc) 405 | d_dec = K.bias_add(d_dec, self.bias_dec) 406 | 407 | if self.activation is not None: 408 | d_enc = self.activation(d_enc) 409 | d_dec = self.activation(d_dec) 410 | 411 | enc_seqlen = K.shape(d_enc)[1] 412 | d_dec_shape = K.shape(d_dec) 413 | 414 | stacked_d_dec = K.tile(d_dec, [enc_seqlen, 1, 1]) # enc time x batch x dec time x da 415 | stacked_d_dec = K.reshape(stacked_d_dec, [enc_seqlen, d_dec_shape[0], d_dec_shape[1], d_dec_shape[2]]) 416 | stacked_d_dec = K.permute_dimensions(stacked_d_dec, [2, 1, 0, 3]) # dec time x batch x enc time x da 417 | tanh_add = K.tanh(stacked_d_dec + d_enc) # dec time x batch x enc time x da 418 | scores = K.dot(tanh_add, self.W_score) 419 | if self.use_bias: 420 | scores = K.bias_add(scores, self.bias_score) 421 | scores = K.squeeze(scores, 3) # batch x dec time x enc time 422 | 423 | weights = K.softmax(scores) # dec time x batch x enc time 424 | weights = K.expand_dims(weights) 425 | 426 | weighted_encodings = weights * encodings # dec time x batch x enc time x h 427 | contexts = K.sum(weighted_encodings, axis=2) # dec time x batch x h 428 | contexts = K.permute_dimensions(contexts, [1, 0, 2]) # batch x dec time x h 429 | 430 | return contexts 431 | 432 | def compute_output_shape(self, input_shape): 433 | assert isinstance(input_shape, list) and len(input_shape) == 2 434 | assert input_shape[-1] 435 | output_shape = list(input_shape[1]) 436 | output_shape[-1] = self.units 437 | return tuple(output_shape) 438 | 439 | def get_config(self): 440 | config = { 441 | 'units': self.units, 442 | 'activation': activations.serialize(self.activation), 443 | 'use_bias': self.use_bias, 444 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 445 | 'bias_initializer': initializers.serialize(self.bias_initializer), 446 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 447 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 448 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 449 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 450 | 'bias_constraint': constraints.serialize(self.bias_constraint) 451 | } 452 | base_config = super(Attention, self).get_config() 453 | return dict(list(base_config.items()) + list(config.items())) 454 | 455 | 456 | class DenseTransposeTied(Layer): 457 | @interfaces.legacy_dense_support 458 | def __init__(self, units, 459 | tied_to=None, # Enter a layer as input to enforce weight-tying 460 | activation=None, 461 | use_bias=True, 462 | kernel_initializer='glorot_uniform', 463 | bias_initializer='zeros', 464 | kernel_regularizer=None, 465 | bias_regularizer=None, 466 | activity_regularizer=None, 467 | kernel_constraint=None, 468 | bias_constraint=None, 469 | **kwargs): 470 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 471 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 472 | super(DenseTransposeTied, self).__init__(**kwargs) 473 | self.units = units 474 | # We add these two properties to save the tied weights 475 | self.tied_to = tied_to 476 | self.tied_weights = self.tied_to.weights 477 | self.activation = activations.get(activation) 478 | self.use_bias = use_bias 479 | self.kernel_initializer = initializers.get(kernel_initializer) 480 | self.bias_initializer = initializers.get(bias_initializer) 481 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 482 | self.bias_regularizer = regularizers.get(bias_regularizer) 483 | self.activity_regularizer = regularizers.get(activity_regularizer) 484 | self.kernel_constraint = constraints.get(kernel_constraint) 485 | self.bias_constraint = constraints.get(bias_constraint) 486 | self.input_spec = InputSpec(min_ndim=2) 487 | self.supports_masking = True 488 | 489 | def build(self, input_shape): 490 | assert len(input_shape) >= 2 491 | input_dim = input_shape[-1] 492 | 493 | # We remove the weights and bias because we do not want them to be trainable 494 | self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) 495 | if self.use_bias: 496 | self.bias = self.add_weight(shape=(self.units,), 497 | initializer=self.bias_initializer, 498 | name='bias', 499 | regularizer=self.bias_regularizer, 500 | constraint=self.bias_constraint) 501 | else: 502 | self.bias = None 503 | self.built = True 504 | 505 | def call(self, inputs, **kwargs): 506 | # Return the transpose layer mapping using the explicit weight matrices 507 | output = K.dot(inputs, K.transpose(self.tied_weights[0])) 508 | if self.use_bias: 509 | output = K.bias_add(output, self.bias, data_format='channels_last') 510 | 511 | if self.activation is not None: 512 | output = self.activation(output) 513 | 514 | return output 515 | 516 | def compute_output_shape(self, input_shape): 517 | assert input_shape and len(input_shape) >= 2 518 | assert input_shape[-1] 519 | output_shape = list(input_shape) 520 | output_shape[-1] = self.units 521 | return tuple(output_shape) 522 | 523 | def get_config(self): 524 | config = { 525 | 'units': self.units, 526 | 'activation': activations.serialize(self.activation), 527 | 'use_bias': self.use_bias, 528 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 529 | 'bias_initializer': initializers.serialize(self.bias_initializer), 530 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 531 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 532 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 533 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 534 | 'bias_constraint': constraints.serialize(self.bias_constraint) 535 | } 536 | base_config = super(DenseTransposeTied, self).get_config() 537 | return dict(list(base_config.items()) + list(config.items())) --------------------------------------------------------------------------------