├── .gitignore ├── README.md ├── data ├── MQ2008 │ └── Fold1 │ │ ├── test.txt │ │ ├── train.txt │ │ └── vali.txt └── processed │ ├── test_feature.npy │ ├── test_group.npy │ ├── test_label.npy │ ├── train_feature.npy │ ├── train_group.npy │ ├── train_label.npy │ ├── vali_feature.npy │ ├── vali_group.npy │ └── vali_label.npy └── src ├── __init__.py ├── freeze_graph.sh ├── main.py ├── metrics.py ├── model.py ├── prepare_data.py ├── tf_common ├── __init__.py ├── nadam.py └── nn_module.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *.py[cod] 3 | *$py.class 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | MANIFEST 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | .static_storage/ 55 | .media/ 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # Environments 84 | .env 85 | .venv 86 | env/ 87 | venv/ 88 | ENV/ 89 | env.bak/ 90 | venv.bak/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | 105 | # 106 | .idea 107 | src/.idea 108 | logs 109 | weights 110 | src/freeze_graph.py 111 | data/Letor/* 112 | data/np* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-LTR 2 | 3 | Ongoing projects for implementing various Learning to Rank (LTR) models. 4 | 5 | - pointwise 6 | - classification 7 | - DNN 8 | - LR 9 | - pairwise 10 | - RankNet 11 | - LambdaRank 12 | - listwise 13 | - ListNet (TODO) 14 | 15 | # References 16 | [1] Hang Li, *A Short Introduction to Learning to Rank* 17 | 18 | [2] Christopher J.C. Burges, *From RankNet to LambdaRank to LambdaMART: An Overview* 19 | 20 | # Acknowledgments 21 | This project gets inspirations from the following projects: 22 | - [learning-rank-public](https://github.com/andreweskeclarke/learning-rank-public) 23 | - [learning2rank](https://github.com/shiba24/learning2rank) -------------------------------------------------------------------------------- /data/processed/test_feature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-LTR/fb302aafd18b3418dac97631cb75ad536cf5f176/data/processed/test_feature.npy -------------------------------------------------------------------------------- /data/processed/test_group.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-LTR/fb302aafd18b3418dac97631cb75ad536cf5f176/data/processed/test_group.npy -------------------------------------------------------------------------------- /data/processed/test_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-LTR/fb302aafd18b3418dac97631cb75ad536cf5f176/data/processed/test_label.npy -------------------------------------------------------------------------------- /data/processed/train_feature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-LTR/fb302aafd18b3418dac97631cb75ad536cf5f176/data/processed/train_feature.npy -------------------------------------------------------------------------------- /data/processed/train_group.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-LTR/fb302aafd18b3418dac97631cb75ad536cf5f176/data/processed/train_group.npy -------------------------------------------------------------------------------- /data/processed/train_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-LTR/fb302aafd18b3418dac97631cb75ad536cf5f176/data/processed/train_label.npy -------------------------------------------------------------------------------- /data/processed/vali_feature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-LTR/fb302aafd18b3418dac97631cb75ad536cf5f176/data/processed/vali_feature.npy -------------------------------------------------------------------------------- /data/processed/vali_group.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-LTR/fb302aafd18b3418dac97631cb75ad536cf5f176/data/processed/vali_group.npy -------------------------------------------------------------------------------- /data/processed/vali_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-LTR/fb302aafd18b3418dac97631cb75ad536cf5f176/data/processed/vali_label.npy -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-LTR/fb302aafd18b3418dac97631cb75ad536cf5f176/src/__init__.py -------------------------------------------------------------------------------- /src/freeze_graph.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | $model_type=$1 3 | python freeze_graph.py --input_graph=../weights/$model_type/graph.pb \ 4 | --input_checkpoint=../weights/$model_type/model.checkpoint \ 5 | --output_graph=../weights/$model_type/freeze_graph.pb \ 6 | --output_node_names="ranking/feature,ranking/training,ranking/score" 7 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import numpy as np 4 | 5 | import utils 6 | from model import LogisticRegression, DNN, RankNet, LambdaRank 7 | from prepare_data import label_file_pat, group_file_pat, feature_file_pat 8 | 9 | def load_data(type): 10 | 11 | labels = np.load(label_file_pat%type) 12 | qids = np.load(group_file_pat % type) 13 | features = np.load(feature_file_pat%type) 14 | 15 | X = { 16 | "feature": features, 17 | "label": labels, 18 | "qid": qids 19 | } 20 | return X 21 | 22 | 23 | utils._makedirs("../logs") 24 | logger = utils._get_logger("../logs", "tf-%s.log" % utils._timestamp()) 25 | 26 | params_common = { 27 | # you might have to tune the batch size to get ranknet and lambdarank working 28 | # keep in mind the followings: 29 | # 1. batch size should be large enough to ensure there are samples of different 30 | # relevance labels from the same group, especially when you use "sample" as "batch_sampling_method" 31 | # this ensure the gradients are nonzeros and stable across batches, 32 | # which is important for pairwise method, e.g., ranknet and lambdarank 33 | # 2. batch size should not be very large since the lambda_ij matrix in ranknet and lambdarank 34 | # (which are of size batch_size x batch_size) will consume large memory space 35 | "batch_size": 128, 36 | "epoch": 50, 37 | "feature_dim": 46, 38 | 39 | "batch_sampling_method": "sample", 40 | "shuffle": True, 41 | 42 | "optimizer_type": "adam", 43 | "init_lr": 0.001, 44 | "beta1": 0.975, 45 | "beta2": 0.999, 46 | "decay_steps": 1000, 47 | "decay_rate": 0.9, 48 | "schedule_decay": 0.004, 49 | "random_seed": 2018, 50 | "eval_every_num_update": 100, 51 | } 52 | 53 | 54 | def train_lr(): 55 | params = { 56 | "offline_model_dir": "../weights/lr", 57 | } 58 | params.update(params_common) 59 | 60 | X_train, X_valid = load_data("train"), load_data("vali") 61 | 62 | model = LogisticRegression("ranking", params, logger) 63 | model.fit(X_train, validation_data=X_valid) 64 | model.save_session() 65 | 66 | 67 | def train_dnn(): 68 | params = { 69 | "offline_model_dir": "../weights/dnn", 70 | 71 | # deep part score fn 72 | "fc_type": "fc", 73 | "fc_dim": 32, 74 | "fc_dropout": 0., 75 | } 76 | params.update(params_common) 77 | 78 | X_train, X_valid = load_data("train"), load_data("vali") 79 | 80 | model = DNN("ranking", params, logger) 81 | model.fit(X_train, validation_data=X_valid) 82 | model.save_session() 83 | 84 | 85 | def train_ranknet(): 86 | params = { 87 | "offline_model_dir": "../weights/ranknet", 88 | 89 | # deep part score fn 90 | "fc_type": "fc", 91 | "fc_dim": 32, 92 | "fc_dropout": 0., 93 | 94 | # ranknet param 95 | "factorization": True, 96 | "sigma": 1., 97 | } 98 | params.update(params_common) 99 | 100 | X_train, X_valid = load_data("train"), load_data("vali") 101 | 102 | model = RankNet("ranking", params, logger) 103 | model.fit(X_train, validation_data=X_valid) 104 | model.save_session() 105 | 106 | 107 | def train_lambdarank(): 108 | params = { 109 | "offline_model_dir": "../weights/lambdarank", 110 | 111 | # deep part score fn 112 | "fc_type": "fc", 113 | "fc_dim": 32, 114 | "fc_dropout": 0., 115 | 116 | # lambdarank param 117 | "sigma": 1., 118 | } 119 | params.update(params_common) 120 | 121 | X_train, X_valid = load_data("train"), load_data("vali") 122 | 123 | model = LambdaRank("ranking", params, logger) 124 | model.fit(X_train, validation_data=X_valid) 125 | model.save_session() 126 | 127 | 128 | def main(): 129 | if len(sys.argv) > 1: 130 | if sys.argv[1] == "lr": 131 | train_lr() 132 | elif sys.argv[1] == "dnn": 133 | train_dnn() 134 | elif sys.argv[1] == "ranknet": 135 | train_ranknet() 136 | elif sys.argv[1] == "lambdarank": 137 | train_lambdarank() 138 | else: 139 | train_lr() 140 | 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | 5 | # taken from: https://github.com/andreweskeclarke/learning-rank-public 6 | def calc_err(predicted_order): 7 | err = 0 8 | prev_one_min_rel_prod = 1 9 | previous_rel = 0 10 | T = len(predicted_order) if len(predicted_order) < 10 else 10 11 | for r in range(T): 12 | rel_r = calc_ri(predicted_order, r) 13 | one_min_rel_prod = (1 - previous_rel) * prev_one_min_rel_prod 14 | err += (1 / (r+1)) * rel_r * one_min_rel_prod 15 | prev_one_min_rel_prod = one_min_rel_prod 16 | previous_rel = rel_r 17 | 18 | return err 19 | 20 | 21 | def calc_ri(predicted_order, i): 22 | return (2 ** predicted_order[i] - 1) / (2 ** np.max(predicted_order)) 23 | 24 | 25 | def dcg(predicted_order): 26 | i = np.log(1. + np.arange(1,len(predicted_order)+1)) 27 | l = 2 ** (np.array(predicted_order)) - 1 28 | return np.sum(l/i) 29 | 30 | 31 | def ndcg(score, top_ten=True): 32 | end = 10 if top_ten else len(score) 33 | sorted_score = np.sort(score)[::-1] 34 | dcg_ = dcg(score[:end]) 35 | if dcg_ == 0: 36 | return 0 37 | dcg_max = dcg(sorted_score[:end]) 38 | return dcg_/dcg_max 39 | 40 | 41 | if __name__ == "__main__": 42 | predicted_order_ = [4, 4, 2, 3, 2, 4, 0, 1, 1, 4, 1, 3, 3, 2, 3, 4, 2, 1, 0, 0] 43 | print(calc_err(predicted_order_)) 44 | print(dcg(predicted_order_)) 45 | print((ndcg(predicted_order_))) 46 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import numpy as np 4 | import pandas as pd 5 | import tensorflow as tf 6 | 7 | import utils 8 | from metrics import ndcg, calc_err 9 | from tf_common.nn_module import resnet_block, dense_block 10 | from tf_common.nadam import NadamOptimizer 11 | 12 | 13 | class BaseRankModel(object): 14 | 15 | def __init__(self, model_name, params, logger, training=True): 16 | self.model_name = model_name 17 | self.params = params 18 | self.logger = logger 19 | utils._makedirs(self.params["offline_model_dir"], force=training) 20 | 21 | self._init_tf_vars() 22 | self.loss, self.num_pairs, self.score, self.train_op = self._build_model() 23 | 24 | self.sess, self.saver = self._init_session() 25 | 26 | 27 | def _init_tf_vars(self): 28 | with tf.name_scope(self.model_name): 29 | #### input for training and inference 30 | self.feature = tf.placeholder(tf.float32, shape=[None, self.params["feature_dim"]], name="feature") 31 | self.training = tf.placeholder(tf.bool, shape=[], name="training") 32 | #### input for training 33 | self.label = tf.placeholder(tf.float32, shape=[None, 1], name="label") 34 | self.sorted_label = tf.placeholder(tf.float32, shape=[None, 1], name="sorted_label") 35 | self.qid = tf.placeholder(tf.float32, shape=[None, 1], name="qid") 36 | #### vars for training 37 | self.global_step = tf.Variable(0, trainable=False) 38 | self.learning_rate = tf.train.exponential_decay(self.params["init_lr"], self.global_step, 39 | self.params["decay_steps"], self.params["decay_rate"]) 40 | self.batch_size = tf.placeholder(tf.int32, shape=[], name="batch_size") 41 | 42 | 43 | def _build_model(self): 44 | return None, None, None, None 45 | 46 | 47 | def _score_fn_inner(self, x, reuse=False): 48 | # deep 49 | hidden_units = [self.params["fc_dim"] * 4, self.params["fc_dim"] * 2, self.params["fc_dim"]] 50 | dropouts = [self.params["fc_dropout"]] * len(hidden_units) 51 | out = dense_block(x, hidden_units=hidden_units, dropouts=dropouts, densenet=False, reuse=reuse, 52 | training=self.training, seed=self.params["random_seed"]) 53 | # score 54 | score = tf.layers.dense(out, 1, activation=None, 55 | kernel_initializer=tf.glorot_uniform_initializer(seed=self.params["random_seed"])) 56 | 57 | return score 58 | 59 | 60 | def _score_fn(self, x, reuse=False): 61 | # https://stackoverflow.com/questions/45670224/why-the-tf-name-scope-with-same-name-is-different 62 | with tf.name_scope(self.model_name+"/"): 63 | score = self._score_fn_inner(x, reuse) 64 | # https://stackoverflow.com/questions/46980287/output-node-for-tensorflow-graph-created-with-tf-layers 65 | # add an identity node to output graph 66 | score = tf.identity(score, "score") 67 | 68 | return score 69 | 70 | 71 | def _jacobian(self, y_flat, x): 72 | """ 73 | https://github.com/tensorflow/tensorflow/issues/675 74 | for ranknet and lambdarank 75 | """ 76 | loop_vars = [ 77 | tf.constant(0, tf.int32), 78 | tf.TensorArray(tf.float32, size=self.batch_size), 79 | ] 80 | 81 | _, jacobian = tf.while_loop( 82 | lambda j, _: j < self.batch_size, 83 | lambda j, result: (j + 1, result.write(j, tf.gradients(y_flat[j], x))), 84 | loop_vars) 85 | 86 | return jacobian.stack() 87 | 88 | 89 | def _get_derivative(self, score, Wk, lambda_ij, x): 90 | """ 91 | for ranknet and lambdarank 92 | :param score: 93 | :param Wk: 94 | :param lambda_ij: 95 | :return: 96 | """ 97 | # dsi_dWk = tf.map_fn(lambda s: tf.gradients(s, [Wk])[0], score) # do not work 98 | # dsi_dWk = tf.stack([tf.gradients(si, x)[0] for si in tf.unstack(score, axis=1)], axis=2) # do not work 99 | dsi_dWk = self._jacobian(score, Wk) 100 | dsi_dWk_minus_dsj_dWk = tf.expand_dims(dsi_dWk, 1) - tf.expand_dims(dsi_dWk, 0) 101 | shape = tf.concat( 102 | [tf.shape(lambda_ij), tf.ones([tf.rank(dsi_dWk_minus_dsj_dWk) - tf.rank(lambda_ij)], dtype=tf.int32)], 103 | axis=0) 104 | grad = tf.reduce_mean(tf.reshape(lambda_ij, shape) * dsi_dWk_minus_dsj_dWk, axis=[0, 1]) 105 | return tf.reshape(grad, tf.shape(Wk)) 106 | 107 | 108 | def _get_train_op(self, loss): 109 | """ 110 | for model that gradient can be computed with respect to loss, e.g., LogisticRegression and RankNet 111 | """ 112 | with tf.name_scope("optimization"): 113 | if self.params["optimizer_type"] == "nadam": 114 | optimizer = NadamOptimizer(learning_rate=self.learning_rate, beta1=self.params["beta1"], 115 | beta2=self.params["beta2"], epsilon=1e-8, 116 | schedule_decay=self.params["schedule_decay"]) 117 | elif self.params["optimizer_type"] == "adam": 118 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.params["beta1"], 119 | beta2=self.params["beta2"], epsilon=1e-8) 120 | 121 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 122 | with tf.control_dependencies(update_ops): 123 | train_op = optimizer.minimize(loss, global_step=self.global_step) 124 | 125 | return train_op 126 | 127 | 128 | def _init_session(self): 129 | config = tf.ConfigProto(device_count={"gpu": 1}) 130 | config.gpu_options.allow_growth = True 131 | config.intra_op_parallelism_threads = 4 132 | config.inter_op_parallelism_threads = 4 133 | sess = tf.Session(config=config) 134 | sess.run(tf.global_variables_initializer()) 135 | # max_to_keep=None, keep all the models 136 | saver = tf.train.Saver(max_to_keep=None) 137 | return sess, saver 138 | 139 | 140 | def save_session(self): 141 | # write graph for freeze_graph.py 142 | tf.train.write_graph(self.sess.graph.as_graph_def(), self.params["offline_model_dir"], "graph.pb", as_text=True) 143 | self.saver.save(self.sess, self.params["offline_model_dir"] + "/model.checkpoint") 144 | 145 | 146 | def restore_session(self): 147 | self.saver.restore(self.sess, self.params["offline_model_dir"] + "/model.checkpoint") 148 | 149 | 150 | def _get_batch_index(self, seq, step): 151 | n = len(seq) 152 | res = [] 153 | for i in range(0, n, step): 154 | res.append(seq[i:i + step]) 155 | # last batch 156 | if len(res) * step < n: 157 | res.append(seq[len(res) * step:]) 158 | return res 159 | 160 | 161 | def _get_feed_dict(self, X, idx, training=False): 162 | feed_dict = { 163 | self.feature: X["feature"][idx], 164 | self.label: X["label"][idx].reshape((-1, 1)), 165 | self.qid: X["qid"][idx].reshape((-1, 1)), 166 | self.sorted_label: np.sort(X["label"][idx].reshape((-1, 1)))[::-1], 167 | self.training: training, 168 | self.batch_size: len(idx), 169 | } 170 | 171 | return feed_dict 172 | 173 | 174 | def fit(self, X, validation_data): 175 | start_time = time.time() 176 | l = X["feature"].shape[0] 177 | self.logger.info("fit on %d sample" % l) 178 | qid_unique = np.unique(X["qid"]) 179 | num_qid_unique = len(qid_unique) 180 | if self.params["batch_sampling_method"] == "group": 181 | train_idx_shuffle = np.arange(num_qid_unique) 182 | else: 183 | train_idx_shuffle = np.arange(l) 184 | total_loss = 0. 185 | loss_decay = 0.9 186 | total_batch = 0 187 | # evaluate before training 188 | loss_mean_train, err_mean_train, ndcg_mean_train, ndcg_all_mean_train = self.evaluate(X) 189 | if validation_data is not None: 190 | loss_mean_valid, err_mean_valid, ndcg_mean_valid, ndcg_all_mean_valid = self.evaluate(validation_data) 191 | self.logger.info( 192 | "[epoch-{}, batch-{}] -- Train Loss: {:5f} NDCG: {:5f} ({:5f}) ERR: {:5f} -- Valid Loss: {:5f} NDCG: {:5f} ({:5f}) ERR: {:5f} -- {:5f} s".format( 193 | 0, 0, loss_mean_train, ndcg_mean_train, ndcg_all_mean_train, err_mean_train, 194 | loss_mean_valid, ndcg_mean_valid, ndcg_all_mean_valid, err_mean_valid, 195 | time.time() - start_time)) 196 | else: 197 | self.logger.info( 198 | "[epoch-{}, batch-{}] -- Train Loss: {:5f} NDCG: {:5f} ({:5f}) ERR: {:5f} -- {:5f} s".format( 199 | 0, 0, loss_mean_train, ndcg_mean_train, ndcg_all_mean_train, err_mean_train, 200 | time.time() - start_time)) 201 | for epoch in range(self.params["epoch"]): 202 | self.logger.info("epoch: %d" % (epoch + 1)) 203 | np.random.seed(epoch) 204 | if self.params["shuffle"]: 205 | np.random.shuffle(train_idx_shuffle) 206 | batches = self._get_batch_index(train_idx_shuffle, self.params["batch_size"]) 207 | for i, idx in enumerate(batches): 208 | if self.params["batch_sampling_method"] == "group": 209 | ind = utils._get_intersect_index(X["qid"], qid_unique[idx]) 210 | else: 211 | ind = idx 212 | feed_dict = self._get_feed_dict(X, ind, training=True) 213 | loss, lr, opt = self.sess.run((self.loss, self.learning_rate, self.train_op), feed_dict=feed_dict) 214 | total_loss = loss_decay * total_loss + (1. - loss_decay) * loss 215 | total_batch += 1 216 | if total_batch % self.params["eval_every_num_update"] == 0: 217 | loss_mean_train, err_mean_train, ndcg_mean_train, ndcg_all_mean_train = self.evaluate(X) 218 | if validation_data is not None: 219 | loss_mean_valid, err_mean_valid, ndcg_mean_valid, ndcg_all_mean_valid = self.evaluate(validation_data) 220 | self.logger.info( 221 | "[epoch-{}, batch-{}] -- Train Loss: {:5f} NDCG: {:5f} ({:5f}) ERR: {:5f} -- Valid Loss: {:5f} NDCG: {:5f} ({:5f}) ERR: {:5f} -- {:5f} s".format( 222 | epoch + 1, total_batch, loss_mean_train, ndcg_mean_train, ndcg_all_mean_train, err_mean_train, 223 | loss_mean_valid, ndcg_mean_valid, ndcg_all_mean_valid, err_mean_valid, time.time() - start_time)) 224 | else: 225 | self.logger.info( 226 | "[epoch-{}, batch-{}] -- Train Loss: {:5f} NDCG: {:5f} ({:5f}) ERR: {:5f} -- {:5f} s".format( 227 | epoch + 1, total_batch, loss_mean_train, ndcg_mean_train, ndcg_all_mean_train, err_mean_train, 228 | time.time() - start_time)) 229 | 230 | 231 | def predict(self, X): 232 | l = X["feature"].shape[0] 233 | train_idx = np.arange(l) 234 | batches = self._get_batch_index(train_idx, self.params["batch_size"]) 235 | y_pred = [] 236 | y_pred_append = y_pred.append 237 | for idx in batches: 238 | feed_dict = self._get_feed_dict(X, idx, training=False) 239 | pred = self.sess.run((self.score), feed_dict=feed_dict) 240 | y_pred_append(pred) 241 | y_pred = np.vstack(y_pred).reshape((-1, 1)) 242 | return y_pred 243 | 244 | 245 | def evaluate(self, X): 246 | qid_unique = np.unique(X["qid"]) 247 | n = len(qid_unique) 248 | losses = np.zeros(n) 249 | ndcgs = np.zeros(n) 250 | ndcgs_all = np.zeros(n) 251 | errs = np.zeros(n) 252 | for e,qid in enumerate(qid_unique): 253 | ind = np.where(X["qid"] == qid)[0] 254 | feed_dict = self._get_feed_dict(X, ind, training=False) 255 | loss, score = self.sess.run((self.loss, self.score), feed_dict=feed_dict) 256 | df = pd.DataFrame({"label": X["label"][ind].flatten(), "score": score.flatten()}) 257 | df.sort_values("score", ascending=False, inplace=True) 258 | 259 | losses[e] = loss 260 | ndcgs[e] = ndcg(df["label"]) 261 | ndcgs_all[e] = ndcg(df["label"], top_ten=False) 262 | errs[e] = calc_err(df["label"]) 263 | losses_mean = np.mean(losses) 264 | ndcgs_mean = np.mean(ndcgs) 265 | ndcgs_all_mean = np.mean(ndcgs_all) 266 | errs_mean = np.mean(errs) 267 | return losses_mean, errs_mean, ndcgs_mean, ndcgs_all_mean 268 | 269 | 270 | class DNN(BaseRankModel): 271 | 272 | def __init__(self, model_name, params, logger, training=True): 273 | super(DNN, self).__init__(model_name, params, logger, training) 274 | 275 | def _build_model(self): 276 | # score 277 | score = logits = self._score_fn(self.feature) 278 | 279 | logloss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=self.label) 280 | loss = tf.reduce_mean(logloss) 281 | num_pairs = tf.shape(self.feature)[0] 282 | 283 | return loss, num_pairs, score, self._get_train_op(loss) 284 | 285 | 286 | class LogisticRegression(DNN): 287 | 288 | def __init__(self, model_name, params, logger, training=True): 289 | super(LogisticRegression, self).__init__(model_name, params, logger, training) 290 | 291 | 292 | def _score_fn_inner(self, x, reuse=False): 293 | score = tf.layers.dense(x, 1, activation=None, 294 | kernel_initializer=tf.glorot_uniform_initializer(seed=self.params["random_seed"])) 295 | return score 296 | 297 | 298 | class RankNet(BaseRankModel): 299 | 300 | def __init__(self, model_name, params, logger, training=True): 301 | super(RankNet, self).__init__(model_name, params, logger, training) 302 | 303 | 304 | def _build_model(self): 305 | if self.params["factorization"]: 306 | return self._build_factorized_model() 307 | else: 308 | return self._build_unfactorized_model() 309 | 310 | 311 | def _build_unfactorized_model(self): 312 | # score 313 | score = self._score_fn(self.feature) 314 | 315 | # 316 | S_ij = self.label - tf.transpose(self.label) 317 | S_ij = tf.maximum(tf.minimum(1., S_ij), -1.) 318 | P_ij = (1 / 2) * (1 + S_ij) 319 | s_i_minus_s_j = logits = score - tf.transpose(score) 320 | 321 | logloss = tf.nn.sigmoid_cross_entropy_with_logits(logits=s_i_minus_s_j, labels=P_ij) 322 | 323 | # only extracted the loss of pairs of the same group 324 | mask1 = tf.equal(self.qid - tf.transpose(self.qid), 0) 325 | mask1 = tf.cast(mask1, tf.float32) 326 | # exclude the pair of sample and itself 327 | n = tf.shape(self.feature)[0] 328 | mask2 = tf.ones([n, n]) - tf.diag(tf.ones([n])) 329 | mask = mask1 * mask2 330 | num_pairs = tf.reduce_sum(mask) 331 | 332 | loss = tf.cond(tf.equal(num_pairs, 0), lambda: 0., lambda: tf.reduce_sum(logloss * mask) / num_pairs) 333 | 334 | return loss, num_pairs, score, self._get_train_op(loss) 335 | 336 | 337 | def _build_factorized_model(self): 338 | # score 339 | score = self._score_fn(self.feature) 340 | 341 | # 342 | S_ij = self.label - tf.transpose(self.label) 343 | S_ij = tf.maximum(tf.minimum(1., S_ij), -1.) 344 | P_ij = (1 / 2) * (1 + S_ij) 345 | s_i_minus_s_j = logits = score - tf.transpose(score) 346 | sigma = self.params["sigma"] 347 | lambda_ij = sigma * ((1 / 2) * (1 - S_ij) - tf.nn.sigmoid(-sigma*s_i_minus_s_j)) 348 | 349 | logloss = tf.nn.sigmoid_cross_entropy_with_logits(logits=s_i_minus_s_j, labels=P_ij) 350 | 351 | # only extracted the loss of pairs of the same group 352 | mask1 = tf.equal(self.qid - tf.transpose(self.qid), 0) 353 | mask1 = tf.cast(mask1, tf.float32) 354 | # exclude the pair of sample and itself 355 | n = tf.shape(self.feature)[0] 356 | mask2 = tf.ones([n, n]) - tf.diag(tf.ones([n])) 357 | mask = mask1 * mask2 358 | num_pairs = tf.reduce_sum(mask) 359 | 360 | loss = tf.cond(tf.equal(num_pairs, 0), lambda: 0., lambda: tf.reduce_sum(logloss * mask) / num_pairs) 361 | 362 | lambda_ij = lambda_ij * mask 363 | 364 | vars = tf.trainable_variables() 365 | grads = [self._get_derivative(score, Wk, lambda_ij, self.feature) for Wk in vars] 366 | 367 | with tf.name_scope("optimization"): 368 | if self.params["optimizer_type"] == "nadam": 369 | optimizer = NadamOptimizer(learning_rate=self.learning_rate, beta1=self.params["beta1"], 370 | beta2=self.params["beta2"], epsilon=1e-8, 371 | schedule_decay=self.params["schedule_decay"]) 372 | elif self.params["optimizer_type"] == "adam": 373 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.params["beta1"], 374 | beta2=self.params["beta2"], epsilon=1e-8) 375 | 376 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 377 | with tf.control_dependencies(update_ops): 378 | train_op = optimizer.apply_gradients(zip(grads, vars)) 379 | 380 | return loss, num_pairs, score, train_op 381 | 382 | 383 | class LambdaRank(BaseRankModel): 384 | 385 | def __init__(self, model_name, params, logger, training=True): 386 | super(LambdaRank, self).__init__(model_name, params, logger, training) 387 | 388 | 389 | def _build_model(self): 390 | # score 391 | score = self._score_fn(self.feature) 392 | 393 | # 394 | S_ij = self.label - tf.transpose(self.label) 395 | S_ij = tf.maximum(tf.minimum(1., S_ij), -1.) 396 | P_ij = (1 / 2) * (1 + S_ij) 397 | s_i_minus_s_j = logits = score - tf.transpose(score) 398 | sigma = self.params["sigma"] 399 | lambda_ij = sigma * ((1 / 2) * (1 - S_ij) - tf.nn.sigmoid(-sigma*s_i_minus_s_j)) 400 | # lambda_ij = -sigma * tf.nn.sigmoid(-sigma*s_i_minus_s_j) 401 | 402 | logloss = tf.nn.sigmoid_cross_entropy_with_logits(logits=s_i_minus_s_j, labels=P_ij) 403 | 404 | # only extracted the loss of pairs of the same group 405 | mask1 = tf.equal(self.qid - tf.transpose(self.qid), 0) 406 | mask1 = tf.cast(mask1, tf.float32) 407 | # exclude the pair of sample and itself 408 | n = tf.shape(self.feature)[0] 409 | mask2 = tf.ones([n, n]) - tf.diag(tf.ones([n])) 410 | mask = mask1 * mask2 411 | num_pairs = tf.reduce_sum(mask) 412 | 413 | loss = tf.cond(tf.equal(num_pairs, 0), lambda: 0., lambda: tf.reduce_sum(logloss * mask) / num_pairs) 414 | 415 | lambda_ij = lambda_ij * mask 416 | 417 | # multiply by delta ndcg 418 | # current dcg 419 | index = tf.reshape(tf.range(1., tf.cast(self.batch_size, dtype=tf.float32) + 1), tf.shape(self.label)) 420 | cg_discount = tf.log(1. + index) 421 | rel = 2 ** self.label - 1 422 | sorted_rel = 2 ** self.sorted_label - 1 423 | dcg_m = rel / cg_discount 424 | dcg = tf.reduce_sum(dcg_m) 425 | # every possible swapped dcg 426 | stale_ij = tf.tile(dcg_m, [1, self.batch_size]) 427 | new_ij = rel / tf.transpose(cg_discount) 428 | stale_ji = tf.transpose(stale_ij) 429 | new_ji = tf.transpose(new_ij) 430 | # new dcg 431 | dcg_new = dcg - stale_ij + new_ij - stale_ji + new_ji 432 | # delta ndcg 433 | # sorted_label = tf.contrib.framework.sort(self.label, direction="DESCENDING") 434 | dcg_max = tf.reduce_sum(sorted_rel / cg_discount) 435 | ndcg_delta = tf.abs(dcg_new - dcg) / dcg_max 436 | lambda_ij = lambda_ij * ndcg_delta 437 | 438 | vars = tf.trainable_variables() 439 | grads = [self._get_derivative(score, Wk, lambda_ij, self.feature) for Wk in vars] 440 | 441 | with tf.name_scope("optimization"): 442 | if self.params["optimizer_type"] == "nadam": 443 | optimizer = NadamOptimizer(learning_rate=self.learning_rate, beta1=self.params["beta1"], 444 | beta2=self.params["beta2"], epsilon=1e-8, 445 | schedule_decay=self.params["schedule_decay"]) 446 | elif self.params["optimizer_type"] == "adam": 447 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.params["beta1"], 448 | beta2=self.params["beta2"], epsilon=1e-8) 449 | 450 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 451 | with tf.control_dependencies(update_ops): 452 | train_op = optimizer.apply_gradients(zip(grads, vars)) 453 | 454 | return loss, num_pairs, score, train_op 455 | 456 | 457 | class ListNet(BaseRankModel): 458 | 459 | def __init__(self, model_name, params, logger, training=True): 460 | super(ListNet, self).__init__(model_name, params, logger, training) 461 | -------------------------------------------------------------------------------- /src/prepare_data.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | 5 | 6 | label_file_pat = "../data/processed/%s_label.npy" 7 | group_file_pat = "../data/processed/%s_group.npy" 8 | feature_file_pat = "../data/processed/%s_feature.npy" 9 | 10 | 11 | def convert(type): 12 | data_path = os.path.join("..", "data/MQ2008/Fold1/"+ type + ".txt") 13 | 14 | labels = [] 15 | features = [] 16 | groups = [] 17 | with open(data_path, "r") as f: 18 | for line in f: 19 | if not line: 20 | break 21 | if "#" in line: 22 | line = line[:line.index("#")] 23 | splits = line.strip().split(" ") 24 | labels.append(splits[0]) 25 | groups.append(splits[1].split(":")[1]) 26 | features.append([split.split(":")[1] for split in splits[2:]]) 27 | np.save(label_file_pat % (type), np.array(labels, dtype=int)) 28 | np.save(group_file_pat%(type), np.array(groups, dtype=int)) 29 | np.save(feature_file_pat%(type), np.array(features, dtype=float)) 30 | 31 | 32 | if __name__ == "__main__": 33 | convert("train") 34 | convert("vali") 35 | convert("test") -------------------------------------------------------------------------------- /src/tf_common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-LTR/fb302aafd18b3418dac97631cb75ad536cf5f176/src/tf_common/__init__.py -------------------------------------------------------------------------------- /src/tf_common/nadam.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | from tensorflow.python.eager import context 4 | from tensorflow.python.framework import ops 5 | from tensorflow.python.ops import array_ops 6 | from tensorflow.python.ops import control_flow_ops 7 | from tensorflow.python.ops import math_ops 8 | from tensorflow.python.ops import resource_variable_ops 9 | from tensorflow.python.ops import state_ops 10 | from tensorflow.python.ops import variable_scope 11 | from tensorflow.python.training import optimizer 12 | from tensorflow.python.training import training_ops 13 | 14 | 15 | class NadamOptimizer(optimizer.Optimizer): 16 | def __init__(self, learning_rate=0.002, beta1=0.9, beta2=0.999, epsilon=1e-8, 17 | schedule_decay=0.004, use_locking=False, name="Nadam"): 18 | super(NadamOptimizer, self).__init__(use_locking, name) 19 | self._lr = learning_rate 20 | self._beta1 = beta1 21 | self._beta2 = beta2 22 | self._epsilon = epsilon 23 | self._schedule_decay = schedule_decay 24 | # momentum cache decay 25 | self._momentum_cache_decay = tf.cast(0.96, tf.float32) 26 | self._momentum_cache_const = tf.pow(self._momentum_cache_decay, 1. * schedule_decay) 27 | 28 | # Tensor versions of the constructor arguments, created in _prepare(). 29 | self._lr_t = None 30 | self._beta1_t = None 31 | self._beta2_t = None 32 | self._epsilon_t = None 33 | self._schedule_decay_t = None 34 | 35 | # Variables to accumulate the powers of the beta parameters. 36 | # Created in _create_slots when we know the variables to optimize. 37 | self._beta1_power = None 38 | self._beta2_power = None 39 | self._iterations = None 40 | self._m_schedule = None 41 | 42 | # Created in SparseApply if needed. 43 | self._updated_lr = None 44 | 45 | 46 | def _prepare(self): 47 | self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate") 48 | self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1") 49 | self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2") 50 | self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon") 51 | self._schedule_decay_t = ops.convert_to_tensor(self._schedule_decay, name="schedule_decay") 52 | 53 | def _create_slots(self, var_list): 54 | # Create the beta1 and beta2 accumulators on the same device as the first 55 | # variable. Sort the var_list to make sure this device is consistent across 56 | # workers (these need to go on the same PS, otherwise some updates are 57 | # silently ignored). 58 | first_var = min(var_list, key=lambda x: x.name) 59 | 60 | create_new = self._iterations is None 61 | if not create_new and context.in_graph_mode(): 62 | create_new = (self._iterations.graph is not first_var.graph) 63 | 64 | if create_new: 65 | with ops.colocate_with(first_var): 66 | self._beta1_power = variable_scope.variable(self._beta1, 67 | name="beta1_power", 68 | trainable=False) 69 | self._beta2_power = variable_scope.variable(self._beta2, 70 | name="beta2_power", 71 | trainable=False) 72 | self._iterations = variable_scope.variable(0., 73 | name="iterations", 74 | trainable=False) 75 | self._m_schedule = variable_scope.variable(1., 76 | name="m_schedule", 77 | trainable=False) 78 | # Create slots for the first and second moments. 79 | for v in var_list: 80 | self._zeros_slot(v, "m", self._name) 81 | self._zeros_slot(v, "v", self._name) 82 | 83 | def _get_momentum_cache(self, schedule_decay_t, t): 84 | return tf.pow(self._momentum_cache_decay, t * schedule_decay_t) 85 | # return beta1_t * (1. - 0.5 * (tf.pow(self._momentum_cache_decay, t * schedule_decay_t))) 86 | 87 | 88 | """very slow 89 | we simply use the nadam update rule without warming momentum schedule 90 | def _apply_dense(self, grad, var): 91 | t = math_ops.cast(self._iterations, var.dtype.base_dtype) + 1. 92 | m_schedule = math_ops.cast(self._m_schedule, var.dtype.base_dtype) 93 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 94 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 95 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 96 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 97 | schedule_decay_t = math_ops.cast(self._schedule_decay_t, var.dtype.base_dtype) 98 | 99 | # Due to the recommendations in [2], i.e. warming momentum schedule 100 | # see keras Nadam 101 | momentum_cache_t = self._get_momentum_cache(beta1_t, schedule_decay_t, t) 102 | momentum_cache_t_1 = self._get_momentum_cache(beta1_t, schedule_decay_t, t+1.) 103 | m_schedule_new = m_schedule * momentum_cache_t 104 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 105 | 106 | # the following equations given in [1] 107 | # m_t = beta1 * m + (1 - beta1) * g_t 108 | m = self.get_slot(var, "m") 109 | m_t = state_ops.assign(m, beta1_t * m + (1. - beta1_t) * grad, use_locking=self._use_locking) 110 | g_prime = grad / (1. - m_schedule_new) 111 | m_t_prime = m_t / (1. - m_schedule_next) 112 | m_t_bar = (1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime 113 | 114 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 115 | v = self.get_slot(var, "v") 116 | v_t = state_ops.assign(v, beta2_t * v + (1. - beta2_t) * tf.square(grad), use_locking=self._use_locking) 117 | v_t_prime = v_t / (1. - tf.pow(beta2_t, t)) 118 | 119 | var_update = state_ops.assign_sub(var, 120 | lr_t * m_t_bar / (tf.sqrt(v_t_prime) + epsilon_t), 121 | use_locking=self._use_locking) 122 | 123 | return control_flow_ops.group(*[var_update, m_t, v_t]) 124 | """ 125 | # nadam update rule without warming momentum schedule 126 | def _apply_dense(self, grad, var): 127 | m = self.get_slot(var, "m") 128 | v = self.get_slot(var, "v") 129 | return training_ops.apply_adam( 130 | var, 131 | m, 132 | v, 133 | math_ops.cast(self._beta1_power, var.dtype.base_dtype), 134 | math_ops.cast(self._beta2_power, var.dtype.base_dtype), 135 | math_ops.cast(self._lr_t, var.dtype.base_dtype), 136 | math_ops.cast(self._beta1_t, var.dtype.base_dtype), 137 | math_ops.cast(self._beta2_t, var.dtype.base_dtype), 138 | math_ops.cast(self._epsilon_t, var.dtype.base_dtype), 139 | grad, 140 | use_locking=self._use_locking, 141 | use_nesterov=True).op 142 | 143 | def _resource_apply_dense(self, grad, var): 144 | m = self.get_slot(var, "m") 145 | v = self.get_slot(var, "v") 146 | return training_ops.resource_apply_adam( 147 | var.handle, 148 | m.handle, 149 | v.handle, 150 | math_ops.cast(self._beta1_power, grad.dtype.base_dtype), 151 | math_ops.cast(self._beta2_power, grad.dtype.base_dtype), 152 | math_ops.cast(self._lr_t, grad.dtype.base_dtype), 153 | math_ops.cast(self._beta1_t, grad.dtype.base_dtype), 154 | math_ops.cast(self._beta2_t, grad.dtype.base_dtype), 155 | math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), 156 | grad, 157 | use_locking=self._use_locking, 158 | use_nesterov=True) 159 | 160 | # keras Nadam update rule 161 | def _apply_sparse(self, grad, var): 162 | t = math_ops.cast(self._iterations, var.dtype.base_dtype) + 1. 163 | m_schedule = math_ops.cast(self._m_schedule, var.dtype.base_dtype) 164 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 165 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 166 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 167 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 168 | schedule_decay_t = math_ops.cast(self._schedule_decay_t, var.dtype.base_dtype) 169 | 170 | # Due to the recommendations in [2], i.e. warming momentum schedule 171 | momentum_cache_power = self._get_momentum_cache(schedule_decay_t, t) 172 | momentum_cache_t = beta1_t * (1. - 0.5 * momentum_cache_power) 173 | momentum_cache_t_1 = beta1_t * (1. - 0.5 * momentum_cache_power * self._momentum_cache_const) 174 | m_schedule_new = m_schedule * momentum_cache_t 175 | m_schedule_next = m_schedule_new * momentum_cache_t_1 176 | 177 | # the following equations given in [1] 178 | # m_t = beta1 * m + (1 - beta1) * g_t 179 | m = self.get_slot(var, "m") 180 | m_t = state_ops.scatter_update(m, grad.indices, 181 | beta1_t * array_ops.gather(m, grad.indices) + 182 | (1. - beta1_t) * grad.values, 183 | use_locking=self._use_locking) 184 | g_prime_slice = grad.values / (1. - m_schedule_new) 185 | m_t_prime_slice = array_ops.gather(m_t, grad.indices) / (1. - m_schedule_next) 186 | m_t_bar_slice = (1. - momentum_cache_t) * g_prime_slice + momentum_cache_t_1 * m_t_prime_slice 187 | 188 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 189 | v = self.get_slot(var, "v") 190 | v_t = state_ops.scatter_update(v, grad.indices, 191 | beta2_t * array_ops.gather(v, grad.indices) + 192 | (1. - beta2_t) * tf.square(grad.values), 193 | use_locking=self._use_locking) 194 | v_t_prime_slice = array_ops.gather(v_t, grad.indices) / (1. - tf.pow(beta2_t, t)) 195 | 196 | var_update = state_ops.scatter_sub(var, grad.indices, 197 | lr_t * m_t_bar_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t), 198 | use_locking=self._use_locking) 199 | 200 | return control_flow_ops.group(*[var_update, m_t, v_t]) 201 | 202 | def _finish(self, update_ops, name_scope): 203 | # Update the power accumulators. 204 | with ops.control_dependencies(update_ops): 205 | with ops.colocate_with(self._iterations): 206 | update_beta1 = self._beta1_power.assign( 207 | self._beta1_power * self._beta1_t, 208 | use_locking=self._use_locking) 209 | update_beta2 = self._beta2_power.assign( 210 | self._beta2_power * self._beta2_t, 211 | use_locking=self._use_locking) 212 | t = self._iterations + 1. 213 | update_iterations = self._iterations.assign(t, use_locking=self._use_locking) 214 | momentum_cache_power = self._get_momentum_cache(self._schedule_decay_t, t) 215 | momentum_cache_t = self._beta1_t * (1. - 0.5 * momentum_cache_power) 216 | update_m_schedule = self._m_schedule.assign( 217 | self._m_schedule * momentum_cache_t, 218 | use_locking=self._use_locking) 219 | return control_flow_ops.group( 220 | *update_ops + [update_beta1, update_beta2] + [update_iterations, update_m_schedule], 221 | name=name_scope) -------------------------------------------------------------------------------- /src/tf_common/nn_module.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | """ 6 | https://explosion.ai/blog/deep-learning-formula-nlp 7 | embed -> encode -> attend -> predict 8 | """ 9 | def batch_normalization(x, training, name): 10 | # with tf.variable_scope(name, reuse=) 11 | bn_train = tf.layers.batch_normalization(x, training=True, reuse=None, name=name) 12 | bn_inference = tf.layers.batch_normalization(x, training=False, reuse=True, name=name) 13 | z = tf.cond(training, lambda: bn_train, lambda: bn_inference) 14 | return z 15 | 16 | 17 | #### Step 1 18 | def embed(x, size, dim, seed=0, flatten=False, reduce_sum=False): 19 | # std = np.sqrt(2 / dim) 20 | std = 0.001 21 | minval = -std 22 | maxval = std 23 | emb = tf.Variable(tf.random_uniform([size, dim], minval, maxval, dtype=tf.float32, seed=seed)) 24 | # None * max_seq_len * embed_dim 25 | out = tf.nn.embedding_lookup(emb, x) 26 | if flatten: 27 | out = tf.layers.flatten(out) 28 | if reduce_sum: 29 | out = tf.reduce_sum(out, axis=1) 30 | return out 31 | 32 | 33 | def embed_subword(x, size, dim, sequence_length, seed=0, mask_zero=False, maxlen=None): 34 | # std = np.sqrt(2 / dim) 35 | std = 0.001 36 | minval = -std 37 | maxval = std 38 | emb = tf.Variable(tf.random_uniform([size, dim], minval, maxval, dtype=tf.float32, seed=seed)) 39 | # None * max_seq_len * max_word_len * embed_dim 40 | out = tf.nn.embedding_lookup(emb, x) 41 | if mask_zero: 42 | # word_len: None * max_seq_len 43 | # mask: shape=None * max_seq_len * max_word_len 44 | mask = tf.sequence_mask(sequence_length, maxlen) 45 | mask = tf.expand_dims(mask, axis=-1) 46 | mask = tf.cast(mask, tf.float32) 47 | out = out * mask 48 | # None * max_seq_len * embed_dim 49 | # according to facebook subword paper, it's sum 50 | out = tf.reduce_sum(out, axis=2) 51 | return out 52 | 53 | 54 | def word_dropout(x, training, dropout=0, seed=0): 55 | # word dropout (dropout the entire embedding for some words) 56 | """ 57 | tf.layers.Dropout doesn't work as it can't switch training or inference 58 | """ 59 | if dropout > 0: 60 | input_shape = tf.shape(x) 61 | noise_shape = [input_shape[0], input_shape[1], 1] 62 | x = tf.layers.Dropout(rate=dropout, noise_shape=noise_shape, seed=seed)(x, training=training) 63 | return x 64 | 65 | 66 | #### Step 2 67 | def fasttext(x): 68 | return x 69 | 70 | 71 | def textcnn(x, num_filters=8, filter_sizes=[2, 3], bn=False, training=False, 72 | timedistributed=False, scope_name="textcnn", reuse=False): 73 | # x: None * step_dim * embed_dim 74 | conv_blocks = [] 75 | for i, filter_size in enumerate(filter_sizes): 76 | scope_name_i = "%s_textcnn_%s"%(str(scope_name), str(filter_size)) 77 | with tf.variable_scope(scope_name_i, reuse=reuse): 78 | if timedistributed: 79 | input_shape = tf.shape(x) 80 | step_dim = input_shape[1] 81 | embed_dim = input_shape[2] 82 | x = tf.transpose(x, [0, 2, 1]) 83 | # None * embed_dim * step_dim 84 | x = tf.reshape(x, [input_shape[0] * embed_dim, step_dim, 1]) 85 | conv = tf.layers.conv1d( 86 | input=x, 87 | filters=1, 88 | kernel_size=filter_size, 89 | padding="same", 90 | activation=None, 91 | strides=1, 92 | reuse=reuse, 93 | name=scope_name_i) 94 | conv = tf.reshape(conv, [input_shape[0], embed_dim, step_dim]) 95 | conv = tf.transpose(conv, [0, 2, 1]) 96 | else: 97 | conv = tf.layers.conv1d( 98 | inputs=x, 99 | filters=num_filters, 100 | kernel_size=filter_size, 101 | padding="same", 102 | activation=None, 103 | strides=1, 104 | reuse=reuse, 105 | name=scope_name_i) 106 | if bn: 107 | conv = tf.layers.BatchNormalization()(conv, training) 108 | conv = tf.nn.relu(conv) 109 | conv_blocks.append(conv) 110 | if len(conv_blocks) > 1: 111 | z = tf.concat(conv_blocks, axis=-1) 112 | else: 113 | z = conv_blocks[0] 114 | return z 115 | 116 | 117 | def textrnn(x, num_units, cell_type, sequence_length, num_layers=1, mask_zero=False, scope_name="textrnn", reuse=False): 118 | for i in range(num_layers): 119 | scope_name_i = "%s_textrnn_%s_%s_%s" % (str(scope_name), cell_type, str(i), str(num_units)) 120 | with tf.variable_scope(scope_name_i, reuse=reuse): 121 | if cell_type == "gru": 122 | cell_fw = tf.nn.rnn_cell.GRUCell(num_units) 123 | elif cell_type == "lstm": 124 | cell_fw = tf.nn.rnn_cell.LSTMCell(num_units) 125 | if mask_zero: 126 | x, _ = tf.nn.dynamic_rnn(cell_fw, x, dtype=tf.float32, sequence_length=sequence_length, scope=scope_name_i) 127 | else: 128 | x, _ = tf.nn.dynamic_rnn(cell_fw, x, dtype=tf.float32, sequence_length=None, scope=scope_name_i) 129 | return x 130 | 131 | 132 | def textbirnn(x, num_units, cell_type, sequence_length, num_layers=1, mask_zero=False, scope_name="textbirnn", reuse=False): 133 | for i in range(num_layers): 134 | scope_name_i = "%s_textbirnn_%s_%s_%s" % (str(scope_name), cell_type, str(i), str(num_units)) 135 | with tf.variable_scope(scope_name_i, reuse=reuse): 136 | if cell_type == "gru": 137 | cell_fw = tf.nn.rnn_cell.GRUCell(num_units) 138 | cell_bw = tf.nn.rnn_cell.GRUCell(num_units) 139 | elif cell_type == "lstm": 140 | cell_fw = tf.nn.rnn_cell.LSTMCell(num_units) 141 | cell_bw = tf.nn.rnn_cell.LSTMCell(num_units) 142 | if mask_zero: 143 | (output_fw, output_bw), _ = tf.nn.bidirectional_dynamic_rnn( 144 | cell_fw, cell_bw, x, dtype=tf.float32, sequence_length=sequence_length, scope=scope_name_i) 145 | else: 146 | (output_fw, output_bw), _ = tf.nn.bidirectional_dynamic_rnn( 147 | cell_fw, cell_bw, x, dtype=tf.float32, sequence_length=None, scope=scope_name_i) 148 | x = tf.concat([output_fw, output_bw], axis=-1) 149 | return x 150 | 151 | 152 | 153 | def encode(x, method, params, sequence_length=None, mask_zero=False, scope_name="encode", reuse=False): 154 | """ 155 | :param x: shape=(None,seqlen,dim) 156 | :param params: 157 | :return: shape=(None,seqlen,dim) 158 | """ 159 | dim_f = params["embedding_dim"] 160 | dim_c = len(params["cnn_filter_sizes"]) * params["cnn_num_filters"] 161 | dim_r = params["rnn_num_units"] 162 | dim_b = params["rnn_num_units"] * 2 163 | out_list = [] 164 | params["encode_dim"] = 0 165 | for m in method.split("+"): 166 | if m == "fasttext": 167 | z = fasttext(x) 168 | out_list.append(z) 169 | params["encode_dim"] += dim_f 170 | elif m == "textcnn": 171 | z = textcnn(x, num_filters=params["cnn_num_filters"], filter_sizes=params["cnn_filter_sizes"], 172 | timedistributed=params["cnn_timedistributed"], scope_name=scope_name, reuse=reuse) 173 | out_list.append(z) 174 | params["encode_dim"] += dim_c 175 | elif m == "textrnn": 176 | z = textrnn(x, num_units=params["rnn_num_units"], cell_type=params["rnn_cell_type"], 177 | sequence_length=sequence_length, mask_zero=mask_zero, scope_name=scope_name, reuse=reuse) 178 | out_list.append(z) 179 | params["encode_dim"] += dim_r 180 | elif method == "textbirnn": 181 | z = textbirnn(x, num_units=params["rnn_num_units"], cell_type=params["rnn_cell_type"], 182 | sequence_length=sequence_length, mask_zero=mask_zero, scope_name=scope_name, reuse=reuse) 183 | out_list.append(z) 184 | params["encode_dim"] += dim_b 185 | z = tf.concat(out_list, axis=-1) 186 | return z 187 | 188 | 189 | def attention(x, feature_dim, sequence_length=None, mask_zero=False, maxlen=None, epsilon=1e-8, seed=0, 190 | scope_name="attention", reuse=False): 191 | input_shape = tf.shape(x) 192 | step_dim = input_shape[1] 193 | # feature_dim = input_shape[2] 194 | x = tf.reshape(x, [-1, feature_dim]) 195 | """ 196 | The last dimension of the inputs to `Dense` should be defined. Found `None`. 197 | 198 | cann't not use `tf.layers.Dense` here 199 | eij = tf.layers.Dense(1)(x) 200 | 201 | see: https://github.com/tensorflow/tensorflow/issues/13348 202 | workaround: specify the feature_dim as input 203 | """ 204 | with tf.variable_scope(scope_name, reuse=reuse): 205 | eij = tf.layers.dense(x, 1, activation=tf.nn.tanh, 206 | kernel_initializer=tf.glorot_uniform_initializer(seed=seed), 207 | reuse=reuse, 208 | name=scope_name) 209 | eij = tf.reshape(eij, [-1, step_dim]) 210 | a = tf.exp(eij) 211 | 212 | # apply mask after the exp. will be re-normalized next 213 | if mask_zero: 214 | # None * step_dim 215 | mask = tf.sequence_mask(sequence_length, maxlen) 216 | mask = tf.cast(mask, tf.float32) 217 | a = a * mask 218 | 219 | # in some cases especially in the early stages of training the sum may be almost zero 220 | a /= tf.cast(tf.reduce_sum(a, axis=1, keep_dims=True) + epsilon, tf.float32) 221 | 222 | a = tf.expand_dims(a, axis=-1) 223 | return a 224 | 225 | 226 | def attend(x, sequence_length=None, method="ave", context=None, feature_dim=None, mask_zero=False, maxlen=None, 227 | bn=False, training=False, seed=0, scope_name="attention", reuse=False): 228 | if method == "ave": 229 | if mask_zero: 230 | # None * step_dim 231 | mask = tf.sequence_mask(sequence_length, maxlen) 232 | mask = tf.reshape(mask, (-1, tf.shape(x)[1], 1)) 233 | mask = tf.cast(mask, tf.float32) 234 | z = tf.reduce_sum(x * mask, axis=1) 235 | l = tf.reduce_sum(mask, axis=1) 236 | # in some cases especially in the early stages of training the sum may be almost zero 237 | epsilon = 1e-8 238 | z /= tf.cast(l + epsilon, tf.float32) 239 | else: 240 | z = tf.reduce_mean(x, axis=1) 241 | elif method == "sum": 242 | if mask_zero: 243 | # None * step_dim 244 | mask = tf.sequence_mask(sequence_length, maxlen) 245 | mask = tf.reshape(mask, (-1, tf.shape(x)[1], 1)) 246 | mask = tf.cast(mask, tf.float32) 247 | z = tf.reduce_sum(x * mask, axis=1) 248 | else: 249 | z = tf.reduce_sum(x, axis=1) 250 | elif method == "max": 251 | if mask_zero: 252 | # None * step_dim 253 | mask = tf.sequence_mask(sequence_length, maxlen) 254 | mask = tf.expand_dims(mask, axis=-1) 255 | mask = tf.tile(mask, (1, 1, tf.shape(x)[2])) 256 | masked_data = tf.where(tf.equal(mask, tf.zeros_like(mask)), 257 | tf.ones_like(x) * -np.inf, x) # if masked assume value is -inf 258 | z = tf.reduce_max(masked_data, axis=1) 259 | else: 260 | z = tf.reduce_max(x, axis=1) 261 | elif method == "attention": 262 | if context is not None: 263 | step_dim = tf.shape(x)[1] 264 | context = tf.expand_dims(context, axis=1) 265 | context = tf.tile(context, [1, step_dim, 1]) 266 | y = tf.concat([x, context], axis=-1) 267 | else: 268 | y = x 269 | a = attention(y, feature_dim, sequence_length, mask_zero, maxlen, seed=seed, scope_name=scope_name, reuse=reuse) 270 | z = tf.reduce_sum(x * a, axis=1) 271 | if bn: 272 | z = tf.layers.BatchNormalization()(z, training=training) 273 | return z 274 | 275 | 276 | #### Step 4 277 | def _dense_block_mode1(x, hidden_units, dropouts, densenet=False, scope_name="dense_block", reuse=False, training=False, seed=0, bn=False): 278 | """ 279 | :param x: 280 | :param hidden_units: 281 | :param dropouts: 282 | :param densenet: enable densenet 283 | :return: 284 | Ref: https://github.com/titu1994/DenseNet 285 | """ 286 | for i, (h, d) in enumerate(zip(hidden_units, dropouts)): 287 | scope_name_i = "%s-dense_block_mode1-%s"%(str(scope_name), str(i)) 288 | with tf.variable_scope(scope_name, reuse=reuse): 289 | z = tf.layers.dense(x, h, kernel_initializer=tf.glorot_uniform_initializer(seed=seed * i), 290 | reuse=reuse, 291 | name=scope_name_i) 292 | if bn: 293 | z = batch_normalization(z, training=training, name=scope_name_i+"-bn") 294 | z = tf.nn.relu(z) 295 | z = tf.layers.Dropout(d, seed=seed * i)(z, training=training) if d > 0 else z 296 | if densenet: 297 | x = tf.concat([x, z], axis=-1) 298 | else: 299 | x = z 300 | return x 301 | 302 | 303 | def _dense_block_mode2(x, hidden_units, dropouts, densenet=False, training=False, seed=0, bn=False, name="dense_block"): 304 | """ 305 | :param x: 306 | :param hidden_units: 307 | :param dropouts: 308 | :param densenet: enable densenet 309 | :return: 310 | Ref: https://github.com/titu1994/DenseNet 311 | """ 312 | for i, (h, d) in enumerate(zip(hidden_units, dropouts)): 313 | if bn: 314 | z = batch_normalization(x, training=training, name=name + "-" + str(i)) 315 | z = tf.nn.relu(z) 316 | z = tf.layers.Dropout(d, seed=seed * i)(z, training=training) if d > 0 else z 317 | z = tf.layers.Dense(h, kernel_initializer=tf.glorot_uniform_initializer(seed=seed * i), dtype=tf.float32, 318 | bias_initializer=tf.zeros_initializer())(z) 319 | if densenet: 320 | x = tf.concat([x, z], axis=-1) 321 | else: 322 | x = z 323 | return x 324 | 325 | 326 | def dense_block(x, hidden_units, dropouts, densenet=False, scope_name="dense_block", reuse=False, training=False, seed=0, bn=False): 327 | return _dense_block_mode1(x, hidden_units, dropouts, densenet, scope_name, reuse, training, seed, bn) 328 | 329 | 330 | def _resnet_branch_mode1(x, hidden_units, dropouts, training, seed=0): 331 | h1, h2, h3 = hidden_units 332 | dr1, dr2, dr3 = dropouts 333 | name = "resnet_block" 334 | # branch 2 335 | x2 = tf.layers.Dense(h1, kernel_initializer=tf.glorot_uniform_initializer(seed=seed * 2), dtype=tf.float32, 336 | bias_initializer=tf.zeros_initializer())(x) 337 | x2 = tf.layers.BatchNormalization()(x2, training=training) 338 | # x2 = batch_normalization(x2, training=training, name=name + "-" + str(1)) 339 | x2 = tf.nn.relu(x2) 340 | x2 = tf.layers.Dropout(dr1, seed=seed * 1)(x2, training=training) if dr1 > 0 else x2 341 | 342 | x2 = tf.layers.Dense(h2, kernel_initializer=tf.glorot_uniform_initializer(seed=seed * 3), dtype=tf.float32, 343 | bias_initializer=tf.zeros_initializer())(x2) 344 | x2 = tf.layers.BatchNormalization()(x2, training=training) 345 | # x2 = batch_normalization(x2, training=training, name=name + "-" + str(2)) 346 | x2 = tf.nn.relu(x2) 347 | x2 = tf.layers.Dropout(dr2, seed=seed * 2)(x2, training=training) if dr2 > 0 else x2 348 | 349 | x2 = tf.layers.Dense(h3, kernel_initializer=tf.glorot_uniform_initializer(seed=seed * 4), dtype=tf.float32, 350 | bias_initializer=tf.zeros_initializer())(x2) 351 | x2 = tf.layers.BatchNormalization()(x2, training=training) 352 | # x2 = batch_normalization(x2, training=training, name=name + "-" + str(3)) 353 | 354 | return x2 355 | 356 | 357 | def _resnet_block_mode1(x, hidden_units, dropouts, cardinality=1, dense_shortcut=False, training=False, seed=0): 358 | """A block that has a dense layer at shortcut. 359 | # Arguments 360 | input_tensor: input tensor 361 | kernel_size: default 3, the kernel size of middle conv layer at main path 362 | filters: list of integers, the filters of 3 conv layer at main path 363 | stage: integer, current stage label, used for generating layer names 364 | block: 'a','b'..., current block label, used for generating layer names 365 | # Returns 366 | Output tensor for the block. 367 | Note that from stage 3, the first conv layer at main path is with strides=(2,2) 368 | And the shortcut should have strides=(2,2) as well 369 | """ 370 | h1, h2, h3 = hidden_units 371 | dr1, dr2, dr3 = dropouts 372 | name = "resnet_block" 373 | xs = [] 374 | # branch 0 375 | if dense_shortcut: 376 | x0 = tf.layers.Dense(h3, kernel_initializer=tf.glorot_uniform_initializer(seed=seed * 1), dtype=tf.float32, 377 | bias_initializer=tf.zeros_initializer())(x) 378 | x0 = tf.layers.BatchNormalization()(x0, training=training) 379 | # x0 = batch_normalization(x0, training=training, name=name + "-" + str(0)) 380 | xs.append(x0) 381 | else: 382 | xs.append(x) 383 | 384 | # branch 1 ~ cardinality 385 | for i in range(cardinality): 386 | xs.append(_resnet_branch_mode1(x, hidden_units, dropouts, training, seed)) 387 | 388 | x = tf.add_n(xs) 389 | x = tf.nn.relu(x) 390 | x = tf.layers.Dropout(dr3, seed=seed * 4)(x, training=training) if dr3 > 0 else x 391 | return x 392 | 393 | 394 | def _resnet_branch_mode2(x, hidden_units, dropouts, training=False, seed=0, scope_name="_resnet_branch_mode2", reuse=False): 395 | h1, h2, h3 = hidden_units 396 | dr1, dr2, dr3 = dropouts 397 | # name = "resnet" 398 | with tf.variable_scope(scope_name, reuse=reuse): 399 | # branch 2: bn-relu->weight 400 | x2 = tf.layers.BatchNormalization()(x) 401 | # x2 = batch_normalization(x, training=training, name=scope_name + "-bn-" + str(1)) 402 | x2 = tf.nn.relu(x2) 403 | x2 = tf.layers.Dropout(dr1)(x2, training=training) if dr1 > 0 else x2 404 | x2 = tf.layers.dense(x2, h1, kernel_initializer=tf.glorot_uniform_initializer(seed * 1), 405 | bias_initializer=tf.zeros_initializer(), 406 | name=scope_name+"-dense-"+str(1), 407 | reuse=reuse) 408 | 409 | x2 = tf.layers.BatchNormalization()(x2) 410 | # x2 = batch_normalization(x2, training=training, name=scope_name + "-bn-" + str(2)) 411 | x2 = tf.nn.relu(x2) 412 | x2 = tf.layers.Dropout(dr2)(x2, training=training) if dr2 > 0 else x2 413 | x2 = tf.layers.dense(x2, h2, kernel_initializer=tf.glorot_uniform_initializer(seed * 2), 414 | bias_initializer=tf.zeros_initializer(), 415 | name=scope_name + "-dense-" + str(2), 416 | reuse=reuse) 417 | 418 | x2 = tf.layers.BatchNormalization()(x2) 419 | # x2 = batch_normalization(x2, training=training, name=scope_name + "-bn-" + str(3)) 420 | x2 = tf.nn.relu(x2) 421 | x2 = tf.layers.Dropout(dr3)(x2, training=training) if dr3 > 0 else x2 422 | x2 = tf.layers.dense(x2, h3, kernel_initializer=tf.glorot_uniform_initializer(seed * 3), 423 | bias_initializer=tf.zeros_initializer(), 424 | name=scope_name + "-dense-" + str(3), 425 | reuse=reuse) 426 | 427 | return x2 428 | 429 | 430 | def _resnet_block_mode2(x, hidden_units, dropouts, cardinality=1, dense_shortcut=False, training=False, seed=0, 431 | scope_name="_resnet_block_mode2", reuse=False): 432 | """A block that has a dense layer at shortcut. 433 | # Arguments 434 | input_tensor: input tensor 435 | kernel_size: default 3, the kernel size of middle conv layer at main path 436 | filters: list of integers, the filters of 3 conv layer at main path 437 | stage: integer, current stage label, used for generating layer names 438 | block: 'a','b'..., current block label, used for generating layer names 439 | # Returns 440 | Output tensor for the block. 441 | Note that from stage 3, the first conv layer at main path is with strides=(2,2) 442 | And the shortcut should have strides=(2,2) as well 443 | """ 444 | h1, h2, h3 = hidden_units 445 | dr1, dr2, dr3 = dropouts 446 | 447 | xs = [] 448 | # branch 0 449 | if dense_shortcut: 450 | with tf.variable_scope(scope_name, reuse=reuse): 451 | x0 = tf.layers.dense(x, h3, kernel_initializer=tf.glorot_uniform_initializer(seed * 1), 452 | bias_initializer=tf.zeros_initializer(), 453 | reuse=reuse, 454 | name=scope_name+"-dense-"+str("0")) 455 | xs.append(x0) 456 | else: 457 | xs.append(x) 458 | 459 | # branch 1 ~ cardinality 460 | for i in range(cardinality): 461 | xs.append(_resnet_branch_mode2(x, hidden_units, dropouts, training, seed, scope_name, reuse)) 462 | 463 | x = tf.add_n(xs) 464 | return x 465 | 466 | 467 | def resnet_block(input_tensor, hidden_units, dropouts, cardinality=1, dense_shortcut=False, training=False, seed=0, 468 | scope_name="resnet_block", reuse=False): 469 | return _resnet_block_mode2(input_tensor, hidden_units, dropouts, cardinality, dense_shortcut, training, seed, 470 | scope_name, reuse) 471 | 472 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import datetime 4 | import logging 5 | import logging.handlers 6 | import shutil 7 | import numpy as np 8 | 9 | 10 | def _timestamp(): 11 | now = datetime.datetime.now() 12 | now_str = now.strftime("%Y%m%d%H%M") 13 | return now_str 14 | 15 | 16 | def _get_logger(logdir, logname, loglevel=logging.INFO): 17 | fmt = "[%(asctime)s] %(levelname)s: %(message)s" 18 | formatter = logging.Formatter(fmt) 19 | 20 | handler = logging.handlers.RotatingFileHandler( 21 | filename=os.path.join(logdir, logname), 22 | maxBytes=2 * 1024 * 1024 * 1024, 23 | backupCount=10) 24 | handler.setFormatter(formatter) 25 | 26 | logger = logging.getLogger("") 27 | logger.addHandler(handler) 28 | logger.setLevel(loglevel) 29 | return logger 30 | 31 | 32 | def _makedirs(dir, force=False): 33 | if os.path.exists(dir): 34 | if force: 35 | shutil.rmtree(dir) 36 | os.makedirs(dir) 37 | else: 38 | os.makedirs(dir) 39 | 40 | 41 | def _get_intersect_index(all, subset): 42 | lst = [] 43 | for xi in subset: 44 | i = np.where(all == xi)[0] 45 | lst.append(i) 46 | return np.hstack(lst).tolist() 47 | --------------------------------------------------------------------------------