├── .spyproject ├── codestyle.ini ├── encoding.ini ├── vcs.ini └── workspace.ini ├── README.md ├── README_EN.md ├── requirements.txt └── script ├── bi_encoder ├── bi-encoder-batch.py └── bi-encoder-data.py └── cross_encoder ├── cross_encoder_random_on_multi_eval.py ├── try_sbert_neg_sampler.py ├── try_sbert_neg_sampler_valid.py └── valid_cross_encoder_on_bi_encoder.py /.spyproject/codestyle.ini: -------------------------------------------------------------------------------- 1 | [codestyle] 2 | indentation = True 3 | 4 | [main] 5 | version = 0.1.0 6 | 7 | -------------------------------------------------------------------------------- /.spyproject/encoding.ini: -------------------------------------------------------------------------------- 1 | [encoding] 2 | text_encoding = utf-8 3 | 4 | [main] 5 | version = 0.1.0 6 | 7 | -------------------------------------------------------------------------------- /.spyproject/vcs.ini: -------------------------------------------------------------------------------- 1 | [vcs] 2 | use_version_control = False 3 | version_control_system = 4 | 5 | [main] 6 | version = 0.1.0 7 | 8 | -------------------------------------------------------------------------------- /.spyproject/workspace.ini: -------------------------------------------------------------------------------- 1 | [workspace] 2 | restore_data_on_startup = True 3 | save_data_on_exit = True 4 | save_history = True 5 | save_non_project_files = False 6 | 7 | [main] 8 | version = 0.1.0 9 | recent_files = ['/home/svjack/.config/spyder-py3/temp.py', '/home/svjack/temp_dir/bi_cross_model/script/bi-encoder-batch.py', '/home/svjack/temp_dir/bi_cross_model/script/bi-encoder-data.py', '/home/svjack/temp_dir/bi_cross_model/script/bi_encoder/bi-encoder-batch.py', '/home/svjack/temp_dir/bi_cross_model/script/bi_encoder/bi-encoder-data.py', '/home/svjack/temp_dir/bi_cross_model/download.sh', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder_data_prepare_train.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder_data_prepare.py', '/home/svjack/temp_dir/bi_cross_model/script/try_sbert_neg_sampler.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_enccoder_v0.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/try_sbert_neg_sampler.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/try_sbert_neg_sampler_valid.py', '/home/svjack/temp_dir/bi_cross_model/script/choose_right_params.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder_random_on_multi_eval.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/cross_encoder_random_on_multi_eval.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/valid_cross_encoder_on_bi_encoder.py'] 10 | 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 8 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 27 | 28 | 36 | 37 | 38 |
39 |

40 | 41 |

Sbert-ChineseExample

42 | 43 |

44 | Sentence-Transformers 中文信息检索例子 45 |
46 |
47 |
48 |

49 |

50 | 51 | [In English](README_EN.md) 52 | 53 | ## 内容提要 54 | * [有关这个工程](#about-the-project) 55 | * [构建信息](#built-with) 56 | * [开始](#getting-started) 57 | * [安装](#installation) 58 | * [使用](#usage) 59 | * [引导](#roadmap) 60 | * [贡献](#contributing) 61 | * [License](#license) 62 | * [Contact](#contact) 63 | * [Acknowledgements](#acknowledgements) 64 | 65 | 66 | ## 关于这个工程 67 | ## About The Project 68 | 69 | 72 | 73 | 74 | 77 | 102 | Sentence Transformers是一个多语言、多模态句子向量生成框架,可以根据Huggingface Transformers框架简单地生成句子及文本段落的分布式向量表征。
103 | 104 | 这个工程的目的是通过训练bi_encoder和cross_encoder实现类似于ms_macro任务的中文数据集信息检索,并搭配定制化的pandas形式的elasticsearch接口使得结果产出(文本、向量)可以方便地序列化。
105 | 106 | 116 | ### 构建信息 117 | ### Built With 118 | 131 | * [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) 132 | * [Elasticsearch](https://github.com/elastic/elasticsearch) 133 | * [Faiss](https://github.com/facebookresearch/faiss) 134 | 135 | 136 | ## 开始 137 | ## Getting Started 138 | 142 | 143 | 152 | ### 安装 153 | ### Installation 154 | * pip 155 | ```sh 156 | pip install -r requirements.txt 157 | ``` 158 | * 安装Elasticsearch并启动服务 159 | * install Elasticsearch and start service 160 | 161 | 162 | 163 | ## 使用 164 | ## Usage 165 | 170 | 171 |

172 | 173 | 1. 从 google drive 下载数据集 174 |

175 | 176 |

177 | 2. bi_encoder 数据准备 178 |

179 | 180 |

181 | 3. 训练 bi_encoder 182 |

183 | 184 |

185 | 4. cross_encoder 训练数据准备 186 |

187 | 188 |

189 | 5. cross_encoder 检测数据准备 190 |

191 | 192 |

193 | 6. 训练 cross_encoder 194 |

195 | 196 |

197 | 7. 展示 bi_encoder cross_encoder 的推断过程 198 |

199 | 200 | 201 | ## 引导 202 | ## Roadmap 203 | 206 | 212 |
213 | * 1 这个工程使用自定义的 es-pandas 的重载接口 (支持向量存储) 来使用pandas对于elasticsearch实现简单的操作。 214 |
215 | * 2 try_sbert_neg_sampler.py 抽取困难样本(模型识别困难的样本)的功能来自于 216 | https://guzpenha.github.io/transformer_rankers/, 217 | 也可以使用 elasticsearch 生成困难样本, 相应的功能在 valid_cross_encoder_on_bi_encoder.py 中定义。 218 |
219 | * 3 上面在 cross_encoder 上训练的功能, 需要预先在不同的句子间检查语义区别程度, 220 | 组合相似语义的样本对于模型训练是有帮助的。 221 |
222 | * 4 增加了一些对Sentence-Transformers多类别结果比较的工具。 223 | 224 | 225 | 226 | ## 贡献 227 | ## Contributing 228 | 237 | 238 | 239 | 240 | ## License 241 | 242 | Distributed under the MIT License. See `LICENSE` for more information. 243 | 244 | 245 | 246 | 247 | ## Contact 248 | 249 | 252 | svjack - svjackbt@gmail.com 253 | ehangzhou@outlook.com 254 | 255 | 258 | Project Link: [https://github.com/svjack/Sbert-ChineseExample](https://github.com/svjack/Sbert-ChineseExample) 259 | 260 | 261 | 262 | ## Acknowledgements 263 | 276 | * [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) 277 | * [Elasticsearch](https://github.com/elastic/elasticsearch) 278 | * [Faiss](https://github.com/facebookresearch/faiss) 279 | * [Transformer Rankers](https://github.com/Guzpenha/transformer_rankers) 280 | * [es_pandas](https://github.com/fuyb1992/es_pandas) 281 | 282 | 283 | 284 | 285 | [contributors-shield]: https://img.shields.io/github/contributors/othneildrew/Best-README-Template.svg?style=flat-square 286 | [contributors-url]: https://github.com/othneildrew/Best-README-Template/graphs/contributors 287 | [forks-shield]: https://img.shields.io/github/forks/othneildrew/Best-README-Template.svg?style=flat-square 288 | [forks-url]: https://github.com/othneildrew/Best-README-Template/network/members 289 | [stars-shield]: https://img.shields.io/github/stars/othneildrew/Best-README-Template.svg?style=flat-square 290 | [stars-url]: https://github.com/othneildrew/Best-README-Template/stargazers 291 | [issues-shield]: https://img.shields.io/github/issues/othneildrew/Best-README-Template.svg?style=flat-square 292 | [issues-url]: https://github.com/othneildrew/Best-README-Template/issues 293 | [license-shield]: https://img.shields.io/github/license/othneildrew/Best-README-Template.svg?style=flat-square 294 | [license-url]: https://github.com/othneildrew/Best-README-Template/blob/master/LICENSE.txt 295 | [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=flat-square&logo=linkedin&colorB=555 296 | [linkedin-url]: https://linkedin.com/in/othneildrew 297 | [product-screenshot]: images/screenshot.png 298 | -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 | 8 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 27 | 28 | 36 | 37 | 38 | 39 |
40 |

41 | 46 | 47 | 51 |

Sbert-ChineseExample

52 | 53 |

54 | 55 | 56 | 57 | Sentence-Transformers Information Retrieval example on Chinese 58 |
59 | 62 |
63 |
64 | 68 | 74 |

75 |

76 | 77 | [中文介绍](README.md) 78 | 79 | 80 | ## Table of Contents 81 | 82 | * [About the Project](#about-the-project) 83 | * [Built With](#built-with) 84 | * [Getting Started](#getting-started) 85 | 88 | * [Installation](#installation) 89 | * [Usage](#usage) 90 | * [Roadmap](#roadmap) 91 | * [Contributing](#contributing) 92 | * [License](#license) 93 | * [Contact](#contact) 94 | * [Acknowledgements](#acknowledgements) 95 | 96 | 97 | 98 | 99 | ## About The Project 100 | 101 | 104 | 105 | 106 | 109 | 134 | Sentence Transformers is a multilingual sentence embedding generate framework, which provides an easy method to compute dense 135 | vector representations for sentences and paragraphs (based on HuggingFace Transformers) 136 | 137 | This repository target at ms_macro like task on a Chinese dataset, train bi_encoder and cross_encoder, with the help of 138 | elasticsearch easy interface on pandas to build serlizable conclusion. 139 | 140 | 150 | ### Built With 151 | 164 | * [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) 165 | * [Elasticsearch](https://github.com/elastic/elasticsearch) 166 | * [Faiss](https://github.com/facebookresearch/faiss) 167 | 168 | 169 | ## Getting Started 170 | 174 | 175 | 184 | 185 | ### Installation 186 | * pip 187 | ```sh 188 | pip install -r requirements.txt 189 | ``` 190 | * install Elasticsearch and start service 191 | 206 | 207 | 208 | 209 | ## Usage 210 | 215 | 216 |

217 | 218 | 1. Download Data from google drive 219 |

220 | 221 |

222 | 2. bi_encoder data prepare 223 |

224 | 225 |

226 | 3. train bi_encoder 227 |

228 | 229 |

230 | 4. cross_encoder train data prepare 231 |

232 | 233 |

234 | 5. cross_encoder valid data prepare 235 |

236 | 237 |

238 | 6. train cross_encoder 239 |

240 | 241 |

242 | 7. show bi_encoder cross_encoder inference 243 |

244 | 245 | 246 | ## Roadmap 247 | 250 | 256 |
257 | * 1 This repository use edited es-pandas interface (support vector serlized) to have a simple manipulate on elasticsearch by pandas. 258 |
259 | * 2 try_sbert_neg_sampler.py sample hard negative samples drived from class provide by 260 | https://guzpenha.github.io/transformer_rankers/ 261 | can also use elastic search to generate hard samples , relate functions have defined in valid_cross_encoder_on_bi_encoder.py 262 |
263 | * 3 Before training your dataset on cross_encoder, should take a look at the semantic similarity between different questions. 264 | Combine some samples with similar semantic may give help. 265 |
266 | * 4 Add some toolkit to Sbert to support multi-class-evaluation (as dictionary) 267 | 268 | ## Contributing 269 | 278 | 279 | 280 | 281 | ## License 282 | 283 | Distributed under the MIT License. See `LICENSE` for more information. 284 | 285 | 286 | 287 | 288 | ## Contact 289 | 290 | 293 | svjack - svjackbt@gmail.com 294 | 295 | 298 | Project Link: [https://github.com/svjack/Sbert-ChineseExample](https://github.com/svjack/Sbert-ChineseExample) 299 | 300 | 301 | 302 | ## Acknowledgements 303 | 316 | * [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) 317 | * [Elasticsearch](https://github.com/elastic/elasticsearch) 318 | * [Faiss](https://github.com/facebookresearch/faiss) 319 | * [Transformer Rankers](https://github.com/Guzpenha/transformer_rankers) 320 | * [es_pandas](https://github.com/fuyb1992/es_pandas) 321 | 322 | 323 | 324 | 325 | [contributors-shield]: https://img.shields.io/github/contributors/othneildrew/Best-README-Template.svg?style=flat-square 326 | [contributors-url]: https://github.com/othneildrew/Best-README-Template/graphs/contributors 327 | [forks-shield]: https://img.shields.io/github/forks/othneildrew/Best-README-Template.svg?style=flat-square 328 | [forks-url]: https://github.com/othneildrew/Best-README-Template/network/members 329 | [stars-shield]: https://img.shields.io/github/stars/othneildrew/Best-README-Template.svg?style=flat-square 330 | [stars-url]: https://github.com/othneildrew/Best-README-Template/stargazers 331 | [issues-shield]: https://img.shields.io/github/issues/othneildrew/Best-README-Template.svg?style=flat-square 332 | [issues-url]: https://github.com/othneildrew/Best-README-Template/issues 333 | [license-shield]: https://img.shields.io/github/license/othneildrew/Best-README-Template.svg?style=flat-square 334 | [license-url]: https://github.com/othneildrew/Best-README-Template/blob/master/LICENSE.txt 335 | [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=flat-square&logo=linkedin&colorB=555 336 | [linkedin-url]: https://linkedin.com/in/othneildrew 337 | [product-screenshot]: images/screenshot.png 338 | 339 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | editdistance==0.5.3 2 | elasticsearch==7.8.1 3 | elasticsearch-dbapi==0.1.3 4 | es-pandas==0.0.16 5 | progressbar2==3.53.1 6 | seaborn==0.10.1 7 | sentence-transformers==0.3.9 8 | -------------------------------------------------------------------------------- /script/bi_encoder/bi-encoder-batch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import gzip 4 | import logging 5 | import math 6 | import os 7 | import random 8 | import tarfile 9 | from collections import defaultdict 10 | from datetime import datetime 11 | from glob import glob 12 | 13 | import numpy as np 14 | import pandas as pd 15 | import torch 16 | from sentence_transformers import (LoggingHandler, SentenceTransformer, 17 | evaluation, losses, models, util) 18 | from torch.utils.data import DataLoader, Dataset, IterableDataset 19 | 20 | #train_part, test_part, valid_part = map(lambda save_type: pd.read_csv(os.path.join(os.path.abspath(""), "{}_part.csv".format(save_type))).dropna(), ["train", "test", "valid"]) 21 | train_part, test_part, valid_part = map(lambda save_type: pd.read_csv(os.path.join("..data/", "{}_part.csv".format(save_type))).dropna(), ["train", "test", "valid"]) 22 | 23 | from sentence_transformers import InputExample 24 | class TripletsDataset(Dataset): 25 | def __init__(self, model, qa_df): 26 | assert set(["question", "answer", "q_idx"]).intersection(set(qa_df.columns.tolist())) == set(["question", "answer", "q_idx"]) 27 | self.model = model 28 | self.qa_df = qa_df 29 | self.q_idx_set = set(qa_df["q_idx"].value_counts().index.tolist()) 30 | 31 | def __getitem__(self, index): 32 | #raise NotImplementedError 33 | label = torch.tensor(1, dtype=torch.long) 34 | choice_s = self.qa_df.iloc[index] 35 | query_text, pos_text, q_idx = choice_s.loc["question"], choice_s.loc["answer"], choice_s.loc["q_idx"] 36 | query_text, pos_text, q_idx = choice_s.loc["question"], choice_s.loc["answer"], choice_s.loc["q_idx"] 37 | neg_q_idx = np.random.choice(list(self.q_idx_set.difference(set([q_idx])))) 38 | neg_text = self.qa_df[self.qa_df["q_idx"] == neg_q_idx].sample()["answer"].iloc[0] 39 | #### InputExample(texts=['I can\'t log in to my account.', 40 | #'Unable to access my account.', 41 | #'I need help with the payment process.'], 42 | #label=1), 43 | return InputExample(texts = [query_text, pos_text, neg_text], label = 1) 44 | ''' 45 | return [self.model.tokenize(query_text), 46 | self.model.tokenize(pos_text), 47 | self.model.tokenize(neg_text)], label 48 | ''' 49 | #return (query_text, pos_text, q_idx) 50 | 51 | def __len__(self): 52 | return self.qa_df.shape[0] 53 | 54 | 55 | class NoSameLabelsBatchSampler: 56 | def __init__(self, dataset, batch_size): 57 | self.dataset = dataset 58 | self.idx_org = list(range(len(dataset))) 59 | random.shuffle(self.idx_org) 60 | self.idx_copy = self.idx_org.copy() 61 | self.batch_size = batch_size 62 | 63 | def __iter__(self): 64 | batch = [] 65 | labels = set() 66 | num_miss = 0 67 | 68 | num_batches_returned = 0 69 | while num_batches_returned < self.__len__(): 70 | if len(self.idx_copy) == 0: 71 | random.shuffle(self.idx_org) 72 | self.idx_copy = self.idx_org.copy() 73 | 74 | idx = self.idx_copy.pop(0) 75 | #label = self.dataset[idx][1].cpu().tolist() 76 | label = self.dataset.qa_df["q_idx"].iloc[idx] 77 | 78 | if label not in labels: 79 | num_miss = 0 80 | batch.append(idx) 81 | labels.add(label) 82 | if len(batch) == self.batch_size: 83 | yield batch 84 | batch = [] 85 | labels = set() 86 | num_batches_returned += 1 87 | else: 88 | num_miss += 1 89 | self.idx_copy.append(idx) #Add item again to the end 90 | 91 | if num_miss >= len(self.idx_copy): #To many failures, flush idx_copy and start with clean 92 | self.idx_copy = [] 93 | 94 | def __len__(self): 95 | return math.ceil(len(self.dataset) / self.batch_size) 96 | 97 | 98 | def transform_part_df_into_Evaluator_format(part_df): 99 | req = part_df.copy() 100 | req["qid"] = req["question"].fillna("").map(hash).map(str) 101 | req["cid"] = req["answer"].fillna("").map(hash).map(str) 102 | queries = dict(map(tuple ,req[["qid", "question"]].drop_duplicates().values.tolist())) 103 | corpus = dict(map(tuple ,req[["cid", "answer"]].drop_duplicates().values.tolist())) 104 | qid_cid_set_df = req[["qid", "cid"]].groupby("qid")["cid"].apply(set).apply(sorted).apply(tuple).reset_index() 105 | qid_cid_set_df.columns = ["qid", "cid_set"] 106 | relevant_docs = dict(map(tuple ,qid_cid_set_df.drop_duplicates().values.tolist())) 107 | relevant_docs = dict(map(lambda t2: (t2[0], set(t2[1])) ,relevant_docs.items())) 108 | return queries, corpus, relevant_docs 109 | 110 | 111 | dev_queries, dev_corpus, dev_rel_docs = transform_part_df_into_Evaluator_format(valid_part.sample(frac=0.1)) 112 | ir_evaluator = evaluation.InformationRetrievalEvaluator(dev_queries, dev_corpus, dev_rel_docs, name='ms-marco-train_eval', batch_size=2) 113 | 114 | 115 | 116 | model_str = "xlm-roberta-base" 117 | #word_embedding_model = models.Transformer(model_str, max_seq_length=512) 118 | word_embedding_model = models.Transformer(model_str, max_seq_length=256) 119 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) 120 | model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) 121 | 122 | 123 | 124 | train_dataset = TripletsDataset(model=model, qa_df = train_part.sample(frac = 1.0, replace=False)) 125 | bs_obj = NoSameLabelsBatchSampler(train_dataset, batch_size=8) 126 | train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=1, batch_sampler=bs_obj, num_workers=1) 127 | train_loss = losses.MultipleNegativesRankingLoss(model=model) 128 | 129 | 130 | model_save_path = os.path.join(os.path.abspath(""), "bi_encoder_save") 131 | if not os.path.exists(model_save_path): 132 | os.mkdir(model_save_path) 133 | 134 | 135 | model.fit(train_objectives=[(train_dataloader, train_loss)], 136 | evaluator=ir_evaluator, 137 | epochs=10, 138 | warmup_steps=1000, 139 | output_path=model_save_path, 140 | evaluation_steps=5000, 141 | use_amp=True 142 | ) 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /script/bi_encoder/bi-encoder-data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import os 4 | from copy import deepcopy 5 | from functools import reduce 6 | from glob import glob 7 | 8 | import editdistance 9 | import numpy as np 10 | import pandas as pd 11 | 12 | ###https://github.com/brightmart/nlp_chinese_corpus 13 | ###https://github.com/brightmart/nlp_chinese_corpus#4%E7%A4%BE%E5%8C%BA%E9%97%AE%E7%AD%94json%E7%89%88webtext2019zh-%E5%A4%A7%E8%A7%84%E6%A8%A1%E9%AB%98%E8%B4%A8%E9%87%8F%E6%95%B0%E6%8D%AE%E9%9B%86 14 | ###https://drive.google.com/open?id=1u2yW_XohbYL2YAK6Bzc5XrngHstQTf0v 15 | 16 | data_dir = r"/home/svjack/temp_dir/webtext2019zh" 17 | json_files = glob(os.path.join(data_dir, "*.json")) 18 | train_json = list(filter(lambda path: "train" in path.lower(), json_files))[0] 19 | def json_reader(path, chunksize = 100): 20 | assert path.endswith(".json") 21 | return pd.read_json(path, lines = True, chunksize=chunksize) 22 | 23 | train_reader = json_reader(train_json, chunksize=10000) 24 | times = 100 25 | df_list = [] 26 | for i, df in enumerate(train_reader): 27 | df_list.append(df) 28 | if i + 1 >= times: 29 | break 30 | 31 | train_head_df = pd.concat(df_list, axis = 0) 32 | content_len_df = pd.concat([train_head_df["content"], train_head_df["content"].map(len)], axis = 1) 33 | content_len_df.columns = ["content", "c_len"] 34 | 35 | 36 | qa_df = train_head_df[["title", "content"]].copy() 37 | qa_df = qa_df.rename(columns = {"title": "question", "content": "answer"}).fillna("") 38 | 39 | qa_df = qa_df[qa_df["question"].map(len) <= 500] 40 | qa_df = qa_df[qa_df["answer"].map(len) <= 500] 41 | 42 | 43 | quests = deepcopy(qa_df["question"]) 44 | question_cmp = pd.concat([quests.sort_values().shift(1), quests.sort_values()], axis = 1) 45 | question_cmp["edit_val"] = question_cmp.fillna("").apply(lambda s: editdistance.eval(s.iloc[0], s.iloc[1]) / (len(s.iloc[0]) + len(s.iloc[1])), axis = 1) 46 | question_cmp.columns = ["q0", "q1", "edit_val"] 47 | 48 | threshold = 0.2 49 | question_nest_list = [[]] 50 | for idx ,r in question_cmp.iterrows(): 51 | q0, q1, v = r.iloc[0], r.iloc[1], r.iloc[2] 52 | if v < threshold: 53 | question_nest_list[-1].append(q0) 54 | question_nest_list[-1].append(q1) 55 | else: 56 | question_nest_list.append([]) 57 | 58 | 59 | idx_question_df_zip = pd.DataFrame(list(map(lambda x: [x] ,question_nest_list))) 60 | 61 | idx_question_df_zip = idx_question_df_zip[idx_question_df_zip.iloc[:, 0].map(len) > 0] 62 | idx_question_df_zip.columns = ["question"] 63 | idx_question_df_zip["q_idx"] = np.arange(idx_question_df_zip.shape[0]).tolist() 64 | 65 | idx_question_df = idx_question_df_zip.explode("question") 66 | 67 | #idx_question_df = pd.DataFrame(reduce(lambda a, b: a + b, map(lambda idx: list(map(lambda q: (idx, q), question_nest_list[idx])), range(len(question_nest_list))))) 68 | #idx_question_df.columns = ["q_idx", "question"] 69 | #idx_question_df.drop_duplicates().to_csv(os.path.join("/home/svjack/temp_dir/", "idx_question_df.csv"), index = False) 70 | 71 | idx_question_df_dd = idx_question_df.drop_duplicates() 72 | 73 | 74 | 75 | qa_df_dd = qa_df.drop_duplicates() 76 | cat_qa_df_with_idx = pd.merge(qa_df_dd, idx_question_df_dd, on = "question", how = "inner") 77 | q_idx_set = set(cat_qa_df_with_idx["q_idx"].value_counts().index.tolist()) 78 | 79 | q_idx_size_bigger_or_eql_3 = ((cat_qa_df_with_idx["q_idx"].value_counts() >= 3).reset_index()).groupby("q_idx")["index"].apply(set).apply(list)[True] 80 | q_idx_size_bigger_or_eql_3_df = cat_qa_df_with_idx[cat_qa_df_with_idx["q_idx"].isin(q_idx_size_bigger_or_eql_3)].copy() 81 | 82 | 83 | def produce_label_list(length = 10, p_list = [0.1, 0.1, 0.8]): 84 | from functools import reduce 85 | assert sum(p_list) == 1 86 | p_array = np.asarray(p_list) 87 | assert all((p_array[:-1] <= p_array[1:]).astype(bool).tolist()) 88 | num_array = (p_array * length).astype(np.int32) 89 | num_list = num_array.tolist() 90 | num_list = list(map(lambda x: max(x, 1), num_list)) 91 | num_list[-1] = length - sum(num_list[:-1]) 92 | return np.random.permutation(reduce(lambda a, b: a + b ,map(lambda idx: [idx] * num_list[idx], range(len(p_list))))) 93 | 94 | q_idx_size_bigger_or_eql_3_df["r_idx"] = q_idx_size_bigger_or_eql_3_df.index.tolist() 95 | 96 | def map_r_idx_list_to_split_label_zip(r_idx_list): 97 | split_label_list = produce_label_list(len(r_idx_list)) 98 | assert len(split_label_list) == len(r_idx_list) 99 | return zip(*[r_idx_list, split_label_list]) 100 | 101 | r_idx_split_label_items = reduce(lambda a, b: a + b ,q_idx_size_bigger_or_eql_3_df.groupby("q_idx")["r_idx"].apply(set).apply(list).apply(map_r_idx_list_to_split_label_zip).apply(list).tolist()) 102 | r_idx_split_label_df = pd.DataFrame(r_idx_split_label_items) 103 | r_idx_split_label_df.columns = ["r_idx", "split_label"] 104 | assert r_idx_split_label_df.shape[0] == pd.merge(q_idx_size_bigger_or_eql_3_df, r_idx_split_label_df, on = "r_idx", how = "inner").shape[0] 105 | 106 | q_idx_size_bigger_or_eql_3_df_before_split = pd.merge(q_idx_size_bigger_or_eql_3_df, r_idx_split_label_df, on = "r_idx", how = "inner") 107 | train_part = q_idx_size_bigger_or_eql_3_df_before_split[q_idx_size_bigger_or_eql_3_df_before_split["split_label"] == 2].copy() 108 | train_part = pd.concat([train_part, cat_qa_df_with_idx[(1 - cat_qa_df_with_idx["q_idx"].isin(q_idx_size_bigger_or_eql_3)).astype(bool)].copy()], axis = 0) 109 | valid_part = q_idx_size_bigger_or_eql_3_df_before_split[q_idx_size_bigger_or_eql_3_df_before_split["split_label"] == 0].copy() 110 | test_part = q_idx_size_bigger_or_eql_3_df_before_split[q_idx_size_bigger_or_eql_3_df_before_split["split_label"] == 1].copy() 111 | 112 | assert set(valid_part["q_idx"].tolist()) == set(test_part["q_idx"].tolist()) 113 | assert set(valid_part["q_idx"].tolist()) == set(valid_part["q_idx"].tolist()).intersection(train_part["q_idx"].tolist()) 114 | 115 | train_part.to_csv(os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir, "data", "train_part.csv"), index = False) 116 | test_part.to_csv(os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir, "data", "test_part.csv"), index = False) 117 | valid_part.to_csv(os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir, "data", "valid_part.csv"), index = False) 118 | -------------------------------------------------------------------------------- /script/cross_encoder/cross_encoder_random_on_multi_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | #es_host = 'localhost:9200' 4 | import csv 5 | import gzip 6 | import json 7 | import logging 8 | import os 9 | import tarfile 10 | import time 11 | from datetime import datetime 12 | #from es_pandas import es_pandas 13 | from functools import reduce 14 | from glob import glob 15 | from typing import Callable, Dict, Iterable, List, Type 16 | 17 | import numpy as np 18 | import pandas as pd 19 | import seaborn as sns 20 | import torch 21 | import transformers 22 | #from elasticsearch import Elasticsearch, helpers 23 | from sentence_transformers import (InputExample, LoggingHandler, 24 | SentenceTransformer, util) 25 | from sentence_transformers.cross_encoder import CrossEncoder 26 | from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator 27 | from sentence_transformers.evaluation import (SentenceEvaluator, 28 | SequentialEvaluator) 29 | from torch import nn 30 | from torch.optim import Optimizer 31 | from torch.utils.data import DataLoader 32 | from tqdm.autonotebook import tqdm, trange 33 | from transformers import (AutoConfig, AutoModelForSequenceClassification, 34 | AutoTokenizer) 35 | 36 | 37 | logger = logging.getLogger(__name__) 38 | pd.set_option("display.max_rows", 200) 39 | 40 | 41 | class DictionaryEvaluator(SequentialEvaluator): 42 | def __init__(self, evaluators: Iterable[SentenceEvaluator], main_score_function = lambda x: x): 43 | super(DictionaryEvaluator, self).__init__(evaluators, main_score_function) 44 | #self.eval_name_ext_dict = dict(map(lambda t2: (t2[0].lower()[:t2[0].lower().find("Evaluator".lower())], t2[1]) ,map(lambda eval_ext: (eval_ext.name, eval_ext) ,self.evaluators))) 45 | self.eval_name_ext_dict = dict(map(lambda t2: (t2[0], t2[1]) ,map(lambda eval_ext: (eval_ext.name, eval_ext) ,self.evaluators))) 46 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 47 | scores = {} 48 | #for evaluator in self.evaluators: 49 | for eval_ext_name, evaluator in self.eval_name_ext_dict.items(): 50 | #scores.append(evaluator(model, output_path, epoch, steps)) 51 | #eval_output_path = output_path + eval_ext_name 52 | eval_output_path = output_path + "_" + eval_ext_name 53 | #scores[eval_ext_name] = evaluator(model, output_path, epoch, steps) 54 | if eval_output_path is not None and not(os.path.exists(eval_output_path)): 55 | os.makedirs(eval_output_path, exist_ok=True) 56 | scores[eval_ext_name] = evaluator(model, eval_output_path, epoch, steps) 57 | return self.main_score_function(scores) 58 | 59 | class CrossEncoder_Dict_Eval(CrossEncoder): 60 | def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback): 61 | """Runs evaluation during the training""" 62 | if evaluator is not None: 63 | if isinstance(evaluator, DictionaryEvaluator): 64 | if type(self.best_score) != type({}): 65 | self.best_score = dict(map(lambda eval_ext_name: (eval_ext_name, self.best_score), evaluator.eval_name_ext_dict.keys())) 66 | score_dict = evaluator(self, output_path=output_path, epoch=epoch, steps=steps) 67 | if callback is not None: 68 | callback(score, epoch, steps) 69 | for eval_ext_name, eval_score in score_dict.items(): 70 | if eval_score > self.best_score[eval_ext_name]: 71 | self.best_score[eval_ext_name] = eval_score 72 | if save_best_model: 73 | #eval_output_path = output_path + eval_ext_name 74 | eval_output_path = output_path + "_" + eval_ext_name 75 | self.save(eval_output_path) 76 | else: 77 | score = evaluator(self, output_path=output_path, epoch=epoch, steps=steps) 78 | if callback is not None: 79 | callback(score, epoch, steps) 80 | if score > self.best_score: 81 | self.best_score = score 82 | if save_best_model: 83 | self.save(output_path) 84 | 85 | 86 | class CERerankingEvaluatorSUM: 87 | """ 88 | This class evaluates a CrossEncoder model for the task of re-ranking. 89 | 90 | Given a query and a list of documents, it computes the score [query, doc_i] for all possible 91 | documents and sorts them in decreasing order. Then, MRR@10 is compute to measure the quality of the ranking. 92 | 93 | :param samples: Must be a list and each element is of the form: {'query': '', 'positive': [], 'negative': []}. Query is the search query, 94 | positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents. 95 | """ 96 | def __init__(self, samples, mrr_at_k: int = 10, name: str = '', num_dev_queries = 600): 97 | self.samples = samples 98 | self.name = name 99 | self.mrr_at_k = mrr_at_k 100 | 101 | self.num_dev_queries = num_dev_queries 102 | 103 | if isinstance(self.samples, dict): 104 | self.samples = list(self.samples.values()) 105 | #output/training_ms-marco_cross-encoder-xlm-roberta-base-2021-01-12_21-10-39mrr-train-eva/CERerankingEvaluator_mrr-train-eval_results.csv 106 | self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv" 107 | self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)] 108 | 109 | self.score_json_file = self.csv_file.replace(".csv", ".json") 110 | 111 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 112 | if epoch != -1: 113 | if steps == -1: 114 | out_txt = " after epoch {}:".format(epoch) 115 | else: 116 | out_txt = " in epoch {} after {} steps:".format(epoch, steps) 117 | else: 118 | out_txt = ":" 119 | 120 | logger.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt) 121 | 122 | all_mrr_scores = [] 123 | num_queries = 0 124 | num_positives = [] 125 | num_negatives = [] 126 | scores_list = [] 127 | 128 | samples = list(self.samples) 129 | samples_indexes = np.random.permutation(np.arange(len(samples))) 130 | samples = list(map(lambda idx: samples[idx], samples_indexes[:self.num_dev_queries])) 131 | #for instance in self.samples: 132 | for instance in samples: 133 | query = instance['query'] 134 | positive = list(instance['positive']) 135 | negative = list(instance['negative']) 136 | docs = positive + negative 137 | is_relevant = [True]*len(positive) + [False]*len(negative) 138 | 139 | if len(positive) == 0 or len(negative) == 0: 140 | continue 141 | 142 | num_queries += 1 143 | num_positives.append(len(positive)) 144 | num_negatives.append(len(negative)) 145 | 146 | model_input = [[query, doc] for doc in docs] 147 | pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False) 148 | scores_list.extend(list(pred_scores)) 149 | pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order 150 | 151 | mrr_score = 0 152 | for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]): 153 | if is_relevant[index]: 154 | mrr_score += 1 / (rank+1) 155 | 156 | all_mrr_scores.append(mrr_score) 157 | 158 | mean_mrr = np.mean(all_mrr_scores) 159 | logger.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives))) 160 | logger.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100)) 161 | 162 | if output_path is not None: 163 | csv_path = os.path.join(output_path, self.csv_file) 164 | output_file_exists = os.path.isfile(csv_path) 165 | with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f: 166 | writer = csv.writer(f) 167 | if not output_file_exists: 168 | writer.writerow(self.csv_headers) 169 | writer.writerow([epoch, steps, mean_mrr]) 170 | json_path = os.path.join(output_path, self.score_json_file) 171 | output_file_exists = os.path.isfile(json_path) 172 | with open(json_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f: 173 | writer = csv.writer(f) 174 | if not output_file_exists: 175 | writer.writerow(["epoch", "steps", "score@"]) 176 | writer.writerow([epoch, steps, json.dumps({"scores_list": list(map(float ,scores_list))})]) 177 | return mean_mrr 178 | 179 | class ScoreCalculator(object): 180 | def __init__(self, 181 | queries_ids, 182 | relevant_docs, 183 | mrr_at_k: List[int] = [10], 184 | ndcg_at_k: List[int] = [10], 185 | accuracy_at_k: List[int] = [1, 3, 5, 10], 186 | precision_recall_at_k: List[int] = [1, 3, 5, 10], 187 | map_at_k: List[int] = [100], 188 | ): 189 | "queries_ids list of query, relevant_docs key query value set or list of relevant_docs" 190 | self.queries_ids = queries_ids 191 | self.relevant_docs = relevant_docs 192 | 193 | self.mrr_at_k = mrr_at_k 194 | self.ndcg_at_k = ndcg_at_k 195 | self.accuracy_at_k = accuracy_at_k 196 | self.precision_recall_at_k = precision_recall_at_k 197 | self.map_at_k = map_at_k 198 | def compute_metrics(self, queries_result_list: List[object]): 199 | # Init score computation values 200 | num_hits_at_k = {k: 0 for k in self.accuracy_at_k} 201 | precisions_at_k = {k: [] for k in self.precision_recall_at_k} 202 | recall_at_k = {k: [] for k in self.precision_recall_at_k} 203 | MRR = {k: 0 for k in self.mrr_at_k} 204 | ndcg = {k: [] for k in self.ndcg_at_k} 205 | AveP_at_k = {k: [] for k in self.map_at_k} 206 | 207 | # Compute scores on results 208 | #### elements with hits one hit is a dict : {"corpus_id": corpus_text, "score": score} 209 | #### corpus_id replace by corpus text 210 | for query_itr in range(len(queries_result_list)): 211 | query_id = self.queries_ids[query_itr] 212 | 213 | # Sort scores 214 | top_hits = sorted(queries_result_list[query_itr], key=lambda x: x['score'], reverse=True) 215 | query_relevant_docs = self.relevant_docs[query_id] 216 | 217 | # Accuracy@k - We count the result correct, if at least one relevant doc is accross the top-k documents 218 | for k_val in self.accuracy_at_k: 219 | for hit in top_hits[0:k_val]: 220 | if hit['corpus_id'] in query_relevant_docs: 221 | num_hits_at_k[k_val] += 1 222 | break 223 | 224 | # Precision and Recall@k 225 | for k_val in self.precision_recall_at_k: 226 | num_correct = 0 227 | for hit in top_hits[0:k_val]: 228 | if hit['corpus_id'] in query_relevant_docs: 229 | num_correct += 1 230 | 231 | precisions_at_k[k_val].append(num_correct / k_val) 232 | recall_at_k[k_val].append(num_correct / len(query_relevant_docs)) 233 | 234 | # MRR@k 235 | for k_val in self.mrr_at_k: 236 | for rank, hit in enumerate(top_hits[0:k_val]): 237 | if hit['corpus_id'] in query_relevant_docs: 238 | MRR[k_val] += 1.0 / (rank + 1) 239 | break 240 | 241 | # NDCG@k 242 | for k_val in self.ndcg_at_k: 243 | predicted_relevance = [1 if top_hit['corpus_id'] in query_relevant_docs else 0 for top_hit in top_hits[0:k_val]] 244 | true_relevances = [1] * len(query_relevant_docs) 245 | 246 | ndcg_value = self.compute_dcg_at_k(predicted_relevance, k_val) / self.compute_dcg_at_k(true_relevances, k_val) 247 | ndcg[k_val].append(ndcg_value) 248 | 249 | # MAP@k 250 | for k_val in self.map_at_k: 251 | num_correct = 0 252 | sum_precisions = 0 253 | 254 | for rank, hit in enumerate(top_hits[0:k_val]): 255 | if hit['corpus_id'] in query_relevant_docs: 256 | num_correct += 1 257 | sum_precisions += num_correct / (rank + 1) 258 | 259 | avg_precision = sum_precisions / min(k_val, len(query_relevant_docs)) 260 | AveP_at_k[k_val].append(avg_precision) 261 | 262 | # Compute averages 263 | for k in num_hits_at_k: 264 | #num_hits_at_k[k] /= len(self.queries) 265 | num_hits_at_k[k] /= len(queries_result_list) 266 | 267 | for k in precisions_at_k: 268 | precisions_at_k[k] = np.mean(precisions_at_k[k]) 269 | 270 | for k in recall_at_k: 271 | recall_at_k[k] = np.mean(recall_at_k[k]) 272 | 273 | for k in ndcg: 274 | ndcg[k] = np.mean(ndcg[k]) 275 | 276 | for k in MRR: 277 | #MRR[k] /= len(self.queries) 278 | MRR[k] /= len(queries_result_list) 279 | 280 | for k in AveP_at_k: 281 | AveP_at_k[k] = np.mean(AveP_at_k[k]) 282 | return {'accuracy@k': num_hits_at_k, 'precision@k': precisions_at_k, 'recall@k': recall_at_k, 'ndcg@k': ndcg, 'mrr@k': MRR, 'map@k': AveP_at_k} 283 | @staticmethod 284 | def compute_dcg_at_k(relevances, k): 285 | dcg = 0 286 | for i in range(min(len(relevances), k)): 287 | dcg += relevances[i] / np.log2(i + 2) #+2 as we start our idx at 0 288 | return dcg 289 | 290 | def map_dev_samples_to_score_calculator_format(dev_samples): 291 | if isinstance(dev_samples, dict): 292 | dev_samples = list(dev_samples.values()) 293 | queries_ids = list(map(lambda x: x["query"] ,dev_samples)) 294 | relevant_docs = dict(map(lambda idx: (dev_samples[idx]["query"], dev_samples[idx]["positive"]), range(len(dev_samples)))) 295 | return ScoreCalculator(queries_ids, relevant_docs) 296 | 297 | class ScoreEvaluator: 298 | """ 299 | This class evaluates a CrossEncoder model for the task of re-ranking. 300 | 301 | Given a query and a list of documents, it computes the score [query, doc_i] for all possible 302 | documents and sorts them in decreasing order. Then, MRR@10 is compute to measure the quality of the ranking. 303 | 304 | :param samples: Must be a list and each element is of the form: {'query': '', 'positive': [], 'negative': []}. Query is the search query, 305 | positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents. 306 | """ 307 | def __init__(self, samples, name: str = '', num_dev_queries = 600): 308 | self.samples = samples 309 | self.name = name 310 | self.num_dev_queries = num_dev_queries 311 | 312 | if isinstance(self.samples, dict): 313 | self.samples = list(self.samples.values()) 314 | 315 | #self.score_calculator_ext = map_dev_samples_to_score_calculator_format(self.samples) 316 | 317 | self.csv_file = "ScoreEvaluator" + ("_" + name if name else '') + "_results.csv" 318 | self.csv_headers = ["epoch", "steps", "MAP@"] 319 | 320 | self.score_json_file = self.csv_file.replace(".csv", ".json") 321 | 322 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 323 | if epoch != -1: 324 | if steps == -1: 325 | out_txt = " after epoch {}:".format(epoch) 326 | else: 327 | out_txt = " in epoch {} after {} steps:".format(epoch, steps) 328 | else: 329 | out_txt = ":" 330 | 331 | logger.info("ScoreEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt) 332 | 333 | all_scores = [] 334 | num_queries = 0 335 | num_positives = [] 336 | num_negatives = [] 337 | 338 | scores_list = [] 339 | 340 | samples = list(self.samples) 341 | samples_indexes = np.random.permutation(np.arange(len(samples))) 342 | samples = list(map(lambda idx: samples[idx], samples_indexes[:self.num_dev_queries])) 343 | #for instance in self.samples: 344 | for instance in samples: 345 | query = instance['query'] 346 | positive = list(instance['positive']) 347 | negative = list(instance['negative']) 348 | docs = positive + negative 349 | is_relevant = [True]*len(positive) + [False]*len(negative) 350 | 351 | if len(positive) == 0 or len(negative) == 0: 352 | continue 353 | 354 | num_queries += 1 355 | num_positives.append(len(positive)) 356 | num_negatives.append(len(negative)) 357 | 358 | model_input = [[query, doc] for doc in docs] 359 | pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False) 360 | 361 | scores_list.extend(list(pred_scores)) 362 | 363 | #### elements with hits one hit is a dict : {"corpus_id": corpus_text, "score": score} 364 | #### corpus_id replace by corpus text 365 | queries_result_list = list(map(lambda idx: {"corpus_id": docs[idx], "score": pred_scores[idx]}, range(len(docs)))) 366 | score_calculator_ext = map_dev_samples_to_score_calculator_format({0: instance}) 367 | score_dict = score_calculator_ext.compute_metrics([queries_result_list]) 368 | all_scores.append(score_dict["map@k"][100]) 369 | 370 | mean_map = np.mean(all_scores) 371 | logger.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives))) 372 | logger.info("MAP@: {:.2f}".format(mean_map*100)) 373 | 374 | if output_path is not None: 375 | csv_path = os.path.join(output_path, self.csv_file) 376 | output_file_exists = os.path.isfile(csv_path) 377 | with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f: 378 | writer = csv.writer(f) 379 | if not output_file_exists: 380 | writer.writerow(self.csv_headers) 381 | #writer.writerow([epoch, steps, mean_mrr]) 382 | writer.writerow([epoch, steps, mean_map]) 383 | json_path = os.path.join(output_path, self.score_json_file) 384 | output_file_exists = os.path.isfile(json_path) 385 | with open(json_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f: 386 | writer = csv.writer(f) 387 | if not output_file_exists: 388 | writer.writerow(["epoch", "steps", "score@"]) 389 | writer.writerow([epoch, steps, json.dumps({"scores_list": list(map(float ,scores_list))})]) 390 | #return mean_mrr 391 | return mean_map 392 | 393 | 394 | def read_part_file(path, adjust_neg_pos_ration = None): 395 | json_full_paths = glob(os.path.join(path, "*.json")) 396 | assert len(json_full_paths) == 4 397 | req = {} 398 | for path in json_full_paths: 399 | val_name = path.split("/")[-1].replace(".json", "") 400 | with open(path, "r", encoding = "utf-8") as f: 401 | j_obj = json.load(f) 402 | if len(j_obj) == 1: 403 | assert type(j_obj[list(j_obj.keys())[0]]) == type([]) 404 | j_obj = j_obj[list(j_obj.keys())[0]] 405 | else: 406 | assert val_name.endswith("_index_dict") and len(val_name.split("_")) == 3 407 | val_name = "_".join(np.asarray(val_name.split("_"))[[1, 0, 2]].tolist()) 408 | j_obj = dict(map(lambda t2: (t2[1], t2[0]) ,j_obj.items())) 409 | req[val_name] = j_obj 410 | if adjust_neg_pos_ration is not None: 411 | assert type(adjust_neg_pos_ration) == type(0) 412 | neg_tuple_set = req["neg_tuple_set"] 413 | pos_tuple_set = req["pos_tuple_set"] 414 | print("ori pos {} neg {}".format(len(pos_tuple_set), len(neg_tuple_set))) 415 | if len(pos_tuple_set) < len(neg_tuple_set): 416 | if len(neg_tuple_set) / len(pos_tuple_set) >= adjust_neg_pos_ration: 417 | pos_num, neg_num = len(pos_tuple_set), len(pos_tuple_set) * adjust_neg_pos_ration 418 | else: 419 | pos_num, neg_num = int(len(neg_tuple_set) / adjust_neg_pos_ration), len(neg_tuple_set) 420 | else: 421 | min_size = min(map(len, [neg_tuple_set, pos_tuple_set])) 422 | snip_size = int(max(1, min_size / 10000)) 423 | snip_num = int(min_size / snip_size) 424 | pos_num = int(min((snip_num / adjust_neg_pos_ration) * snip_size, min_size)) 425 | neg_num = int(min((snip_num / adjust_neg_pos_ration) * snip_size * adjust_neg_pos_ration, min_size)) 426 | neg_tuple_set = set(list(map(tuple ,neg_tuple_set))[:neg_num]) 427 | pos_tuple_set = set(list(map(tuple ,pos_tuple_set))[:pos_num]) 428 | req["neg_tuple_set"] = neg_tuple_set 429 | req["pos_tuple_set"] = pos_tuple_set 430 | return req 431 | 432 | def construct_train_samples(json_obj, neg_random = False): 433 | train_samples = [] 434 | label = 1 435 | for t2 in json_obj["pos_tuple_set"]: 436 | q_index, a_index = t2[0], t2[1] 437 | q, a = json_obj["index_question_dict"][q_index], json_obj["index_answer_dict"][a_index] 438 | #q = q + "" 439 | train_samples.append(InputExample(texts=[q, a], label=label)) 440 | if neg_random: 441 | neg_len = len(json_obj["index_answer_dict"]) 442 | label = 0 443 | for t2 in json_obj["neg_tuple_set"]: 444 | q_index, a_index = t2[0], t2[1] 445 | if neg_random: 446 | q, a = json_obj["index_question_dict"][q_index], json_obj["index_answer_dict"][np.random.randint(0, neg_len)] 447 | else: 448 | q, a = json_obj["index_question_dict"][q_index], json_obj["index_answer_dict"][a_index] 449 | #q = q + "" 450 | train_samples.append(InputExample(texts=[q, a], label=label)) 451 | train_indexes = np.random.permutation(np.arange(len(train_samples))) 452 | return list(map(lambda idx: train_samples[idx], train_indexes)) 453 | 454 | def construct_dev_samples(json_obj ,num_dev_queries = int(2e3), num_max_dev_negatives = 200, neg_random = False): 455 | dev_samples_list = construct_train_samples(json_obj, neg_random) 456 | dev_samples_df = pd.DataFrame(list(map(lambda item_:(item_.texts[0], item_.texts[1], item_.label) ,dev_samples_list)), columns = ["q", "a", "l"]) 457 | dev_q_qid_dict = dict(map(lambda t2: (t2[1], t2[0]), enumerate(dev_samples_df["q"].drop_duplicates().tolist()))) 458 | print("dev_samples_df shape {}, dev_q_qid_dict len {}".format(dev_samples_df.shape, len(dev_q_qid_dict))) 459 | dev_samples = {} 460 | for q, qid in dev_q_qid_dict.items(): 461 | if qid not in dev_samples and len(dev_samples) < num_dev_queries: 462 | dev_samples[qid] = {'query': q, 'positive': set(), 'negative': set()} 463 | for qid in dev_samples.keys(): 464 | qid_relate_df = dev_samples_df[dev_samples_df["q"] == dev_samples[qid]["query"]] 465 | dev_samples[qid]["positive"] = set(qid_relate_df[qid_relate_df["l"] == 1]["a"].tolist()) 466 | for neg_a in set(qid_relate_df[qid_relate_df["l"] == 0]["a"].tolist()): 467 | if len(dev_samples[qid]['negative']) < num_max_dev_negatives: 468 | dev_samples[qid]['negative'].add(neg_a) 469 | return dev_samples 470 | 471 | 472 | def merge_dev_samples(dev_samples_list): 473 | assert len(dev_samples_list) > 1 474 | query_pos_set_dict = {} 475 | query_neg_set_dict = {} 476 | for dev_samples in dev_samples_list: 477 | for qid, item_ in dev_samples.items(): 478 | # item_ {'query': q, 'positive': set(), 'negative': set()} 479 | query = item_["query"] 480 | positive = item_["positive"] 481 | negative = item_["negative"] 482 | if query not in query_pos_set_dict: 483 | query_pos_set_dict[query] = positive 484 | else: 485 | for ele in positive: 486 | query_pos_set_dict[query].add(ele) 487 | if query not in query_neg_set_dict: 488 | query_neg_set_dict[query] = negative 489 | else: 490 | for ele in negative: 491 | query_neg_set_dict[query].add(ele) 492 | assert set(query_pos_set_dict.keys()) == set(query_neg_set_dict.keys()) 493 | merge_dev_samples = {} 494 | for qid, query in enumerate(query_pos_set_dict.keys()): 495 | merge_dev_samples[qid] = {'query': query, 'positive': query_pos_set_dict[query], 'negative': query_neg_set_dict[query]} 496 | return merge_dev_samples 497 | 498 | 499 | def merge_dev_samples_add_neg(dev_samples_list, add_num = None, after_size = 500): 500 | assert len(dev_samples_list) > 1 501 | query_pos_set_dict = {} 502 | query_neg_set_dict = {} 503 | all_negs = reduce(lambda a, b: a.union(b) ,map(lambda dev_samples: reduce(lambda a, b: a.union(b) ,map(lambda item_:set(item_["negative"]) ,dev_samples.values())), dev_samples_list)) 504 | assert type(all_negs) == type(set([])) 505 | assert set(map(type, all_negs)) == set([type("")]) 506 | all_negs = list(all_negs) 507 | for dev_samples in dev_samples_list: 508 | for qid, item_ in dev_samples.items(): 509 | # item_ {'query': q, 'positive': set(), 'negative': set()} 510 | query = item_["query"] 511 | positive = item_["positive"] 512 | negative = item_["negative"] 513 | if query not in query_pos_set_dict: 514 | query_pos_set_dict[query] = positive 515 | else: 516 | for ele in positive: 517 | query_pos_set_dict[query].add(ele) 518 | if query not in query_neg_set_dict: 519 | query_neg_set_dict[query] = negative 520 | else: 521 | for ele in negative: 522 | query_neg_set_dict[query].add(ele) 523 | if add_num is not None: 524 | assert type(add_num) == type(0) 525 | all_negs_indexes = np.random.permutation(np.arange(len(all_negs))) 526 | all_negs_add = list(map(lambda idx: all_negs[idx], all_negs_indexes[:add_num])) 527 | for ele in all_negs_add: 528 | if ele not in positive: 529 | query_neg_set_dict[query].add(ele) 530 | assert set(query_pos_set_dict.keys()) == set(query_neg_set_dict.keys()) 531 | merge_dev_samples = {} 532 | for qid, query in enumerate(query_pos_set_dict.keys()): 533 | if len(merge_dev_samples) >= after_size: 534 | break 535 | merge_dev_samples[qid] = {'query': query, 'positive': query_pos_set_dict[query], 'negative': query_neg_set_dict[query]} 536 | return merge_dev_samples 537 | 538 | 539 | train_json_obj = read_part_file("train_file_faiss_10", adjust_neg_pos_ration = None) 540 | train_samples = construct_train_samples(train_json_obj, neg_random = True) 541 | 542 | 543 | valid_json_obj = read_part_file("valid_file_faiss", adjust_neg_pos_ration = None) 544 | 545 | 546 | dev_samples = construct_dev_samples(valid_json_obj, neg_random = True) 547 | 548 | 549 | 550 | model_name = 'xlm-roberta-base' 551 | model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 552 | model = CrossEncoder_Dict_Eval(model_name, num_labels=1, max_length=512) 553 | 554 | 555 | 556 | train_batch_size = 10 557 | num_epochs = 10 558 | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) 559 | 560 | 561 | 562 | mrr_sum_evaluator = CERerankingEvaluatorSUM(dev_samples, name='mrr-train-eval') 563 | map_evaluator = ScoreEvaluator(dev_samples, name = "map-train-eval") 564 | dict_evaluator = DictionaryEvaluator([mrr_sum_evaluator, map_evaluator]) 565 | 566 | 567 | warmup_steps = 1000 568 | logging.info("Warmup-steps: {}".format(warmup_steps)) 569 | 570 | 571 | # Train the model 572 | model.fit(train_dataloader=train_dataloader, 573 | evaluator=dict_evaluator, 574 | epochs=num_epochs, 575 | evaluation_steps=5000, 576 | warmup_steps=warmup_steps, 577 | output_path=model_save_path, 578 | use_amp=False) 579 | 580 | #Save latest model 581 | model.save(model_save_path+'-latest') 582 | 583 | 584 | 585 | -------------------------------------------------------------------------------- /script/cross_encoder/try_sbert_neg_sampler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | import json 6 | import logging 7 | import os 8 | import pickle 9 | import random 10 | import time 11 | import traceback 12 | from functools import reduce 13 | 14 | import faiss 15 | import numpy as np 16 | import pandas as pd 17 | import scipy.spatial 18 | import torch 19 | from elasticsearch import Elasticsearch, helpers 20 | from es_pandas import es_pandas 21 | from IPython import embed 22 | from sentence_transformers import InputExample, SentenceTransformer, util 23 | 24 | es_host = 'localhost:9200' 25 | train_part, test_part, valid_part = map(lambda save_type: pd.read_csv(os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir, "data", "{}_part.csv".format(save_type)) 26 | ).dropna(), ["train", "test", "valid"]) 27 | 28 | 29 | class es_pandas_edit(es_pandas): 30 | @staticmethod 31 | def serialize(row, columns, use_pandas_json, iso_dates): 32 | if use_pandas_json: 33 | return json.dumps(dict(zip(columns, row)), iso_dates=iso_dates) 34 | return dict(zip(columns, [None if (all(pd.isna(r)) if (hasattr(r, "__len__") and type(r) != type("")) else pd.isna(r)) else r for r in row])) 35 | def to_pandas_iter(self, index, query_rule=None, heads=[], dtype={}, infer_dtype=False, show_progress=True, 36 | chunk_size = None, **kwargs): 37 | if query_rule is None: 38 | query_rule = {'query': {'match_all': {}}} 39 | count = self.es.count(index=index, body=query_rule)['count'] 40 | if count < 1: 41 | raise Exception('Empty for %s' % index) 42 | query_rule['_source'] = heads 43 | anl = helpers.scan(self.es, query=query_rule, index=index, **kwargs) 44 | source_iter = self.get_source(anl, show_progress = show_progress, count = count) 45 | print(source_iter) 46 | if chunk_size is None: 47 | df = pd.DataFrame(list(source_iter)).set_index('_id') 48 | if infer_dtype: 49 | dtype = self.infer_dtype(index, df.columns.values) 50 | if len(dtype): 51 | df = df.astype(dtype) 52 | yield df 53 | return 54 | assert type(chunk_size) == type(0) 55 | def map_list_of_dicts_into_df(list_of_dicts, set_index = "_id"): 56 | from collections import defaultdict 57 | req = defaultdict(list) 58 | for dict_ in list_of_dicts: 59 | for k, v in dict_.items(): 60 | req[k].append(v) 61 | req = pd.DataFrame.from_dict(req) 62 | if set_index: 63 | assert set_index in req.columns.tolist() 64 | t_df = req.set_index(set_index) 65 | if infer_dtype: 66 | dtype = self.infer_dtype(index, t_df.columns.values) 67 | if len(dtype): 68 | t_df = t_df.astype(dtype) 69 | return t_df 70 | list_of_dicts = [] 71 | for dict_ in source_iter: 72 | list_of_dicts.append(dict_) 73 | if len(list_of_dicts) >= chunk_size: 74 | yield map_list_of_dicts_into_df(list_of_dicts) 75 | list_of_dicts = [] 76 | if list_of_dicts: 77 | yield map_list_of_dicts_into_df(list_of_dicts) 78 | 79 | 80 | 81 | ep = es_pandas_edit(es_host) 82 | if ep.ic.exists("train_part"): 83 | ep.ic.delete(index = "train_part") 84 | 85 | 86 | ep.init_es_tmpl(train_part.head(1000), "train_part_doc_type", delete=True) 87 | valid_part_tmp = ep.es.indices.get_template("train_part_doc_type") 88 | es_index = valid_part_tmp["train_part_doc_type"] 89 | es_index["mappings"]["properties"]["question"] = { 90 | "type": "text", 91 | } 92 | es_index["mappings"]["properties"]["answer"] = { 93 | "type": "text", 94 | } 95 | es_index = {"mappings": es_index["mappings"]} 96 | ep.es.indices.create(index='train_part', body=es_index, ignore=[400]) 97 | 98 | 99 | chunk_size = 10000 100 | range_list = list(range(0, train_part.shape[0], chunk_size)) 101 | if train_part.shape[0] not in range_list: 102 | range_list.append(train_part.shape[0]) 103 | assert "".join(map(str ,range_list)).startswith("0") and "".join(map(str ,range_list)).endswith("{}".format(train_part.shape[0])) 104 | 105 | for i in range(len(range_list) - 1): 106 | part_tiny = train_part.iloc[range_list[i]:range_list[i+1]] 107 | ep.to_es(part_tiny, "train_part") 108 | 109 | assert reduce(lambda a, b: a + b, map(lambda df: df.shape[0] ,ep.to_pandas_iter("train_part", chunk_size = chunk_size))) == train_part.shape[0] 110 | 111 | if ep.ic.exists("valid_part"): 112 | ep.ic.delete(index = "valid_part") 113 | 114 | 115 | ep.init_es_tmpl(train_part.head(1000), "valid_part_doc_type", delete=True) 116 | valid_part_tmp = ep.es.indices.get_template("valid_part_doc_type") 117 | es_index = valid_part_tmp["valid_part_doc_type"] 118 | es_index["mappings"]["properties"]["question"] = { 119 | "type": "text", 120 | } 121 | es_index["mappings"]["properties"]["answer"] = { 122 | "type": "text", 123 | } 124 | es_index = {"mappings": es_index["mappings"]} 125 | ep.es.indices.create(index='valid_part', body=es_index, ignore=[400]) 126 | 127 | 128 | chunk_size = 10000 129 | range_list = list(range(0, valid_part.shape[0], chunk_size)) 130 | if valid_part.shape[0] not in range_list: 131 | range_list.append(valid_part.shape[0]) 132 | assert "".join(map(str ,range_list)).startswith("0") and "".join(map(str ,range_list)).endswith("{}".format(valid_part.shape[0])) 133 | 134 | for i in range(len(range_list) - 1): 135 | part_tiny = valid_part.iloc[range_list[i]:range_list[i+1]] 136 | ep.to_es(part_tiny, "valid_part") 137 | 138 | assert reduce(lambda a, b: a + b, map(lambda df: df.shape[0] ,ep.to_pandas_iter("valid_part", chunk_size = chunk_size))) == valid_part.shape[0] 139 | 140 | 141 | class SentenceBERTNegativeSampler(): 142 | """ 143 | Sample candidates from a list of candidates using dense embeddings from sentenceBERT. 144 | 145 | Args: 146 | candidates: list of str containing the candidates 147 | num_candidates_samples: int containing the number of negative samples for each query. 148 | embeddings_file: str containing the path to cache the embeddings. 149 | sample_data: int indicating amount of candidates in the index (-1 if all) 150 | pre_trained_model: str containing the pre-trained sentence embedding model, 151 | e.g. bert-base-nli-stsb-mean-tokens. 152 | """ 153 | def __init__(self, candidates, num_candidates_samples, embeddings_file, sample_data, 154 | pre_trained_model='bert-base-nli-stsb-mean-tokens', seed=42): 155 | random.seed(seed) 156 | self.candidates = candidates 157 | self.num_candidates_samples = num_candidates_samples 158 | self.pre_trained_model = pre_trained_model 159 | 160 | #self.model = SentenceTransformer(self.pre_trained_model) 161 | self.model = SentenceTransformer(self.pre_trained_model, device = "cpu") 162 | #extract the name of the folder with the pre-trained sentence embedding 163 | if os.path.isdir(self.pre_trained_model): 164 | self.pre_trained_model = self.pre_trained_model.split("/")[-1] 165 | 166 | self.name = "SentenceBERTNS_"+self.pre_trained_model 167 | self.sample_data = sample_data 168 | self.embeddings_file = embeddings_file 169 | 170 | self._calculate_sentence_embeddings() 171 | self._build_faiss_index() 172 | 173 | def _calculate_sentence_embeddings(self): 174 | """ 175 | Calculates sentenceBERT embeddings for all candidates. 176 | """ 177 | embeds_file_path = "{}_n_sample_{}_pre_trained_model_{}".format(self.embeddings_file, 178 | self.sample_data, 179 | self.pre_trained_model) 180 | if not os.path.isfile(embeds_file_path): 181 | logging.info("Calculating embeddings for the candidates.") 182 | self.candidate_embeddings = self.model.encode(self.candidates, show_progress_bar=True) 183 | with open(embeds_file_path, 'wb') as f: 184 | pickle.dump(self.candidate_embeddings, f) 185 | else: 186 | with open(embeds_file_path, 'rb') as f: 187 | self.candidate_embeddings = pickle.load(f) 188 | 189 | def _build_faiss_index(self): 190 | """ 191 | Builds the faiss indexes containing all sentence embeddings of the candidates. 192 | """ 193 | self.index = faiss.IndexFlatL2(self.candidate_embeddings[0].shape[0]) # build the index 194 | self.index.add(np.array(self.candidate_embeddings)) 195 | logging.info("There is a total of {} candidates.".format(len(self.candidates))) 196 | logging.info("There is a total of {} candidate embeddings.".format(len(self.candidate_embeddings))) 197 | logging.info("Faiss index has a total of {} candidates".format(self.index.ntotal)) 198 | 199 | def sample(self, query_str, relevant_docs): 200 | """ 201 | Samples from a list of candidates using dot product sentenceBERT similarity. 202 | 203 | If the samples match the relevant doc, then removes it and re-samples randomly. 204 | The method uses faiss index to be efficient. 205 | 206 | Args: 207 | query_str: the str of the query to be used for the dense similarity matching. 208 | relevant_docs: list with the str of the relevant documents, to avoid sampling them as negative sample. 209 | 210 | Returns: 211 | A triplet containing the list of negative samples, 212 | whether the method had retrieved the relevant doc and 213 | if yes its rank in the list. 214 | """ 215 | query_embedding = self.model.encode([query_str], show_progress_bar=False) 216 | 217 | distances, idxs = self.index.search(np.array(query_embedding), self.num_candidates_samples) 218 | sampled_initial = [self.candidates[idx] for idx in idxs[0]] 219 | 220 | was_relevant_sampled = False 221 | relevant_doc_rank = -1 222 | sampled = [] 223 | for i, d in enumerate(sampled_initial): 224 | if d in relevant_docs: 225 | was_relevant_sampled = True 226 | relevant_doc_rank = i 227 | else: 228 | sampled.append(d) 229 | 230 | while len(sampled) != self.num_candidates_samples: 231 | sampled = sampled + [d for d in random.sample(self.candidates, self.num_candidates_samples-len(sampled)) 232 | if d not in relevant_docs] 233 | return sampled, was_relevant_sampled, relevant_doc_rank 234 | 235 | 236 | 237 | 238 | 239 | #chunk_size = 10000 240 | #train_part = pd.concat(list(ep.to_pandas_iter("train_part", chunk_size = chunk_size)), axis = 0) 241 | candidates = train_part["answer"].tolist() 242 | 243 | 244 | num_candidates_samples = 30 245 | embeddings_file = os.path.join(os.path.abspath(""), "train_sbert_emb_cache") 246 | sample_data = -1 247 | pre_trained_model = os.path.join(os.path.abspath(""), "bi_encoder_save") 248 | sbert_sampler = SentenceBERTNegativeSampler(candidates, num_candidates_samples, embeddings_file, sample_data, 249 | pre_trained_model) 250 | 251 | 252 | def part_gen_constructor(sampler, part_df): 253 | #question_neg_dict = {} 254 | for question, df in part_df.groupby("question"): 255 | pos_answer_list = df["answer"].tolist() 256 | negs = sbert_sampler.sample(question, pos_answer_list) 257 | #negs = sbert_sampler.sample(question, []) 258 | #neg_mg_df = pd.merge(train_part_tiny, pd.DataFrame(np.asarray(negs[0]).reshape([-1, 1]), columns = ["answer"]), on = "answer", how = "inner") 259 | #question_neg_dict[question] = neg_mg_df 260 | for pos_answer in pos_answer_list: 261 | yield InputExample(texts=[question, pos_answer], label=1) 262 | for neg_answer in negs[0]: 263 | yield InputExample(texts=[question, neg_answer], label=0) 264 | 265 | 266 | def json_save(input_collection, path): 267 | assert path.endswith(".json") 268 | assert type(input_collection) in [type({}), type(set([]))] 269 | with open(path, "w", encoding = "utf-8") as f: 270 | if type(input_collection) == type({}): 271 | #json.dump(input_collection, f, encoding = "utf-8") 272 | pass 273 | else: 274 | input_collection = {path.split("/")[-1].replace(".json", ""): list(input_collection)} 275 | json.dump(input_collection, f) 276 | print("save to {}".format(path)) 277 | 278 | 279 | def produce_question_answer_sample_in_file_format(part_gen, chunck_size = 1000, save_times = 1, sub_dir = None): 280 | question_index_dict = {} 281 | answer_index_dict = {} 282 | pos_tuple_set = set([]) 283 | neg_tuple_set = set([]) 284 | have_save = 0 285 | #for idx, item_ in enumerate(part_gen): 286 | idx = 0 287 | while True: 288 | item_ = part_gen.__next__() 289 | idx += 1 290 | question, answer = item_.texts 291 | if question not in question_index_dict: 292 | question_index_dict[question] = len(question_index_dict) 293 | if answer not in answer_index_dict: 294 | answer_index_dict[answer] = len(answer_index_dict) 295 | label = item_.label 296 | assert label in [0, 1] 297 | if label == 1: 298 | pos_tuple_set.add((question_index_dict[question], answer_index_dict[answer])) 299 | else: 300 | neg_tuple_set.add((question_index_dict[question], answer_index_dict[answer])) 301 | if sub_dir is not None and not os.path.exists(os.path.join(os.path.abspath(""), sub_dir)): 302 | assert type(sub_dir) == type("") and "/" not in sub_dir 303 | os.mkdir(os.path.join(os.path.abspath(""), sub_dir)) 304 | if (idx + 1) % chunck_size == 0: 305 | for c in ["question_index_dict", "answer_index_dict", "pos_tuple_set", "neg_tuple_set"]: 306 | if sub_dir is None: 307 | exec("json_save({}, '{}.json')".format(c, os.path.join(os.path.abspath(""), c))) 308 | else: 309 | exec("json_save({}, '{}.json')".format(c, os.path.join(os.path.abspath(""), sub_dir, c))) 310 | have_save += 1 311 | print("have_save in {} step".format(idx + 1)) 312 | if have_save >= save_times: 313 | return 314 | 315 | 316 | train_part_gen = part_gen_constructor(sbert_sampler, train_part) 317 | produce_question_answer_sample_in_file_format(train_part_gen, chunck_size = 10000, save_times = 10000, 318 | sub_dir = "train_file_faiss_10") 319 | 320 | 321 | 322 | 323 | -------------------------------------------------------------------------------- /script/cross_encoder/try_sbert_neg_sampler_valid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import json 4 | import logging 5 | import os 6 | import pickle 7 | import random 8 | import time 9 | import traceback 10 | from functools import reduce 11 | 12 | import faiss 13 | import numpy as np 14 | import pandas as pd 15 | import scipy.spatial 16 | import torch 17 | from elasticsearch import Elasticsearch, helpers 18 | from es_pandas import es_pandas 19 | from IPython import embed 20 | from sentence_transformers import InputExample, SentenceTransformer, util 21 | 22 | es_host = 'localhost:9200' 23 | 24 | class es_pandas_edit(es_pandas): 25 | @staticmethod 26 | def serialize(row, columns, use_pandas_json, iso_dates): 27 | if use_pandas_json: 28 | return json.dumps(dict(zip(columns, row)), iso_dates=iso_dates) 29 | return dict(zip(columns, [None if (all(pd.isna(r)) if (hasattr(r, "__len__") and type(r) != type("")) else pd.isna(r)) else r for r in row])) 30 | def to_pandas_iter(self, index, query_rule=None, heads=[], dtype={}, infer_dtype=False, show_progress=True, 31 | chunk_size = None, **kwargs): 32 | if query_rule is None: 33 | query_rule = {'query': {'match_all': {}}} 34 | count = self.es.count(index=index, body=query_rule)['count'] 35 | if count < 1: 36 | raise Exception('Empty for %s' % index) 37 | query_rule['_source'] = heads 38 | anl = helpers.scan(self.es, query=query_rule, index=index, **kwargs) 39 | source_iter = self.get_source(anl, show_progress = show_progress, count = count) 40 | print(source_iter) 41 | if chunk_size is None: 42 | df = pd.DataFrame(list(source_iter)).set_index('_id') 43 | if infer_dtype: 44 | dtype = self.infer_dtype(index, df.columns.values) 45 | if len(dtype): 46 | df = df.astype(dtype) 47 | yield df 48 | return 49 | assert type(chunk_size) == type(0) 50 | def map_list_of_dicts_into_df(list_of_dicts, set_index = "_id"): 51 | from collections import defaultdict 52 | req = defaultdict(list) 53 | for dict_ in list_of_dicts: 54 | for k, v in dict_.items(): 55 | req[k].append(v) 56 | req = pd.DataFrame.from_dict(req) 57 | if set_index: 58 | assert set_index in req.columns.tolist() 59 | t_df = req.set_index(set_index) 60 | if infer_dtype: 61 | dtype = self.infer_dtype(index, t_df.columns.values) 62 | if len(dtype): 63 | t_df = t_df.astype(dtype) 64 | return t_df 65 | list_of_dicts = [] 66 | for dict_ in source_iter: 67 | list_of_dicts.append(dict_) 68 | if len(list_of_dicts) >= chunk_size: 69 | yield map_list_of_dicts_into_df(list_of_dicts) 70 | list_of_dicts = [] 71 | if list_of_dicts: 72 | yield map_list_of_dicts_into_df(list_of_dicts) 73 | 74 | 75 | class SentenceBERTNegativeSampler(): 76 | """ 77 | Sample candidates from a list of candidates using dense embeddings from sentenceBERT. 78 | 79 | Args: 80 | candidates: list of str containing the candidates 81 | num_candidates_samples: int containing the number of negative samples for each query. 82 | embeddings_file: str containing the path to cache the embeddings. 83 | sample_data: int indicating amount of candidates in the index (-1 if all) 84 | pre_trained_model: str containing the pre-trained sentence embedding model, 85 | e.g. bert-base-nli-stsb-mean-tokens. 86 | """ 87 | def __init__(self, candidates, num_candidates_samples, embeddings_file, sample_data, 88 | pre_trained_model='bert-base-nli-stsb-mean-tokens', seed=42): 89 | random.seed(seed) 90 | self.candidates = candidates 91 | self.num_candidates_samples = num_candidates_samples 92 | self.pre_trained_model = pre_trained_model 93 | 94 | self.model = SentenceTransformer(self.pre_trained_model) 95 | #extract the name of the folder with the pre-trained sentence embedding 96 | if os.path.isdir(self.pre_trained_model): 97 | self.pre_trained_model = self.pre_trained_model.split("/")[-1] 98 | 99 | self.name = "SentenceBERTNS_"+self.pre_trained_model 100 | self.sample_data = sample_data 101 | self.embeddings_file = embeddings_file 102 | 103 | self._calculate_sentence_embeddings() 104 | self._build_faiss_index() 105 | 106 | def _calculate_sentence_embeddings(self): 107 | """ 108 | Calculates sentenceBERT embeddings for all candidates. 109 | """ 110 | embeds_file_path = "{}_n_sample_{}_pre_trained_model_{}".format(self.embeddings_file, 111 | self.sample_data, 112 | self.pre_trained_model) 113 | if not os.path.isfile(embeds_file_path): 114 | logging.info("Calculating embeddings for the candidates.") 115 | self.candidate_embeddings = self.model.encode(self.candidates, show_progress_bar=True) 116 | with open(embeds_file_path, 'wb') as f: 117 | pickle.dump(self.candidate_embeddings, f) 118 | else: 119 | with open(embeds_file_path, 'rb') as f: 120 | self.candidate_embeddings = pickle.load(f) 121 | 122 | def _build_faiss_index(self): 123 | """ 124 | Builds the faiss indexes containing all sentence embeddings of the candidates. 125 | """ 126 | self.index = faiss.IndexFlatL2(self.candidate_embeddings[0].shape[0]) # build the index 127 | self.index.add(np.array(self.candidate_embeddings)) 128 | logging.info("There is a total of {} candidates.".format(len(self.candidates))) 129 | logging.info("There is a total of {} candidate embeddings.".format(len(self.candidate_embeddings))) 130 | logging.info("Faiss index has a total of {} candidates".format(self.index.ntotal)) 131 | 132 | def sample(self, query_str, relevant_docs): 133 | """ 134 | Samples from a list of candidates using dot product sentenceBERT similarity. 135 | 136 | If the samples match the relevant doc, then removes it and re-samples randomly. 137 | The method uses faiss index to be efficient. 138 | 139 | Args: 140 | query_str: the str of the query to be used for the dense similarity matching. 141 | relevant_docs: list with the str of the relevant documents, to avoid sampling them as negative sample. 142 | 143 | Returns: 144 | A triplet containing the list of negative samples, 145 | whether the method had retrieved the relevant doc and 146 | if yes its rank in the list. 147 | """ 148 | query_embedding = self.model.encode([query_str], show_progress_bar=False) 149 | 150 | distances, idxs = self.index.search(np.array(query_embedding), self.num_candidates_samples) 151 | sampled_initial = [self.candidates[idx] for idx in idxs[0]] 152 | 153 | was_relevant_sampled = False 154 | relevant_doc_rank = -1 155 | sampled = [] 156 | for i, d in enumerate(sampled_initial): 157 | if d in relevant_docs: 158 | was_relevant_sampled = True 159 | relevant_doc_rank = i 160 | else: 161 | sampled.append(d) 162 | 163 | while len(sampled) != self.num_candidates_samples: 164 | sampled = sampled + [d for d in random.sample(self.candidates, self.num_candidates_samples-len(sampled)) 165 | if d not in relevant_docs] 166 | return sampled, was_relevant_sampled, relevant_doc_rank 167 | 168 | 169 | ep = es_pandas_edit(es_host) 170 | chunk_size = 10000 171 | valid_part = pd.concat(list(ep.to_pandas_iter("valid_part", chunk_size = chunk_size)), axis = 0) 172 | 173 | 174 | num_candidates_samples = 4 175 | embeddings_file = os.path.join(os.path.abspath(""), "valid_sbert_emb_cache") 176 | sample_data = -1 177 | pre_trained_model = os.path.join(os.path.abspath(""), "bi_encoder_save") 178 | sbert_sampler = SentenceBERTNegativeSampler(candidates, num_candidates_samples, embeddings_file, sample_data, 179 | pre_trained_model) 180 | 181 | def part_gen_constructor(sampler, part_df): 182 | #question_neg_dict = {} 183 | for question, df in part_df.groupby("question"): 184 | pos_answer_list = df["answer"].tolist() 185 | negs = sbert_sampler.sample(question, pos_answer_list) 186 | #negs = sbert_sampler.sample(question, []) 187 | #neg_mg_df = pd.merge(train_part_tiny, pd.DataFrame(np.asarray(negs[0]).reshape([-1, 1]), columns = ["answer"]), on = "answer", how = "inner") 188 | #question_neg_dict[question] = neg_mg_df 189 | for pos_answer in pos_answer_list: 190 | yield InputExample(texts=[question, pos_answer], label=1) 191 | for neg_answer in negs[0]: 192 | yield InputExample(texts=[question, neg_answer], label=0) 193 | 194 | def json_save(input_collection, path): 195 | assert path.endswith(".json") 196 | assert type(input_collection) in [type({}), type(set([]))] 197 | with open(path, "w", encoding = "utf-8") as f: 198 | if type(input_collection) == type({}): 199 | #json.dump(input_collection, f, encoding = "utf-8") 200 | pass 201 | else: 202 | input_collection = {path.split("/")[-1].replace(".json", ""): list(input_collection)} 203 | json.dump(input_collection, f) 204 | print("save to {}".format(path)) 205 | 206 | 207 | 208 | def produce_question_answer_sample_in_file_format(part_gen, chunck_size = 1000, save_times = 1, sub_dir = None): 209 | question_index_dict = {} 210 | answer_index_dict = {} 211 | pos_tuple_set = set([]) 212 | neg_tuple_set = set([]) 213 | have_save = 0 214 | #for idx, item_ in enumerate(part_gen): 215 | idx = 0 216 | while True: 217 | item_ = part_gen.__next__() 218 | idx += 1 219 | question, answer = item_.texts 220 | if question not in question_index_dict: 221 | question_index_dict[question] = len(question_index_dict) 222 | if answer not in answer_index_dict: 223 | answer_index_dict[answer] = len(answer_index_dict) 224 | label = item_.label 225 | assert label in [0, 1] 226 | if label == 1: 227 | pos_tuple_set.add((question_index_dict[question], answer_index_dict[answer])) 228 | else: 229 | neg_tuple_set.add((question_index_dict[question], answer_index_dict[answer])) 230 | if sub_dir is not None and not os.path.exists(os.path.join(os.path.abspath(""), sub_dir)): 231 | assert type(sub_dir) == type("") and "/" not in sub_dir 232 | os.mkdir(os.path.join(os.path.abspath(""), sub_dir)) 233 | if (idx + 1) % chunck_size == 0: 234 | for c in ["question_index_dict", "answer_index_dict", "pos_tuple_set", "neg_tuple_set"]: 235 | if sub_dir is None: 236 | exec("json_save({}, '{}.json')".format(c, os.path.join(os.path.abspath(""), c))) 237 | else: 238 | exec("json_save({}, '{}.json')".format(c, os.path.join(os.path.abspath(""), sub_dir, c))) 239 | have_save += 1 240 | print("have_save in {} step".format(idx + 1)) 241 | if have_save >= save_times: 242 | return 243 | 244 | 245 | valid_part_gen = part_gen_constructor(sbert_sampler, valid_part) 246 | 247 | 248 | 249 | produce_question_answer_sample_in_file_format(valid_part_gen, chunck_size = 3000, save_times = 10000, 250 | sub_dir = "valid_file_faiss") 251 | 252 | 253 | 254 | 255 | 256 | 257 | -------------------------------------------------------------------------------- /script/cross_encoder/valid_cross_encoder_on_bi_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import gzip 4 | import json 5 | import logging 6 | import os 7 | import tarfile 8 | import time 9 | from datetime import datetime 10 | from functools import partial, reduce 11 | from glob import glob 12 | from typing import Callable, Dict, List, Type 13 | 14 | import numpy as np 15 | import pandas as pd 16 | import seaborn as sns 17 | import torch 18 | import tqdm 19 | from elasticsearch import Elasticsearch, helpers 20 | from es_pandas import es_pandas 21 | from sentence_transformers import (InputExample, LoggingHandler, 22 | SentenceTransformer, util) 23 | from sentence_transformers.cross_encoder import CrossEncoder 24 | from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator 25 | from sklearn.metrics import balanced_accuracy_score 26 | from torch.utils.data import DataLoader 27 | 28 | pd.set_option("display.max_rows", 200) 29 | es_host = 'localhost:9200' 30 | 31 | 32 | bi_model_path = os.path.join(os.path.dirname("__file__"), os.path.pardir, "bi_encoder_save/") 33 | bi_model = SentenceTransformer(bi_model_path, device = "cpu") 34 | 35 | 36 | cross_model_path = "output/training_ms-marco_cross-encoder-xlm-roberta-base-2021-01-17_14-43-23_map-train-eval" 37 | cross_model = CrossEncoder(cross_model_path, num_labels=1, max_length=512, device = "cpu") 38 | 39 | 40 | class es_pandas_edit(es_pandas): 41 | @staticmethod 42 | def serialize(row, columns, use_pandas_json, iso_dates): 43 | if use_pandas_json: 44 | return json.dumps(dict(zip(columns, row)), iso_dates=iso_dates) 45 | return dict(zip(columns, [None if (all(pd.isna(r)) if (hasattr(r, "__len__") and type(r) != type("")) else pd.isna(r)) else r for r in row])) 46 | def to_pandas_iter(self, index, query_rule=None, heads=[], dtype={}, infer_dtype=False, show_progress=True, 47 | chunk_size = None, **kwargs): 48 | if query_rule is None: 49 | query_rule = {'query': {'match_all': {}}} 50 | count = self.es.count(index=index, body=query_rule)['count'] 51 | if count < 1: 52 | raise Exception('Empty for %s' % index) 53 | query_rule['_source'] = heads 54 | anl = helpers.scan(self.es, query=query_rule, index=index, **kwargs) 55 | source_iter = self.get_source(anl, show_progress = show_progress, count = count) 56 | print(source_iter) 57 | if chunk_size is None: 58 | df = pd.DataFrame(list(source_iter)).set_index('_id') 59 | if infer_dtype: 60 | dtype = self.infer_dtype(index, df.columns.values) 61 | if len(dtype): 62 | df = df.astype(dtype) 63 | yield df 64 | return 65 | assert type(chunk_size) == type(0) 66 | def map_list_of_dicts_into_df(list_of_dicts, set_index = "_id"): 67 | from collections import defaultdict 68 | req = defaultdict(list) 69 | for dict_ in list_of_dicts: 70 | for k, v in dict_.items(): 71 | req[k].append(v) 72 | req = pd.DataFrame.from_dict(req) 73 | if set_index: 74 | assert set_index in req.columns.tolist() 75 | t_df = req.set_index(set_index) 76 | if infer_dtype: 77 | dtype = self.infer_dtype(index, t_df.columns.values) 78 | if len(dtype): 79 | t_df = t_df.astype(dtype) 80 | return t_df 81 | list_of_dicts = [] 82 | for dict_ in source_iter: 83 | list_of_dicts.append(dict_) 84 | if len(list_of_dicts) >= chunk_size: 85 | yield map_list_of_dicts_into_df(list_of_dicts) 86 | list_of_dicts = [] 87 | if list_of_dicts: 88 | yield map_list_of_dicts_into_df(list_of_dicts) 89 | 90 | 91 | ep = es_pandas_edit(es_host) 92 | ep.ic.get_alias("*") 93 | 94 | chunk_size = 1000 95 | valid_part_from_es_iter = ep.to_pandas_iter(index = "valid_part", chunk_size = chunk_size) 96 | 97 | 98 | valid_part_tiny = None 99 | for ele in valid_part_from_es_iter: 100 | valid_part_tiny = ele 101 | break 102 | del valid_part_from_es_iter 103 | 104 | 105 | if ep.ic.exists("valid_part_tiny"): 106 | ep.ic.delete(index = "valid_part_tiny") 107 | 108 | 109 | ep.init_es_tmpl(valid_part_tiny, "valid_part_tiny_doc_type", delete=True) 110 | valid_part_tmp = ep.es.indices.get_template("valid_part_tiny_doc_type") 111 | 112 | 113 | es_index = valid_part_tmp["valid_part_tiny_doc_type"] 114 | es_index["mappings"]["properties"]["question_emb"] = { 115 | "type": "dense_vector", 116 | "dims": 768 117 | } 118 | es_index["mappings"]["properties"]["answer_emb"] = { 119 | "type": "dense_vector", 120 | "dims": 768 121 | } 122 | es_index["mappings"]["properties"]["question"] = { 123 | "type": "text", 124 | } 125 | es_index["mappings"]["properties"]["answer"] = { 126 | "type": "text", 127 | } 128 | es_index = {"mappings": es_index["mappings"]} 129 | 130 | 131 | ep.es.indices.create(index='valid_part_tiny', body=es_index, ignore=[400]) 132 | question_embeddings = bi_model.encode(valid_part_tiny["question"].tolist(), convert_to_tensor=True, show_progress_bar=True) 133 | answer_embeddings = bi_model.encode(valid_part_tiny["answer"].tolist(), convert_to_tensor=True, show_progress_bar=True) 134 | 135 | valid_part_tiny["question_emb"] = question_embeddings.cpu().numpy().tolist() 136 | valid_part_tiny["answer_emb"] = answer_embeddings.cpu().numpy().tolist() 137 | 138 | ep.to_es(valid_part_tiny, "valid_part_tiny") 139 | 140 | chunk_size = 1000 141 | valid_part_tiny = list(ep.to_pandas_iter(index = "valid_part_tiny", chunk_size = None))[0] 142 | 143 | 144 | def search_by_embedding_in_es(index = "valid_part" ,embedding = np.asarray(valid_part_tiny["question_emb"].iloc[0]), on_column = "answer_emb"): 145 | vector_search_one = ep.es.search(index=index, body={ 146 | "query": { 147 | "script_score": { 148 | "query": { 149 | "match_all": {} 150 | }, 151 | "script": { 152 | "source": "cosineSimilarity(params.queryVector, doc['{}']) + 1.0".format(on_column), 153 | "params": { 154 | "queryVector": embedding 155 | } 156 | } 157 | } 158 | } 159 | }, ignore = [400]) 160 | req = list(map(lambda x: (x["_source"]["question"], x["_source"]["answer"], x["_score"]) ,vector_search_one["hits"]["hits"])) 161 | req_df = pd.DataFrame(req, columns = ["question", "answer", "score"]) 162 | return req_df 163 | 164 | 165 | def search_by_text_in_es(index = "valid_part" ,text = valid_part_tiny["question"].iloc[0], on_column = "answer", 166 | analyzer = "smartcn"): 167 | if analyzer is not None: 168 | bm25 = es.search(index = index, 169 | body={"query": 170 | { 171 | "match": {on_column:{"query" :text, "analyzer": analyzer} }, 172 | 173 | } 174 | }, 175 | ) 176 | else: 177 | bm25 = ep.es.search(index=index, body={"query": {"match": {on_column: text}}}) 178 | req = list(map(lambda x: (x["_source"]["question"], x["_source"]["answer"], x["_score"]) ,bm25["hits"]["hits"])) 179 | req_df = pd.DataFrame(req, columns = ["question", "answer", "score"]) 180 | return req_df 181 | 182 | 183 | def valid_two_model(cross_model, ep, index, question, question_embedding, on_column = "answer_emb", size = 10): 184 | def search_by_embedding(ep ,index = "valid_part" ,embedding = np.asarray(valid_part_tiny["question_emb"].iloc[0]), on_column = "answer_emb"): 185 | vector_search_one = ep.es.search(index=index, body={ 186 | "size": size, 187 | "query": { 188 | "script_score": { 189 | "query": { 190 | "match_all": {} 191 | }, 192 | "script": { 193 | "source": "cosineSimilarity(params.queryVector, doc['{}']) + 1.0".format(on_column), 194 | "params": { 195 | "queryVector": embedding 196 | } 197 | } 198 | } 199 | } 200 | }, ignore = [400]) 201 | req = list(map(lambda x: (x["_source"]["question"], x["_source"]["answer"], x["_score"]) ,vector_search_one["hits"]["hits"])) 202 | req_df = pd.DataFrame(req, columns = ["question", "answer", "score"]) 203 | return req_df 204 | search_by_emb = search_by_embedding(ep ,index = index, embedding = question_embedding, on_column = on_column) 205 | print("question : {}".format(question)) 206 | preds = cross_model.predict(search_by_emb.apply(lambda r: [question, r["answer"]], axis = 1).tolist()) 207 | search_by_emb["cross_score"] = preds.tolist() 208 | return search_by_emb 209 | def produce_df(question, size = 10): 210 | question, question_embedding = valid_part_tiny[valid_part_tiny["question"] == question].iloc[0][["question", "question_emb"]] 211 | valid_df = valid_two_model(cross_model, ep, index = "valid_part_tiny", question = question, question_embedding = question_embedding, size = size) 212 | return valid_df 213 | 214 | 215 | class ScoreCalculator(object): 216 | def __init__(self, 217 | queries_ids, 218 | relevant_docs, 219 | mrr_at_k: List[int] = [10], 220 | ndcg_at_k: List[int] = [10], 221 | accuracy_at_k: List[int] = [1, 3, 5, 10], 222 | precision_recall_at_k: List[int] = [1, 3, 5, 10], 223 | map_at_k: List[int] = [100], 224 | ): 225 | "queries_ids list of query, relevant_docs key query value set or list of relevant_docs" 226 | self.queries_ids = queries_ids 227 | self.relevant_docs = relevant_docs 228 | 229 | self.mrr_at_k = mrr_at_k 230 | self.ndcg_at_k = ndcg_at_k 231 | self.accuracy_at_k = accuracy_at_k 232 | self.precision_recall_at_k = precision_recall_at_k 233 | self.map_at_k = map_at_k 234 | def compute_metrics(self, queries_result_list: List[object]): 235 | # Init score computation values 236 | num_hits_at_k = {k: 0 for k in self.accuracy_at_k} 237 | precisions_at_k = {k: [] for k in self.precision_recall_at_k} 238 | recall_at_k = {k: [] for k in self.precision_recall_at_k} 239 | MRR = {k: 0 for k in self.mrr_at_k} 240 | ndcg = {k: [] for k in self.ndcg_at_k} 241 | AveP_at_k = {k: [] for k in self.map_at_k} 242 | 243 | # Compute scores on results 244 | #### elements with hits one hit is a dict : {"corpus_id": corpus_text, "score": score} 245 | #### corpus_id replace by corpus text 246 | for query_itr in range(len(queries_result_list)): 247 | query_id = self.queries_ids[query_itr] 248 | 249 | # Sort scores 250 | top_hits = sorted(queries_result_list[query_itr], key=lambda x: x['score'], reverse=True) 251 | query_relevant_docs = self.relevant_docs[query_id] 252 | 253 | # Accuracy@k - We count the result correct, if at least one relevant doc is accross the top-k documents 254 | for k_val in self.accuracy_at_k: 255 | for hit in top_hits[0:k_val]: 256 | if hit['corpus_id'] in query_relevant_docs: 257 | num_hits_at_k[k_val] += 1 258 | break 259 | 260 | # Precision and Recall@k 261 | for k_val in self.precision_recall_at_k: 262 | num_correct = 0 263 | for hit in top_hits[0:k_val]: 264 | if hit['corpus_id'] in query_relevant_docs: 265 | num_correct += 1 266 | 267 | precisions_at_k[k_val].append(num_correct / k_val) 268 | recall_at_k[k_val].append(num_correct / len(query_relevant_docs)) 269 | 270 | # MRR@k 271 | for k_val in self.mrr_at_k: 272 | for rank, hit in enumerate(top_hits[0:k_val]): 273 | if hit['corpus_id'] in query_relevant_docs: 274 | MRR[k_val] += 1.0 / (rank + 1) 275 | #break 276 | 277 | # NDCG@k 278 | for k_val in self.ndcg_at_k: 279 | predicted_relevance = [1 if top_hit['corpus_id'] in query_relevant_docs else 0 for top_hit in top_hits[0:k_val]] 280 | true_relevances = [1] * len(query_relevant_docs) 281 | 282 | ndcg_value = self.compute_dcg_at_k(predicted_relevance, k_val) / self.compute_dcg_at_k(true_relevances, k_val) 283 | ndcg[k_val].append(ndcg_value) 284 | 285 | # MAP@k 286 | for k_val in self.map_at_k: 287 | num_correct = 0 288 | sum_precisions = 0 289 | 290 | for rank, hit in enumerate(top_hits[0:k_val]): 291 | if hit['corpus_id'] in query_relevant_docs: 292 | num_correct += 1 293 | sum_precisions += num_correct / (rank + 1) 294 | 295 | avg_precision = sum_precisions / min(k_val, len(query_relevant_docs)) 296 | AveP_at_k[k_val].append(avg_precision) 297 | 298 | # Compute averages 299 | for k in num_hits_at_k: 300 | #num_hits_at_k[k] /= len(self.queries) 301 | num_hits_at_k[k] /= len(queries_result_list) 302 | 303 | for k in precisions_at_k: 304 | precisions_at_k[k] = np.mean(precisions_at_k[k]) 305 | 306 | for k in recall_at_k: 307 | recall_at_k[k] = np.mean(recall_at_k[k]) 308 | 309 | for k in ndcg: 310 | ndcg[k] = np.mean(ndcg[k]) 311 | 312 | for k in MRR: 313 | #MRR[k] /= len(self.queries) 314 | MRR[k] /= len(queries_result_list) 315 | 316 | for k in AveP_at_k: 317 | AveP_at_k[k] = np.mean(AveP_at_k[k]) 318 | return {'accuracy@k': num_hits_at_k, 'precision@k': precisions_at_k, 'recall@k': recall_at_k, 'ndcg@k': ndcg, 'mrr@k': MRR, 'map@k': AveP_at_k} 319 | @staticmethod 320 | def compute_dcg_at_k(relevances, k): 321 | dcg = 0 322 | for i in range(min(len(relevances), k)): 323 | dcg += relevances[i] / np.log2(i + 2) #+2 as we start our idx at 0 324 | return dcg 325 | 326 | 327 | def map_dev_samples_to_score_calculator_format(dev_samples): 328 | if isinstance(dev_samples, dict): 329 | dev_samples = list(dev_samples.values()) 330 | queries_ids = list(map(lambda x: x["query"] ,dev_samples)) 331 | relevant_docs = dict(map(lambda idx: (dev_samples[idx]["query"], dev_samples[idx]["positive"]), range(len(dev_samples)))) 332 | return ScoreCalculator(queries_ids, relevant_docs) 333 | 334 | def map_valid_df_to_score_calculator_format(query ,valid_df): 335 | queries_ids = [query] 336 | relevant_docs = {query: valid_df[valid_df["question"] == query]["answer"].tolist()} 337 | return ScoreCalculator(queries_ids, relevant_docs) 338 | 339 | 340 | def df_to_mrr_score(df, query, score_col, mrr_at_k = 10): 341 | #model_input = [[query, doc] for doc in docs] 342 | #pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False) 343 | is_relevant = list(map(lambda t2: True if t2[1]["question"] == query else False, df.iterrows())) 344 | pred_scores = df[score_col].values 345 | pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order 346 | mrr_score = 0 347 | for rank, index in enumerate(pred_scores_argsort[0:mrr_at_k]): 348 | if is_relevant[index]: 349 | mrr_score = 1 / (rank+1) 350 | #mrr_score += 1 / (rank+1) 351 | break 352 | return mrr_score 353 | 354 | 355 | question_list = valid_part_tiny["question"].value_counts().index.tolist() 356 | valid_df = produce_df(question_list[10], size = 100) 357 | 358 | def produce_score_dict(query ,valid_df, column = "score"): 359 | queries_result_list = valid_df[["answer", column]].apply(lambda x: {"corpus_id": x["answer"], "score": x[column]}, axis = 1).tolist() 360 | score_dict = map_valid_df_to_score_calculator_format(query, valid_df).compute_metrics([queries_result_list]) 361 | return score_dict 362 | 363 | produce_score_dict(question_list[10] ,valid_df, "score") 364 | produce_score_dict(question_list[10] ,valid_df, "cross_score") 365 | produce_score_dict(question_list[10] ,valid_df.head(20), "score") 366 | produce_score_dict(question_list[10] ,valid_df.head(20), "cross_score") 367 | 368 | valid_df.head(20) 369 | valid_df.head(20).sort_values(by = "cross_score", ascending = False) 370 | valid_df.sort_values(by = "cross_score", ascending = False).head(10) 371 | 372 | sns.distplot(valid_df["cross_score"]) 373 | 374 | 375 | --------------------------------------------------------------------------------