├── .spyproject
├── codestyle.ini
├── encoding.ini
├── vcs.ini
└── workspace.ini
├── README.md
├── README_EN.md
├── requirements.txt
└── script
├── bi_encoder
├── bi-encoder-batch.py
└── bi-encoder-data.py
└── cross_encoder
├── cross_encoder_random_on_multi_eval.py
├── try_sbert_neg_sampler.py
├── try_sbert_neg_sampler_valid.py
└── valid_cross_encoder_on_bi_encoder.py
/.spyproject/codestyle.ini:
--------------------------------------------------------------------------------
1 | [codestyle]
2 | indentation = True
3 |
4 | [main]
5 | version = 0.1.0
6 |
7 |
--------------------------------------------------------------------------------
/.spyproject/encoding.ini:
--------------------------------------------------------------------------------
1 | [encoding]
2 | text_encoding = utf-8
3 |
4 | [main]
5 | version = 0.1.0
6 |
7 |
--------------------------------------------------------------------------------
/.spyproject/vcs.ini:
--------------------------------------------------------------------------------
1 | [vcs]
2 | use_version_control = False
3 | version_control_system =
4 |
5 | [main]
6 | version = 0.1.0
7 |
8 |
--------------------------------------------------------------------------------
/.spyproject/workspace.ini:
--------------------------------------------------------------------------------
1 | [workspace]
2 | restore_data_on_startup = True
3 | save_data_on_exit = True
4 | save_history = True
5 | save_non_project_files = False
6 |
7 | [main]
8 | version = 0.1.0
9 | recent_files = ['/home/svjack/.config/spyder-py3/temp.py', '/home/svjack/temp_dir/bi_cross_model/script/bi-encoder-batch.py', '/home/svjack/temp_dir/bi_cross_model/script/bi-encoder-data.py', '/home/svjack/temp_dir/bi_cross_model/script/bi_encoder/bi-encoder-batch.py', '/home/svjack/temp_dir/bi_cross_model/script/bi_encoder/bi-encoder-data.py', '/home/svjack/temp_dir/bi_cross_model/download.sh', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder_data_prepare_train.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder_data_prepare.py', '/home/svjack/temp_dir/bi_cross_model/script/try_sbert_neg_sampler.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_enccoder_v0.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/try_sbert_neg_sampler.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/try_sbert_neg_sampler_valid.py', '/home/svjack/temp_dir/bi_cross_model/script/choose_right_params.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder_random_on_multi_eval.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/cross_encoder_random_on_multi_eval.py', '/home/svjack/temp_dir/bi_cross_model/script/cross_encoder/valid_cross_encoder_on_bi_encoder.py']
10 |
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
8 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
27 |
28 |
36 |
37 |
38 |
39 |
40 |
41 |
Sbert-ChineseExample
42 |
43 |
44 | Sentence-Transformers 中文信息检索例子
45 |
46 |
47 |
48 |
49 |
50 |
51 | [In English](README_EN.md)
52 |
53 | ## 内容提要
54 | * [有关这个工程](#about-the-project)
55 | * [构建信息](#built-with)
56 | * [开始](#getting-started)
57 | * [安装](#installation)
58 | * [使用](#usage)
59 | * [引导](#roadmap)
60 | * [贡献](#contributing)
61 | * [License](#license)
62 | * [Contact](#contact)
63 | * [Acknowledgements](#acknowledgements)
64 |
65 |
66 | ## 关于这个工程
67 | ## About The Project
68 |
69 |
72 |
73 |
74 |
77 |
102 | Sentence Transformers是一个多语言、多模态句子向量生成框架,可以根据Huggingface Transformers框架简单地生成句子及文本段落的分布式向量表征。
103 |
104 | 这个工程的目的是通过训练bi_encoder和cross_encoder实现类似于ms_macro任务的中文数据集信息检索,并搭配定制化的pandas形式的elasticsearch接口使得结果产出(文本、向量)可以方便地序列化。
105 |
106 |
116 | ### 构建信息
117 | ### Built With
118 |
131 | * [Sentence Transformers](https://github.com/UKPLab/sentence-transformers)
132 | * [Elasticsearch](https://github.com/elastic/elasticsearch)
133 | * [Faiss](https://github.com/facebookresearch/faiss)
134 |
135 |
136 | ## 开始
137 | ## Getting Started
138 |
142 |
143 |
152 | ### 安装
153 | ### Installation
154 | * pip
155 | ```sh
156 | pip install -r requirements.txt
157 | ```
158 | * 安装Elasticsearch并启动服务
159 | * install Elasticsearch and start service
160 |
161 |
162 |
163 | ## 使用
164 | ## Usage
165 |
170 |
171 |
172 |
173 | 1. 从 google drive 下载数据集
174 |
175 |
176 |
177 | 2. bi_encoder 数据准备
178 |
179 |
180 |
181 | 3. 训练 bi_encoder
182 |
183 |
184 |
185 | 4. cross_encoder 训练数据准备
186 |
187 |
188 |
189 | 5. cross_encoder 检测数据准备
190 |
191 |
192 |
193 | 6. 训练 cross_encoder
194 |
195 |
196 |
197 | 7. 展示 bi_encoder cross_encoder 的推断过程
198 |
199 |
200 |
201 | ## 引导
202 | ## Roadmap
203 |
206 |
212 |
213 | * 1 这个工程使用自定义的 es-pandas 的重载接口 (支持向量存储) 来使用pandas对于elasticsearch实现简单的操作。
214 |
215 | * 2 try_sbert_neg_sampler.py 抽取困难样本(模型识别困难的样本)的功能来自于
216 | https://guzpenha.github.io/transformer_rankers/,
217 | 也可以使用 elasticsearch 生成困难样本, 相应的功能在 valid_cross_encoder_on_bi_encoder.py 中定义。
218 |
219 | * 3 上面在 cross_encoder 上训练的功能, 需要预先在不同的句子间检查语义区别程度,
220 | 组合相似语义的样本对于模型训练是有帮助的。
221 |
222 | * 4 增加了一些对Sentence-Transformers多类别结果比较的工具。
223 |
224 |
225 |
226 | ## 贡献
227 | ## Contributing
228 |
237 |
238 |
239 |
240 | ## License
241 |
242 | Distributed under the MIT License. See `LICENSE` for more information.
243 |
244 |
245 |
246 |
247 | ## Contact
248 |
249 |
252 | svjack - svjackbt@gmail.com
253 | ehangzhou@outlook.com
254 |
255 |
258 | Project Link: [https://github.com/svjack/Sbert-ChineseExample](https://github.com/svjack/Sbert-ChineseExample)
259 |
260 |
261 |
262 | ## Acknowledgements
263 |
276 | * [Sentence Transformers](https://github.com/UKPLab/sentence-transformers)
277 | * [Elasticsearch](https://github.com/elastic/elasticsearch)
278 | * [Faiss](https://github.com/facebookresearch/faiss)
279 | * [Transformer Rankers](https://github.com/Guzpenha/transformer_rankers)
280 | * [es_pandas](https://github.com/fuyb1992/es_pandas)
281 |
282 |
283 |
284 |
285 | [contributors-shield]: https://img.shields.io/github/contributors/othneildrew/Best-README-Template.svg?style=flat-square
286 | [contributors-url]: https://github.com/othneildrew/Best-README-Template/graphs/contributors
287 | [forks-shield]: https://img.shields.io/github/forks/othneildrew/Best-README-Template.svg?style=flat-square
288 | [forks-url]: https://github.com/othneildrew/Best-README-Template/network/members
289 | [stars-shield]: https://img.shields.io/github/stars/othneildrew/Best-README-Template.svg?style=flat-square
290 | [stars-url]: https://github.com/othneildrew/Best-README-Template/stargazers
291 | [issues-shield]: https://img.shields.io/github/issues/othneildrew/Best-README-Template.svg?style=flat-square
292 | [issues-url]: https://github.com/othneildrew/Best-README-Template/issues
293 | [license-shield]: https://img.shields.io/github/license/othneildrew/Best-README-Template.svg?style=flat-square
294 | [license-url]: https://github.com/othneildrew/Best-README-Template/blob/master/LICENSE.txt
295 | [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=flat-square&logo=linkedin&colorB=555
296 | [linkedin-url]: https://linkedin.com/in/othneildrew
297 | [product-screenshot]: images/screenshot.png
298 |
--------------------------------------------------------------------------------
/README_EN.md:
--------------------------------------------------------------------------------
1 |
8 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
27 |
28 |
36 |
37 |
38 |
39 |
40 |
41 |
46 |
47 |
51 |
Sbert-ChineseExample
52 |
53 |
54 |
55 |
56 |
57 | Sentence-Transformers Information Retrieval example on Chinese
58 |
59 |
62 |
63 |
64 |
68 |
74 |
75 |
76 |
77 | [中文介绍](README.md)
78 |
79 |
80 | ## Table of Contents
81 |
82 | * [About the Project](#about-the-project)
83 | * [Built With](#built-with)
84 | * [Getting Started](#getting-started)
85 |
88 | * [Installation](#installation)
89 | * [Usage](#usage)
90 | * [Roadmap](#roadmap)
91 | * [Contributing](#contributing)
92 | * [License](#license)
93 | * [Contact](#contact)
94 | * [Acknowledgements](#acknowledgements)
95 |
96 |
97 |
98 |
99 | ## About The Project
100 |
101 |
104 |
105 |
106 |
109 |
134 | Sentence Transformers is a multilingual sentence embedding generate framework, which provides an easy method to compute dense
135 | vector representations for sentences and paragraphs (based on HuggingFace Transformers)
136 |
137 | This repository target at ms_macro like task on a Chinese dataset, train bi_encoder and cross_encoder, with the help of
138 | elasticsearch easy interface on pandas to build serlizable conclusion.
139 |
140 |
150 | ### Built With
151 |
164 | * [Sentence Transformers](https://github.com/UKPLab/sentence-transformers)
165 | * [Elasticsearch](https://github.com/elastic/elasticsearch)
166 | * [Faiss](https://github.com/facebookresearch/faiss)
167 |
168 |
169 | ## Getting Started
170 |
174 |
175 |
184 |
185 | ### Installation
186 | * pip
187 | ```sh
188 | pip install -r requirements.txt
189 | ```
190 | * install Elasticsearch and start service
191 |
206 |
207 |
208 |
209 | ## Usage
210 |
215 |
216 |
217 |
218 | 1. Download Data from google drive
219 |
220 |
221 |
222 | 2. bi_encoder data prepare
223 |
224 |
225 |
226 | 3. train bi_encoder
227 |
228 |
229 |
230 | 4. cross_encoder train data prepare
231 |
232 |
233 |
234 | 5. cross_encoder valid data prepare
235 |
236 |
237 |
238 | 6. train cross_encoder
239 |
240 |
241 |
242 | 7. show bi_encoder cross_encoder inference
243 |
244 |
245 |
246 | ## Roadmap
247 |
250 |
256 |
257 | * 1 This repository use edited es-pandas interface (support vector serlized) to have a simple manipulate on elasticsearch by pandas.
258 |
259 | * 2 try_sbert_neg_sampler.py sample hard negative samples drived from class provide by
260 | https://guzpenha.github.io/transformer_rankers/
261 | can also use elastic search to generate hard samples , relate functions have defined in valid_cross_encoder_on_bi_encoder.py
262 |
263 | * 3 Before training your dataset on cross_encoder, should take a look at the semantic similarity between different questions.
264 | Combine some samples with similar semantic may give help.
265 |
266 | * 4 Add some toolkit to Sbert to support multi-class-evaluation (as dictionary)
267 |
268 | ## Contributing
269 |
278 |
279 |
280 |
281 | ## License
282 |
283 | Distributed under the MIT License. See `LICENSE` for more information.
284 |
285 |
286 |
287 |
288 | ## Contact
289 |
290 |
293 | svjack - svjackbt@gmail.com
294 |
295 |
298 | Project Link: [https://github.com/svjack/Sbert-ChineseExample](https://github.com/svjack/Sbert-ChineseExample)
299 |
300 |
301 |
302 | ## Acknowledgements
303 |
316 | * [Sentence Transformers](https://github.com/UKPLab/sentence-transformers)
317 | * [Elasticsearch](https://github.com/elastic/elasticsearch)
318 | * [Faiss](https://github.com/facebookresearch/faiss)
319 | * [Transformer Rankers](https://github.com/Guzpenha/transformer_rankers)
320 | * [es_pandas](https://github.com/fuyb1992/es_pandas)
321 |
322 |
323 |
324 |
325 | [contributors-shield]: https://img.shields.io/github/contributors/othneildrew/Best-README-Template.svg?style=flat-square
326 | [contributors-url]: https://github.com/othneildrew/Best-README-Template/graphs/contributors
327 | [forks-shield]: https://img.shields.io/github/forks/othneildrew/Best-README-Template.svg?style=flat-square
328 | [forks-url]: https://github.com/othneildrew/Best-README-Template/network/members
329 | [stars-shield]: https://img.shields.io/github/stars/othneildrew/Best-README-Template.svg?style=flat-square
330 | [stars-url]: https://github.com/othneildrew/Best-README-Template/stargazers
331 | [issues-shield]: https://img.shields.io/github/issues/othneildrew/Best-README-Template.svg?style=flat-square
332 | [issues-url]: https://github.com/othneildrew/Best-README-Template/issues
333 | [license-shield]: https://img.shields.io/github/license/othneildrew/Best-README-Template.svg?style=flat-square
334 | [license-url]: https://github.com/othneildrew/Best-README-Template/blob/master/LICENSE.txt
335 | [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=flat-square&logo=linkedin&colorB=555
336 | [linkedin-url]: https://linkedin.com/in/othneildrew
337 | [product-screenshot]: images/screenshot.png
338 |
339 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | editdistance==0.5.3
2 | elasticsearch==7.8.1
3 | elasticsearch-dbapi==0.1.3
4 | es-pandas==0.0.16
5 | progressbar2==3.53.1
6 | seaborn==0.10.1
7 | sentence-transformers==0.3.9
8 |
--------------------------------------------------------------------------------
/script/bi_encoder/bi-encoder-batch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import gzip
4 | import logging
5 | import math
6 | import os
7 | import random
8 | import tarfile
9 | from collections import defaultdict
10 | from datetime import datetime
11 | from glob import glob
12 |
13 | import numpy as np
14 | import pandas as pd
15 | import torch
16 | from sentence_transformers import (LoggingHandler, SentenceTransformer,
17 | evaluation, losses, models, util)
18 | from torch.utils.data import DataLoader, Dataset, IterableDataset
19 |
20 | #train_part, test_part, valid_part = map(lambda save_type: pd.read_csv(os.path.join(os.path.abspath(""), "{}_part.csv".format(save_type))).dropna(), ["train", "test", "valid"])
21 | train_part, test_part, valid_part = map(lambda save_type: pd.read_csv(os.path.join("..data/", "{}_part.csv".format(save_type))).dropna(), ["train", "test", "valid"])
22 |
23 | from sentence_transformers import InputExample
24 | class TripletsDataset(Dataset):
25 | def __init__(self, model, qa_df):
26 | assert set(["question", "answer", "q_idx"]).intersection(set(qa_df.columns.tolist())) == set(["question", "answer", "q_idx"])
27 | self.model = model
28 | self.qa_df = qa_df
29 | self.q_idx_set = set(qa_df["q_idx"].value_counts().index.tolist())
30 |
31 | def __getitem__(self, index):
32 | #raise NotImplementedError
33 | label = torch.tensor(1, dtype=torch.long)
34 | choice_s = self.qa_df.iloc[index]
35 | query_text, pos_text, q_idx = choice_s.loc["question"], choice_s.loc["answer"], choice_s.loc["q_idx"]
36 | query_text, pos_text, q_idx = choice_s.loc["question"], choice_s.loc["answer"], choice_s.loc["q_idx"]
37 | neg_q_idx = np.random.choice(list(self.q_idx_set.difference(set([q_idx]))))
38 | neg_text = self.qa_df[self.qa_df["q_idx"] == neg_q_idx].sample()["answer"].iloc[0]
39 | #### InputExample(texts=['I can\'t log in to my account.',
40 | #'Unable to access my account.',
41 | #'I need help with the payment process.'],
42 | #label=1),
43 | return InputExample(texts = [query_text, pos_text, neg_text], label = 1)
44 | '''
45 | return [self.model.tokenize(query_text),
46 | self.model.tokenize(pos_text),
47 | self.model.tokenize(neg_text)], label
48 | '''
49 | #return (query_text, pos_text, q_idx)
50 |
51 | def __len__(self):
52 | return self.qa_df.shape[0]
53 |
54 |
55 | class NoSameLabelsBatchSampler:
56 | def __init__(self, dataset, batch_size):
57 | self.dataset = dataset
58 | self.idx_org = list(range(len(dataset)))
59 | random.shuffle(self.idx_org)
60 | self.idx_copy = self.idx_org.copy()
61 | self.batch_size = batch_size
62 |
63 | def __iter__(self):
64 | batch = []
65 | labels = set()
66 | num_miss = 0
67 |
68 | num_batches_returned = 0
69 | while num_batches_returned < self.__len__():
70 | if len(self.idx_copy) == 0:
71 | random.shuffle(self.idx_org)
72 | self.idx_copy = self.idx_org.copy()
73 |
74 | idx = self.idx_copy.pop(0)
75 | #label = self.dataset[idx][1].cpu().tolist()
76 | label = self.dataset.qa_df["q_idx"].iloc[idx]
77 |
78 | if label not in labels:
79 | num_miss = 0
80 | batch.append(idx)
81 | labels.add(label)
82 | if len(batch) == self.batch_size:
83 | yield batch
84 | batch = []
85 | labels = set()
86 | num_batches_returned += 1
87 | else:
88 | num_miss += 1
89 | self.idx_copy.append(idx) #Add item again to the end
90 |
91 | if num_miss >= len(self.idx_copy): #To many failures, flush idx_copy and start with clean
92 | self.idx_copy = []
93 |
94 | def __len__(self):
95 | return math.ceil(len(self.dataset) / self.batch_size)
96 |
97 |
98 | def transform_part_df_into_Evaluator_format(part_df):
99 | req = part_df.copy()
100 | req["qid"] = req["question"].fillna("").map(hash).map(str)
101 | req["cid"] = req["answer"].fillna("").map(hash).map(str)
102 | queries = dict(map(tuple ,req[["qid", "question"]].drop_duplicates().values.tolist()))
103 | corpus = dict(map(tuple ,req[["cid", "answer"]].drop_duplicates().values.tolist()))
104 | qid_cid_set_df = req[["qid", "cid"]].groupby("qid")["cid"].apply(set).apply(sorted).apply(tuple).reset_index()
105 | qid_cid_set_df.columns = ["qid", "cid_set"]
106 | relevant_docs = dict(map(tuple ,qid_cid_set_df.drop_duplicates().values.tolist()))
107 | relevant_docs = dict(map(lambda t2: (t2[0], set(t2[1])) ,relevant_docs.items()))
108 | return queries, corpus, relevant_docs
109 |
110 |
111 | dev_queries, dev_corpus, dev_rel_docs = transform_part_df_into_Evaluator_format(valid_part.sample(frac=0.1))
112 | ir_evaluator = evaluation.InformationRetrievalEvaluator(dev_queries, dev_corpus, dev_rel_docs, name='ms-marco-train_eval', batch_size=2)
113 |
114 |
115 |
116 | model_str = "xlm-roberta-base"
117 | #word_embedding_model = models.Transformer(model_str, max_seq_length=512)
118 | word_embedding_model = models.Transformer(model_str, max_seq_length=256)
119 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
120 | model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
121 |
122 |
123 |
124 | train_dataset = TripletsDataset(model=model, qa_df = train_part.sample(frac = 1.0, replace=False))
125 | bs_obj = NoSameLabelsBatchSampler(train_dataset, batch_size=8)
126 | train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=1, batch_sampler=bs_obj, num_workers=1)
127 | train_loss = losses.MultipleNegativesRankingLoss(model=model)
128 |
129 |
130 | model_save_path = os.path.join(os.path.abspath(""), "bi_encoder_save")
131 | if not os.path.exists(model_save_path):
132 | os.mkdir(model_save_path)
133 |
134 |
135 | model.fit(train_objectives=[(train_dataloader, train_loss)],
136 | evaluator=ir_evaluator,
137 | epochs=10,
138 | warmup_steps=1000,
139 | output_path=model_save_path,
140 | evaluation_steps=5000,
141 | use_amp=True
142 | )
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
--------------------------------------------------------------------------------
/script/bi_encoder/bi-encoder-data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import os
4 | from copy import deepcopy
5 | from functools import reduce
6 | from glob import glob
7 |
8 | import editdistance
9 | import numpy as np
10 | import pandas as pd
11 |
12 | ###https://github.com/brightmart/nlp_chinese_corpus
13 | ###https://github.com/brightmart/nlp_chinese_corpus#4%E7%A4%BE%E5%8C%BA%E9%97%AE%E7%AD%94json%E7%89%88webtext2019zh-%E5%A4%A7%E8%A7%84%E6%A8%A1%E9%AB%98%E8%B4%A8%E9%87%8F%E6%95%B0%E6%8D%AE%E9%9B%86
14 | ###https://drive.google.com/open?id=1u2yW_XohbYL2YAK6Bzc5XrngHstQTf0v
15 |
16 | data_dir = r"/home/svjack/temp_dir/webtext2019zh"
17 | json_files = glob(os.path.join(data_dir, "*.json"))
18 | train_json = list(filter(lambda path: "train" in path.lower(), json_files))[0]
19 | def json_reader(path, chunksize = 100):
20 | assert path.endswith(".json")
21 | return pd.read_json(path, lines = True, chunksize=chunksize)
22 |
23 | train_reader = json_reader(train_json, chunksize=10000)
24 | times = 100
25 | df_list = []
26 | for i, df in enumerate(train_reader):
27 | df_list.append(df)
28 | if i + 1 >= times:
29 | break
30 |
31 | train_head_df = pd.concat(df_list, axis = 0)
32 | content_len_df = pd.concat([train_head_df["content"], train_head_df["content"].map(len)], axis = 1)
33 | content_len_df.columns = ["content", "c_len"]
34 |
35 |
36 | qa_df = train_head_df[["title", "content"]].copy()
37 | qa_df = qa_df.rename(columns = {"title": "question", "content": "answer"}).fillna("")
38 |
39 | qa_df = qa_df[qa_df["question"].map(len) <= 500]
40 | qa_df = qa_df[qa_df["answer"].map(len) <= 500]
41 |
42 |
43 | quests = deepcopy(qa_df["question"])
44 | question_cmp = pd.concat([quests.sort_values().shift(1), quests.sort_values()], axis = 1)
45 | question_cmp["edit_val"] = question_cmp.fillna("").apply(lambda s: editdistance.eval(s.iloc[0], s.iloc[1]) / (len(s.iloc[0]) + len(s.iloc[1])), axis = 1)
46 | question_cmp.columns = ["q0", "q1", "edit_val"]
47 |
48 | threshold = 0.2
49 | question_nest_list = [[]]
50 | for idx ,r in question_cmp.iterrows():
51 | q0, q1, v = r.iloc[0], r.iloc[1], r.iloc[2]
52 | if v < threshold:
53 | question_nest_list[-1].append(q0)
54 | question_nest_list[-1].append(q1)
55 | else:
56 | question_nest_list.append([])
57 |
58 |
59 | idx_question_df_zip = pd.DataFrame(list(map(lambda x: [x] ,question_nest_list)))
60 |
61 | idx_question_df_zip = idx_question_df_zip[idx_question_df_zip.iloc[:, 0].map(len) > 0]
62 | idx_question_df_zip.columns = ["question"]
63 | idx_question_df_zip["q_idx"] = np.arange(idx_question_df_zip.shape[0]).tolist()
64 |
65 | idx_question_df = idx_question_df_zip.explode("question")
66 |
67 | #idx_question_df = pd.DataFrame(reduce(lambda a, b: a + b, map(lambda idx: list(map(lambda q: (idx, q), question_nest_list[idx])), range(len(question_nest_list)))))
68 | #idx_question_df.columns = ["q_idx", "question"]
69 | #idx_question_df.drop_duplicates().to_csv(os.path.join("/home/svjack/temp_dir/", "idx_question_df.csv"), index = False)
70 |
71 | idx_question_df_dd = idx_question_df.drop_duplicates()
72 |
73 |
74 |
75 | qa_df_dd = qa_df.drop_duplicates()
76 | cat_qa_df_with_idx = pd.merge(qa_df_dd, idx_question_df_dd, on = "question", how = "inner")
77 | q_idx_set = set(cat_qa_df_with_idx["q_idx"].value_counts().index.tolist())
78 |
79 | q_idx_size_bigger_or_eql_3 = ((cat_qa_df_with_idx["q_idx"].value_counts() >= 3).reset_index()).groupby("q_idx")["index"].apply(set).apply(list)[True]
80 | q_idx_size_bigger_or_eql_3_df = cat_qa_df_with_idx[cat_qa_df_with_idx["q_idx"].isin(q_idx_size_bigger_or_eql_3)].copy()
81 |
82 |
83 | def produce_label_list(length = 10, p_list = [0.1, 0.1, 0.8]):
84 | from functools import reduce
85 | assert sum(p_list) == 1
86 | p_array = np.asarray(p_list)
87 | assert all((p_array[:-1] <= p_array[1:]).astype(bool).tolist())
88 | num_array = (p_array * length).astype(np.int32)
89 | num_list = num_array.tolist()
90 | num_list = list(map(lambda x: max(x, 1), num_list))
91 | num_list[-1] = length - sum(num_list[:-1])
92 | return np.random.permutation(reduce(lambda a, b: a + b ,map(lambda idx: [idx] * num_list[idx], range(len(p_list)))))
93 |
94 | q_idx_size_bigger_or_eql_3_df["r_idx"] = q_idx_size_bigger_or_eql_3_df.index.tolist()
95 |
96 | def map_r_idx_list_to_split_label_zip(r_idx_list):
97 | split_label_list = produce_label_list(len(r_idx_list))
98 | assert len(split_label_list) == len(r_idx_list)
99 | return zip(*[r_idx_list, split_label_list])
100 |
101 | r_idx_split_label_items = reduce(lambda a, b: a + b ,q_idx_size_bigger_or_eql_3_df.groupby("q_idx")["r_idx"].apply(set).apply(list).apply(map_r_idx_list_to_split_label_zip).apply(list).tolist())
102 | r_idx_split_label_df = pd.DataFrame(r_idx_split_label_items)
103 | r_idx_split_label_df.columns = ["r_idx", "split_label"]
104 | assert r_idx_split_label_df.shape[0] == pd.merge(q_idx_size_bigger_or_eql_3_df, r_idx_split_label_df, on = "r_idx", how = "inner").shape[0]
105 |
106 | q_idx_size_bigger_or_eql_3_df_before_split = pd.merge(q_idx_size_bigger_or_eql_3_df, r_idx_split_label_df, on = "r_idx", how = "inner")
107 | train_part = q_idx_size_bigger_or_eql_3_df_before_split[q_idx_size_bigger_or_eql_3_df_before_split["split_label"] == 2].copy()
108 | train_part = pd.concat([train_part, cat_qa_df_with_idx[(1 - cat_qa_df_with_idx["q_idx"].isin(q_idx_size_bigger_or_eql_3)).astype(bool)].copy()], axis = 0)
109 | valid_part = q_idx_size_bigger_or_eql_3_df_before_split[q_idx_size_bigger_or_eql_3_df_before_split["split_label"] == 0].copy()
110 | test_part = q_idx_size_bigger_or_eql_3_df_before_split[q_idx_size_bigger_or_eql_3_df_before_split["split_label"] == 1].copy()
111 |
112 | assert set(valid_part["q_idx"].tolist()) == set(test_part["q_idx"].tolist())
113 | assert set(valid_part["q_idx"].tolist()) == set(valid_part["q_idx"].tolist()).intersection(train_part["q_idx"].tolist())
114 |
115 | train_part.to_csv(os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir, "data", "train_part.csv"), index = False)
116 | test_part.to_csv(os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir, "data", "test_part.csv"), index = False)
117 | valid_part.to_csv(os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir, "data", "valid_part.csv"), index = False)
118 |
--------------------------------------------------------------------------------
/script/cross_encoder/cross_encoder_random_on_multi_eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #es_host = 'localhost:9200'
4 | import csv
5 | import gzip
6 | import json
7 | import logging
8 | import os
9 | import tarfile
10 | import time
11 | from datetime import datetime
12 | #from es_pandas import es_pandas
13 | from functools import reduce
14 | from glob import glob
15 | from typing import Callable, Dict, Iterable, List, Type
16 |
17 | import numpy as np
18 | import pandas as pd
19 | import seaborn as sns
20 | import torch
21 | import transformers
22 | #from elasticsearch import Elasticsearch, helpers
23 | from sentence_transformers import (InputExample, LoggingHandler,
24 | SentenceTransformer, util)
25 | from sentence_transformers.cross_encoder import CrossEncoder
26 | from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
27 | from sentence_transformers.evaluation import (SentenceEvaluator,
28 | SequentialEvaluator)
29 | from torch import nn
30 | from torch.optim import Optimizer
31 | from torch.utils.data import DataLoader
32 | from tqdm.autonotebook import tqdm, trange
33 | from transformers import (AutoConfig, AutoModelForSequenceClassification,
34 | AutoTokenizer)
35 |
36 |
37 | logger = logging.getLogger(__name__)
38 | pd.set_option("display.max_rows", 200)
39 |
40 |
41 | class DictionaryEvaluator(SequentialEvaluator):
42 | def __init__(self, evaluators: Iterable[SentenceEvaluator], main_score_function = lambda x: x):
43 | super(DictionaryEvaluator, self).__init__(evaluators, main_score_function)
44 | #self.eval_name_ext_dict = dict(map(lambda t2: (t2[0].lower()[:t2[0].lower().find("Evaluator".lower())], t2[1]) ,map(lambda eval_ext: (eval_ext.name, eval_ext) ,self.evaluators)))
45 | self.eval_name_ext_dict = dict(map(lambda t2: (t2[0], t2[1]) ,map(lambda eval_ext: (eval_ext.name, eval_ext) ,self.evaluators)))
46 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
47 | scores = {}
48 | #for evaluator in self.evaluators:
49 | for eval_ext_name, evaluator in self.eval_name_ext_dict.items():
50 | #scores.append(evaluator(model, output_path, epoch, steps))
51 | #eval_output_path = output_path + eval_ext_name
52 | eval_output_path = output_path + "_" + eval_ext_name
53 | #scores[eval_ext_name] = evaluator(model, output_path, epoch, steps)
54 | if eval_output_path is not None and not(os.path.exists(eval_output_path)):
55 | os.makedirs(eval_output_path, exist_ok=True)
56 | scores[eval_ext_name] = evaluator(model, eval_output_path, epoch, steps)
57 | return self.main_score_function(scores)
58 |
59 | class CrossEncoder_Dict_Eval(CrossEncoder):
60 | def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback):
61 | """Runs evaluation during the training"""
62 | if evaluator is not None:
63 | if isinstance(evaluator, DictionaryEvaluator):
64 | if type(self.best_score) != type({}):
65 | self.best_score = dict(map(lambda eval_ext_name: (eval_ext_name, self.best_score), evaluator.eval_name_ext_dict.keys()))
66 | score_dict = evaluator(self, output_path=output_path, epoch=epoch, steps=steps)
67 | if callback is not None:
68 | callback(score, epoch, steps)
69 | for eval_ext_name, eval_score in score_dict.items():
70 | if eval_score > self.best_score[eval_ext_name]:
71 | self.best_score[eval_ext_name] = eval_score
72 | if save_best_model:
73 | #eval_output_path = output_path + eval_ext_name
74 | eval_output_path = output_path + "_" + eval_ext_name
75 | self.save(eval_output_path)
76 | else:
77 | score = evaluator(self, output_path=output_path, epoch=epoch, steps=steps)
78 | if callback is not None:
79 | callback(score, epoch, steps)
80 | if score > self.best_score:
81 | self.best_score = score
82 | if save_best_model:
83 | self.save(output_path)
84 |
85 |
86 | class CERerankingEvaluatorSUM:
87 | """
88 | This class evaluates a CrossEncoder model for the task of re-ranking.
89 |
90 | Given a query and a list of documents, it computes the score [query, doc_i] for all possible
91 | documents and sorts them in decreasing order. Then, MRR@10 is compute to measure the quality of the ranking.
92 |
93 | :param samples: Must be a list and each element is of the form: {'query': '', 'positive': [], 'negative': []}. Query is the search query,
94 | positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents.
95 | """
96 | def __init__(self, samples, mrr_at_k: int = 10, name: str = '', num_dev_queries = 600):
97 | self.samples = samples
98 | self.name = name
99 | self.mrr_at_k = mrr_at_k
100 |
101 | self.num_dev_queries = num_dev_queries
102 |
103 | if isinstance(self.samples, dict):
104 | self.samples = list(self.samples.values())
105 | #output/training_ms-marco_cross-encoder-xlm-roberta-base-2021-01-12_21-10-39mrr-train-eva/CERerankingEvaluator_mrr-train-eval_results.csv
106 | self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
107 | self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)]
108 |
109 | self.score_json_file = self.csv_file.replace(".csv", ".json")
110 |
111 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
112 | if epoch != -1:
113 | if steps == -1:
114 | out_txt = " after epoch {}:".format(epoch)
115 | else:
116 | out_txt = " in epoch {} after {} steps:".format(epoch, steps)
117 | else:
118 | out_txt = ":"
119 |
120 | logger.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
121 |
122 | all_mrr_scores = []
123 | num_queries = 0
124 | num_positives = []
125 | num_negatives = []
126 | scores_list = []
127 |
128 | samples = list(self.samples)
129 | samples_indexes = np.random.permutation(np.arange(len(samples)))
130 | samples = list(map(lambda idx: samples[idx], samples_indexes[:self.num_dev_queries]))
131 | #for instance in self.samples:
132 | for instance in samples:
133 | query = instance['query']
134 | positive = list(instance['positive'])
135 | negative = list(instance['negative'])
136 | docs = positive + negative
137 | is_relevant = [True]*len(positive) + [False]*len(negative)
138 |
139 | if len(positive) == 0 or len(negative) == 0:
140 | continue
141 |
142 | num_queries += 1
143 | num_positives.append(len(positive))
144 | num_negatives.append(len(negative))
145 |
146 | model_input = [[query, doc] for doc in docs]
147 | pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
148 | scores_list.extend(list(pred_scores))
149 | pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
150 |
151 | mrr_score = 0
152 | for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
153 | if is_relevant[index]:
154 | mrr_score += 1 / (rank+1)
155 |
156 | all_mrr_scores.append(mrr_score)
157 |
158 | mean_mrr = np.mean(all_mrr_scores)
159 | logger.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
160 | logger.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100))
161 |
162 | if output_path is not None:
163 | csv_path = os.path.join(output_path, self.csv_file)
164 | output_file_exists = os.path.isfile(csv_path)
165 | with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
166 | writer = csv.writer(f)
167 | if not output_file_exists:
168 | writer.writerow(self.csv_headers)
169 | writer.writerow([epoch, steps, mean_mrr])
170 | json_path = os.path.join(output_path, self.score_json_file)
171 | output_file_exists = os.path.isfile(json_path)
172 | with open(json_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
173 | writer = csv.writer(f)
174 | if not output_file_exists:
175 | writer.writerow(["epoch", "steps", "score@"])
176 | writer.writerow([epoch, steps, json.dumps({"scores_list": list(map(float ,scores_list))})])
177 | return mean_mrr
178 |
179 | class ScoreCalculator(object):
180 | def __init__(self,
181 | queries_ids,
182 | relevant_docs,
183 | mrr_at_k: List[int] = [10],
184 | ndcg_at_k: List[int] = [10],
185 | accuracy_at_k: List[int] = [1, 3, 5, 10],
186 | precision_recall_at_k: List[int] = [1, 3, 5, 10],
187 | map_at_k: List[int] = [100],
188 | ):
189 | "queries_ids list of query, relevant_docs key query value set or list of relevant_docs"
190 | self.queries_ids = queries_ids
191 | self.relevant_docs = relevant_docs
192 |
193 | self.mrr_at_k = mrr_at_k
194 | self.ndcg_at_k = ndcg_at_k
195 | self.accuracy_at_k = accuracy_at_k
196 | self.precision_recall_at_k = precision_recall_at_k
197 | self.map_at_k = map_at_k
198 | def compute_metrics(self, queries_result_list: List[object]):
199 | # Init score computation values
200 | num_hits_at_k = {k: 0 for k in self.accuracy_at_k}
201 | precisions_at_k = {k: [] for k in self.precision_recall_at_k}
202 | recall_at_k = {k: [] for k in self.precision_recall_at_k}
203 | MRR = {k: 0 for k in self.mrr_at_k}
204 | ndcg = {k: [] for k in self.ndcg_at_k}
205 | AveP_at_k = {k: [] for k in self.map_at_k}
206 |
207 | # Compute scores on results
208 | #### elements with hits one hit is a dict : {"corpus_id": corpus_text, "score": score}
209 | #### corpus_id replace by corpus text
210 | for query_itr in range(len(queries_result_list)):
211 | query_id = self.queries_ids[query_itr]
212 |
213 | # Sort scores
214 | top_hits = sorted(queries_result_list[query_itr], key=lambda x: x['score'], reverse=True)
215 | query_relevant_docs = self.relevant_docs[query_id]
216 |
217 | # Accuracy@k - We count the result correct, if at least one relevant doc is accross the top-k documents
218 | for k_val in self.accuracy_at_k:
219 | for hit in top_hits[0:k_val]:
220 | if hit['corpus_id'] in query_relevant_docs:
221 | num_hits_at_k[k_val] += 1
222 | break
223 |
224 | # Precision and Recall@k
225 | for k_val in self.precision_recall_at_k:
226 | num_correct = 0
227 | for hit in top_hits[0:k_val]:
228 | if hit['corpus_id'] in query_relevant_docs:
229 | num_correct += 1
230 |
231 | precisions_at_k[k_val].append(num_correct / k_val)
232 | recall_at_k[k_val].append(num_correct / len(query_relevant_docs))
233 |
234 | # MRR@k
235 | for k_val in self.mrr_at_k:
236 | for rank, hit in enumerate(top_hits[0:k_val]):
237 | if hit['corpus_id'] in query_relevant_docs:
238 | MRR[k_val] += 1.0 / (rank + 1)
239 | break
240 |
241 | # NDCG@k
242 | for k_val in self.ndcg_at_k:
243 | predicted_relevance = [1 if top_hit['corpus_id'] in query_relevant_docs else 0 for top_hit in top_hits[0:k_val]]
244 | true_relevances = [1] * len(query_relevant_docs)
245 |
246 | ndcg_value = self.compute_dcg_at_k(predicted_relevance, k_val) / self.compute_dcg_at_k(true_relevances, k_val)
247 | ndcg[k_val].append(ndcg_value)
248 |
249 | # MAP@k
250 | for k_val in self.map_at_k:
251 | num_correct = 0
252 | sum_precisions = 0
253 |
254 | for rank, hit in enumerate(top_hits[0:k_val]):
255 | if hit['corpus_id'] in query_relevant_docs:
256 | num_correct += 1
257 | sum_precisions += num_correct / (rank + 1)
258 |
259 | avg_precision = sum_precisions / min(k_val, len(query_relevant_docs))
260 | AveP_at_k[k_val].append(avg_precision)
261 |
262 | # Compute averages
263 | for k in num_hits_at_k:
264 | #num_hits_at_k[k] /= len(self.queries)
265 | num_hits_at_k[k] /= len(queries_result_list)
266 |
267 | for k in precisions_at_k:
268 | precisions_at_k[k] = np.mean(precisions_at_k[k])
269 |
270 | for k in recall_at_k:
271 | recall_at_k[k] = np.mean(recall_at_k[k])
272 |
273 | for k in ndcg:
274 | ndcg[k] = np.mean(ndcg[k])
275 |
276 | for k in MRR:
277 | #MRR[k] /= len(self.queries)
278 | MRR[k] /= len(queries_result_list)
279 |
280 | for k in AveP_at_k:
281 | AveP_at_k[k] = np.mean(AveP_at_k[k])
282 | return {'accuracy@k': num_hits_at_k, 'precision@k': precisions_at_k, 'recall@k': recall_at_k, 'ndcg@k': ndcg, 'mrr@k': MRR, 'map@k': AveP_at_k}
283 | @staticmethod
284 | def compute_dcg_at_k(relevances, k):
285 | dcg = 0
286 | for i in range(min(len(relevances), k)):
287 | dcg += relevances[i] / np.log2(i + 2) #+2 as we start our idx at 0
288 | return dcg
289 |
290 | def map_dev_samples_to_score_calculator_format(dev_samples):
291 | if isinstance(dev_samples, dict):
292 | dev_samples = list(dev_samples.values())
293 | queries_ids = list(map(lambda x: x["query"] ,dev_samples))
294 | relevant_docs = dict(map(lambda idx: (dev_samples[idx]["query"], dev_samples[idx]["positive"]), range(len(dev_samples))))
295 | return ScoreCalculator(queries_ids, relevant_docs)
296 |
297 | class ScoreEvaluator:
298 | """
299 | This class evaluates a CrossEncoder model for the task of re-ranking.
300 |
301 | Given a query and a list of documents, it computes the score [query, doc_i] for all possible
302 | documents and sorts them in decreasing order. Then, MRR@10 is compute to measure the quality of the ranking.
303 |
304 | :param samples: Must be a list and each element is of the form: {'query': '', 'positive': [], 'negative': []}. Query is the search query,
305 | positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents.
306 | """
307 | def __init__(self, samples, name: str = '', num_dev_queries = 600):
308 | self.samples = samples
309 | self.name = name
310 | self.num_dev_queries = num_dev_queries
311 |
312 | if isinstance(self.samples, dict):
313 | self.samples = list(self.samples.values())
314 |
315 | #self.score_calculator_ext = map_dev_samples_to_score_calculator_format(self.samples)
316 |
317 | self.csv_file = "ScoreEvaluator" + ("_" + name if name else '') + "_results.csv"
318 | self.csv_headers = ["epoch", "steps", "MAP@"]
319 |
320 | self.score_json_file = self.csv_file.replace(".csv", ".json")
321 |
322 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
323 | if epoch != -1:
324 | if steps == -1:
325 | out_txt = " after epoch {}:".format(epoch)
326 | else:
327 | out_txt = " in epoch {} after {} steps:".format(epoch, steps)
328 | else:
329 | out_txt = ":"
330 |
331 | logger.info("ScoreEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
332 |
333 | all_scores = []
334 | num_queries = 0
335 | num_positives = []
336 | num_negatives = []
337 |
338 | scores_list = []
339 |
340 | samples = list(self.samples)
341 | samples_indexes = np.random.permutation(np.arange(len(samples)))
342 | samples = list(map(lambda idx: samples[idx], samples_indexes[:self.num_dev_queries]))
343 | #for instance in self.samples:
344 | for instance in samples:
345 | query = instance['query']
346 | positive = list(instance['positive'])
347 | negative = list(instance['negative'])
348 | docs = positive + negative
349 | is_relevant = [True]*len(positive) + [False]*len(negative)
350 |
351 | if len(positive) == 0 or len(negative) == 0:
352 | continue
353 |
354 | num_queries += 1
355 | num_positives.append(len(positive))
356 | num_negatives.append(len(negative))
357 |
358 | model_input = [[query, doc] for doc in docs]
359 | pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
360 |
361 | scores_list.extend(list(pred_scores))
362 |
363 | #### elements with hits one hit is a dict : {"corpus_id": corpus_text, "score": score}
364 | #### corpus_id replace by corpus text
365 | queries_result_list = list(map(lambda idx: {"corpus_id": docs[idx], "score": pred_scores[idx]}, range(len(docs))))
366 | score_calculator_ext = map_dev_samples_to_score_calculator_format({0: instance})
367 | score_dict = score_calculator_ext.compute_metrics([queries_result_list])
368 | all_scores.append(score_dict["map@k"][100])
369 |
370 | mean_map = np.mean(all_scores)
371 | logger.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
372 | logger.info("MAP@: {:.2f}".format(mean_map*100))
373 |
374 | if output_path is not None:
375 | csv_path = os.path.join(output_path, self.csv_file)
376 | output_file_exists = os.path.isfile(csv_path)
377 | with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
378 | writer = csv.writer(f)
379 | if not output_file_exists:
380 | writer.writerow(self.csv_headers)
381 | #writer.writerow([epoch, steps, mean_mrr])
382 | writer.writerow([epoch, steps, mean_map])
383 | json_path = os.path.join(output_path, self.score_json_file)
384 | output_file_exists = os.path.isfile(json_path)
385 | with open(json_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
386 | writer = csv.writer(f)
387 | if not output_file_exists:
388 | writer.writerow(["epoch", "steps", "score@"])
389 | writer.writerow([epoch, steps, json.dumps({"scores_list": list(map(float ,scores_list))})])
390 | #return mean_mrr
391 | return mean_map
392 |
393 |
394 | def read_part_file(path, adjust_neg_pos_ration = None):
395 | json_full_paths = glob(os.path.join(path, "*.json"))
396 | assert len(json_full_paths) == 4
397 | req = {}
398 | for path in json_full_paths:
399 | val_name = path.split("/")[-1].replace(".json", "")
400 | with open(path, "r", encoding = "utf-8") as f:
401 | j_obj = json.load(f)
402 | if len(j_obj) == 1:
403 | assert type(j_obj[list(j_obj.keys())[0]]) == type([])
404 | j_obj = j_obj[list(j_obj.keys())[0]]
405 | else:
406 | assert val_name.endswith("_index_dict") and len(val_name.split("_")) == 3
407 | val_name = "_".join(np.asarray(val_name.split("_"))[[1, 0, 2]].tolist())
408 | j_obj = dict(map(lambda t2: (t2[1], t2[0]) ,j_obj.items()))
409 | req[val_name] = j_obj
410 | if adjust_neg_pos_ration is not None:
411 | assert type(adjust_neg_pos_ration) == type(0)
412 | neg_tuple_set = req["neg_tuple_set"]
413 | pos_tuple_set = req["pos_tuple_set"]
414 | print("ori pos {} neg {}".format(len(pos_tuple_set), len(neg_tuple_set)))
415 | if len(pos_tuple_set) < len(neg_tuple_set):
416 | if len(neg_tuple_set) / len(pos_tuple_set) >= adjust_neg_pos_ration:
417 | pos_num, neg_num = len(pos_tuple_set), len(pos_tuple_set) * adjust_neg_pos_ration
418 | else:
419 | pos_num, neg_num = int(len(neg_tuple_set) / adjust_neg_pos_ration), len(neg_tuple_set)
420 | else:
421 | min_size = min(map(len, [neg_tuple_set, pos_tuple_set]))
422 | snip_size = int(max(1, min_size / 10000))
423 | snip_num = int(min_size / snip_size)
424 | pos_num = int(min((snip_num / adjust_neg_pos_ration) * snip_size, min_size))
425 | neg_num = int(min((snip_num / adjust_neg_pos_ration) * snip_size * adjust_neg_pos_ration, min_size))
426 | neg_tuple_set = set(list(map(tuple ,neg_tuple_set))[:neg_num])
427 | pos_tuple_set = set(list(map(tuple ,pos_tuple_set))[:pos_num])
428 | req["neg_tuple_set"] = neg_tuple_set
429 | req["pos_tuple_set"] = pos_tuple_set
430 | return req
431 |
432 | def construct_train_samples(json_obj, neg_random = False):
433 | train_samples = []
434 | label = 1
435 | for t2 in json_obj["pos_tuple_set"]:
436 | q_index, a_index = t2[0], t2[1]
437 | q, a = json_obj["index_question_dict"][q_index], json_obj["index_answer_dict"][a_index]
438 | #q = q + ""
439 | train_samples.append(InputExample(texts=[q, a], label=label))
440 | if neg_random:
441 | neg_len = len(json_obj["index_answer_dict"])
442 | label = 0
443 | for t2 in json_obj["neg_tuple_set"]:
444 | q_index, a_index = t2[0], t2[1]
445 | if neg_random:
446 | q, a = json_obj["index_question_dict"][q_index], json_obj["index_answer_dict"][np.random.randint(0, neg_len)]
447 | else:
448 | q, a = json_obj["index_question_dict"][q_index], json_obj["index_answer_dict"][a_index]
449 | #q = q + ""
450 | train_samples.append(InputExample(texts=[q, a], label=label))
451 | train_indexes = np.random.permutation(np.arange(len(train_samples)))
452 | return list(map(lambda idx: train_samples[idx], train_indexes))
453 |
454 | def construct_dev_samples(json_obj ,num_dev_queries = int(2e3), num_max_dev_negatives = 200, neg_random = False):
455 | dev_samples_list = construct_train_samples(json_obj, neg_random)
456 | dev_samples_df = pd.DataFrame(list(map(lambda item_:(item_.texts[0], item_.texts[1], item_.label) ,dev_samples_list)), columns = ["q", "a", "l"])
457 | dev_q_qid_dict = dict(map(lambda t2: (t2[1], t2[0]), enumerate(dev_samples_df["q"].drop_duplicates().tolist())))
458 | print("dev_samples_df shape {}, dev_q_qid_dict len {}".format(dev_samples_df.shape, len(dev_q_qid_dict)))
459 | dev_samples = {}
460 | for q, qid in dev_q_qid_dict.items():
461 | if qid not in dev_samples and len(dev_samples) < num_dev_queries:
462 | dev_samples[qid] = {'query': q, 'positive': set(), 'negative': set()}
463 | for qid in dev_samples.keys():
464 | qid_relate_df = dev_samples_df[dev_samples_df["q"] == dev_samples[qid]["query"]]
465 | dev_samples[qid]["positive"] = set(qid_relate_df[qid_relate_df["l"] == 1]["a"].tolist())
466 | for neg_a in set(qid_relate_df[qid_relate_df["l"] == 0]["a"].tolist()):
467 | if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
468 | dev_samples[qid]['negative'].add(neg_a)
469 | return dev_samples
470 |
471 |
472 | def merge_dev_samples(dev_samples_list):
473 | assert len(dev_samples_list) > 1
474 | query_pos_set_dict = {}
475 | query_neg_set_dict = {}
476 | for dev_samples in dev_samples_list:
477 | for qid, item_ in dev_samples.items():
478 | # item_ {'query': q, 'positive': set(), 'negative': set()}
479 | query = item_["query"]
480 | positive = item_["positive"]
481 | negative = item_["negative"]
482 | if query not in query_pos_set_dict:
483 | query_pos_set_dict[query] = positive
484 | else:
485 | for ele in positive:
486 | query_pos_set_dict[query].add(ele)
487 | if query not in query_neg_set_dict:
488 | query_neg_set_dict[query] = negative
489 | else:
490 | for ele in negative:
491 | query_neg_set_dict[query].add(ele)
492 | assert set(query_pos_set_dict.keys()) == set(query_neg_set_dict.keys())
493 | merge_dev_samples = {}
494 | for qid, query in enumerate(query_pos_set_dict.keys()):
495 | merge_dev_samples[qid] = {'query': query, 'positive': query_pos_set_dict[query], 'negative': query_neg_set_dict[query]}
496 | return merge_dev_samples
497 |
498 |
499 | def merge_dev_samples_add_neg(dev_samples_list, add_num = None, after_size = 500):
500 | assert len(dev_samples_list) > 1
501 | query_pos_set_dict = {}
502 | query_neg_set_dict = {}
503 | all_negs = reduce(lambda a, b: a.union(b) ,map(lambda dev_samples: reduce(lambda a, b: a.union(b) ,map(lambda item_:set(item_["negative"]) ,dev_samples.values())), dev_samples_list))
504 | assert type(all_negs) == type(set([]))
505 | assert set(map(type, all_negs)) == set([type("")])
506 | all_negs = list(all_negs)
507 | for dev_samples in dev_samples_list:
508 | for qid, item_ in dev_samples.items():
509 | # item_ {'query': q, 'positive': set(), 'negative': set()}
510 | query = item_["query"]
511 | positive = item_["positive"]
512 | negative = item_["negative"]
513 | if query not in query_pos_set_dict:
514 | query_pos_set_dict[query] = positive
515 | else:
516 | for ele in positive:
517 | query_pos_set_dict[query].add(ele)
518 | if query not in query_neg_set_dict:
519 | query_neg_set_dict[query] = negative
520 | else:
521 | for ele in negative:
522 | query_neg_set_dict[query].add(ele)
523 | if add_num is not None:
524 | assert type(add_num) == type(0)
525 | all_negs_indexes = np.random.permutation(np.arange(len(all_negs)))
526 | all_negs_add = list(map(lambda idx: all_negs[idx], all_negs_indexes[:add_num]))
527 | for ele in all_negs_add:
528 | if ele not in positive:
529 | query_neg_set_dict[query].add(ele)
530 | assert set(query_pos_set_dict.keys()) == set(query_neg_set_dict.keys())
531 | merge_dev_samples = {}
532 | for qid, query in enumerate(query_pos_set_dict.keys()):
533 | if len(merge_dev_samples) >= after_size:
534 | break
535 | merge_dev_samples[qid] = {'query': query, 'positive': query_pos_set_dict[query], 'negative': query_neg_set_dict[query]}
536 | return merge_dev_samples
537 |
538 |
539 | train_json_obj = read_part_file("train_file_faiss_10", adjust_neg_pos_ration = None)
540 | train_samples = construct_train_samples(train_json_obj, neg_random = True)
541 |
542 |
543 | valid_json_obj = read_part_file("valid_file_faiss", adjust_neg_pos_ration = None)
544 |
545 |
546 | dev_samples = construct_dev_samples(valid_json_obj, neg_random = True)
547 |
548 |
549 |
550 | model_name = 'xlm-roberta-base'
551 | model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
552 | model = CrossEncoder_Dict_Eval(model_name, num_labels=1, max_length=512)
553 |
554 |
555 |
556 | train_batch_size = 10
557 | num_epochs = 10
558 | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
559 |
560 |
561 |
562 | mrr_sum_evaluator = CERerankingEvaluatorSUM(dev_samples, name='mrr-train-eval')
563 | map_evaluator = ScoreEvaluator(dev_samples, name = "map-train-eval")
564 | dict_evaluator = DictionaryEvaluator([mrr_sum_evaluator, map_evaluator])
565 |
566 |
567 | warmup_steps = 1000
568 | logging.info("Warmup-steps: {}".format(warmup_steps))
569 |
570 |
571 | # Train the model
572 | model.fit(train_dataloader=train_dataloader,
573 | evaluator=dict_evaluator,
574 | epochs=num_epochs,
575 | evaluation_steps=5000,
576 | warmup_steps=warmup_steps,
577 | output_path=model_save_path,
578 | use_amp=False)
579 |
580 | #Save latest model
581 | model.save(model_save_path+'-latest')
582 |
583 |
584 |
585 |
--------------------------------------------------------------------------------
/script/cross_encoder/try_sbert_neg_sampler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # In[1]:
5 | import json
6 | import logging
7 | import os
8 | import pickle
9 | import random
10 | import time
11 | import traceback
12 | from functools import reduce
13 |
14 | import faiss
15 | import numpy as np
16 | import pandas as pd
17 | import scipy.spatial
18 | import torch
19 | from elasticsearch import Elasticsearch, helpers
20 | from es_pandas import es_pandas
21 | from IPython import embed
22 | from sentence_transformers import InputExample, SentenceTransformer, util
23 |
24 | es_host = 'localhost:9200'
25 | train_part, test_part, valid_part = map(lambda save_type: pd.read_csv(os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir, "data", "{}_part.csv".format(save_type))
26 | ).dropna(), ["train", "test", "valid"])
27 |
28 |
29 | class es_pandas_edit(es_pandas):
30 | @staticmethod
31 | def serialize(row, columns, use_pandas_json, iso_dates):
32 | if use_pandas_json:
33 | return json.dumps(dict(zip(columns, row)), iso_dates=iso_dates)
34 | return dict(zip(columns, [None if (all(pd.isna(r)) if (hasattr(r, "__len__") and type(r) != type("")) else pd.isna(r)) else r for r in row]))
35 | def to_pandas_iter(self, index, query_rule=None, heads=[], dtype={}, infer_dtype=False, show_progress=True,
36 | chunk_size = None, **kwargs):
37 | if query_rule is None:
38 | query_rule = {'query': {'match_all': {}}}
39 | count = self.es.count(index=index, body=query_rule)['count']
40 | if count < 1:
41 | raise Exception('Empty for %s' % index)
42 | query_rule['_source'] = heads
43 | anl = helpers.scan(self.es, query=query_rule, index=index, **kwargs)
44 | source_iter = self.get_source(anl, show_progress = show_progress, count = count)
45 | print(source_iter)
46 | if chunk_size is None:
47 | df = pd.DataFrame(list(source_iter)).set_index('_id')
48 | if infer_dtype:
49 | dtype = self.infer_dtype(index, df.columns.values)
50 | if len(dtype):
51 | df = df.astype(dtype)
52 | yield df
53 | return
54 | assert type(chunk_size) == type(0)
55 | def map_list_of_dicts_into_df(list_of_dicts, set_index = "_id"):
56 | from collections import defaultdict
57 | req = defaultdict(list)
58 | for dict_ in list_of_dicts:
59 | for k, v in dict_.items():
60 | req[k].append(v)
61 | req = pd.DataFrame.from_dict(req)
62 | if set_index:
63 | assert set_index in req.columns.tolist()
64 | t_df = req.set_index(set_index)
65 | if infer_dtype:
66 | dtype = self.infer_dtype(index, t_df.columns.values)
67 | if len(dtype):
68 | t_df = t_df.astype(dtype)
69 | return t_df
70 | list_of_dicts = []
71 | for dict_ in source_iter:
72 | list_of_dicts.append(dict_)
73 | if len(list_of_dicts) >= chunk_size:
74 | yield map_list_of_dicts_into_df(list_of_dicts)
75 | list_of_dicts = []
76 | if list_of_dicts:
77 | yield map_list_of_dicts_into_df(list_of_dicts)
78 |
79 |
80 |
81 | ep = es_pandas_edit(es_host)
82 | if ep.ic.exists("train_part"):
83 | ep.ic.delete(index = "train_part")
84 |
85 |
86 | ep.init_es_tmpl(train_part.head(1000), "train_part_doc_type", delete=True)
87 | valid_part_tmp = ep.es.indices.get_template("train_part_doc_type")
88 | es_index = valid_part_tmp["train_part_doc_type"]
89 | es_index["mappings"]["properties"]["question"] = {
90 | "type": "text",
91 | }
92 | es_index["mappings"]["properties"]["answer"] = {
93 | "type": "text",
94 | }
95 | es_index = {"mappings": es_index["mappings"]}
96 | ep.es.indices.create(index='train_part', body=es_index, ignore=[400])
97 |
98 |
99 | chunk_size = 10000
100 | range_list = list(range(0, train_part.shape[0], chunk_size))
101 | if train_part.shape[0] not in range_list:
102 | range_list.append(train_part.shape[0])
103 | assert "".join(map(str ,range_list)).startswith("0") and "".join(map(str ,range_list)).endswith("{}".format(train_part.shape[0]))
104 |
105 | for i in range(len(range_list) - 1):
106 | part_tiny = train_part.iloc[range_list[i]:range_list[i+1]]
107 | ep.to_es(part_tiny, "train_part")
108 |
109 | assert reduce(lambda a, b: a + b, map(lambda df: df.shape[0] ,ep.to_pandas_iter("train_part", chunk_size = chunk_size))) == train_part.shape[0]
110 |
111 | if ep.ic.exists("valid_part"):
112 | ep.ic.delete(index = "valid_part")
113 |
114 |
115 | ep.init_es_tmpl(train_part.head(1000), "valid_part_doc_type", delete=True)
116 | valid_part_tmp = ep.es.indices.get_template("valid_part_doc_type")
117 | es_index = valid_part_tmp["valid_part_doc_type"]
118 | es_index["mappings"]["properties"]["question"] = {
119 | "type": "text",
120 | }
121 | es_index["mappings"]["properties"]["answer"] = {
122 | "type": "text",
123 | }
124 | es_index = {"mappings": es_index["mappings"]}
125 | ep.es.indices.create(index='valid_part', body=es_index, ignore=[400])
126 |
127 |
128 | chunk_size = 10000
129 | range_list = list(range(0, valid_part.shape[0], chunk_size))
130 | if valid_part.shape[0] not in range_list:
131 | range_list.append(valid_part.shape[0])
132 | assert "".join(map(str ,range_list)).startswith("0") and "".join(map(str ,range_list)).endswith("{}".format(valid_part.shape[0]))
133 |
134 | for i in range(len(range_list) - 1):
135 | part_tiny = valid_part.iloc[range_list[i]:range_list[i+1]]
136 | ep.to_es(part_tiny, "valid_part")
137 |
138 | assert reduce(lambda a, b: a + b, map(lambda df: df.shape[0] ,ep.to_pandas_iter("valid_part", chunk_size = chunk_size))) == valid_part.shape[0]
139 |
140 |
141 | class SentenceBERTNegativeSampler():
142 | """
143 | Sample candidates from a list of candidates using dense embeddings from sentenceBERT.
144 |
145 | Args:
146 | candidates: list of str containing the candidates
147 | num_candidates_samples: int containing the number of negative samples for each query.
148 | embeddings_file: str containing the path to cache the embeddings.
149 | sample_data: int indicating amount of candidates in the index (-1 if all)
150 | pre_trained_model: str containing the pre-trained sentence embedding model,
151 | e.g. bert-base-nli-stsb-mean-tokens.
152 | """
153 | def __init__(self, candidates, num_candidates_samples, embeddings_file, sample_data,
154 | pre_trained_model='bert-base-nli-stsb-mean-tokens', seed=42):
155 | random.seed(seed)
156 | self.candidates = candidates
157 | self.num_candidates_samples = num_candidates_samples
158 | self.pre_trained_model = pre_trained_model
159 |
160 | #self.model = SentenceTransformer(self.pre_trained_model)
161 | self.model = SentenceTransformer(self.pre_trained_model, device = "cpu")
162 | #extract the name of the folder with the pre-trained sentence embedding
163 | if os.path.isdir(self.pre_trained_model):
164 | self.pre_trained_model = self.pre_trained_model.split("/")[-1]
165 |
166 | self.name = "SentenceBERTNS_"+self.pre_trained_model
167 | self.sample_data = sample_data
168 | self.embeddings_file = embeddings_file
169 |
170 | self._calculate_sentence_embeddings()
171 | self._build_faiss_index()
172 |
173 | def _calculate_sentence_embeddings(self):
174 | """
175 | Calculates sentenceBERT embeddings for all candidates.
176 | """
177 | embeds_file_path = "{}_n_sample_{}_pre_trained_model_{}".format(self.embeddings_file,
178 | self.sample_data,
179 | self.pre_trained_model)
180 | if not os.path.isfile(embeds_file_path):
181 | logging.info("Calculating embeddings for the candidates.")
182 | self.candidate_embeddings = self.model.encode(self.candidates, show_progress_bar=True)
183 | with open(embeds_file_path, 'wb') as f:
184 | pickle.dump(self.candidate_embeddings, f)
185 | else:
186 | with open(embeds_file_path, 'rb') as f:
187 | self.candidate_embeddings = pickle.load(f)
188 |
189 | def _build_faiss_index(self):
190 | """
191 | Builds the faiss indexes containing all sentence embeddings of the candidates.
192 | """
193 | self.index = faiss.IndexFlatL2(self.candidate_embeddings[0].shape[0]) # build the index
194 | self.index.add(np.array(self.candidate_embeddings))
195 | logging.info("There is a total of {} candidates.".format(len(self.candidates)))
196 | logging.info("There is a total of {} candidate embeddings.".format(len(self.candidate_embeddings)))
197 | logging.info("Faiss index has a total of {} candidates".format(self.index.ntotal))
198 |
199 | def sample(self, query_str, relevant_docs):
200 | """
201 | Samples from a list of candidates using dot product sentenceBERT similarity.
202 |
203 | If the samples match the relevant doc, then removes it and re-samples randomly.
204 | The method uses faiss index to be efficient.
205 |
206 | Args:
207 | query_str: the str of the query to be used for the dense similarity matching.
208 | relevant_docs: list with the str of the relevant documents, to avoid sampling them as negative sample.
209 |
210 | Returns:
211 | A triplet containing the list of negative samples,
212 | whether the method had retrieved the relevant doc and
213 | if yes its rank in the list.
214 | """
215 | query_embedding = self.model.encode([query_str], show_progress_bar=False)
216 |
217 | distances, idxs = self.index.search(np.array(query_embedding), self.num_candidates_samples)
218 | sampled_initial = [self.candidates[idx] for idx in idxs[0]]
219 |
220 | was_relevant_sampled = False
221 | relevant_doc_rank = -1
222 | sampled = []
223 | for i, d in enumerate(sampled_initial):
224 | if d in relevant_docs:
225 | was_relevant_sampled = True
226 | relevant_doc_rank = i
227 | else:
228 | sampled.append(d)
229 |
230 | while len(sampled) != self.num_candidates_samples:
231 | sampled = sampled + [d for d in random.sample(self.candidates, self.num_candidates_samples-len(sampled))
232 | if d not in relevant_docs]
233 | return sampled, was_relevant_sampled, relevant_doc_rank
234 |
235 |
236 |
237 |
238 |
239 | #chunk_size = 10000
240 | #train_part = pd.concat(list(ep.to_pandas_iter("train_part", chunk_size = chunk_size)), axis = 0)
241 | candidates = train_part["answer"].tolist()
242 |
243 |
244 | num_candidates_samples = 30
245 | embeddings_file = os.path.join(os.path.abspath(""), "train_sbert_emb_cache")
246 | sample_data = -1
247 | pre_trained_model = os.path.join(os.path.abspath(""), "bi_encoder_save")
248 | sbert_sampler = SentenceBERTNegativeSampler(candidates, num_candidates_samples, embeddings_file, sample_data,
249 | pre_trained_model)
250 |
251 |
252 | def part_gen_constructor(sampler, part_df):
253 | #question_neg_dict = {}
254 | for question, df in part_df.groupby("question"):
255 | pos_answer_list = df["answer"].tolist()
256 | negs = sbert_sampler.sample(question, pos_answer_list)
257 | #negs = sbert_sampler.sample(question, [])
258 | #neg_mg_df = pd.merge(train_part_tiny, pd.DataFrame(np.asarray(negs[0]).reshape([-1, 1]), columns = ["answer"]), on = "answer", how = "inner")
259 | #question_neg_dict[question] = neg_mg_df
260 | for pos_answer in pos_answer_list:
261 | yield InputExample(texts=[question, pos_answer], label=1)
262 | for neg_answer in negs[0]:
263 | yield InputExample(texts=[question, neg_answer], label=0)
264 |
265 |
266 | def json_save(input_collection, path):
267 | assert path.endswith(".json")
268 | assert type(input_collection) in [type({}), type(set([]))]
269 | with open(path, "w", encoding = "utf-8") as f:
270 | if type(input_collection) == type({}):
271 | #json.dump(input_collection, f, encoding = "utf-8")
272 | pass
273 | else:
274 | input_collection = {path.split("/")[-1].replace(".json", ""): list(input_collection)}
275 | json.dump(input_collection, f)
276 | print("save to {}".format(path))
277 |
278 |
279 | def produce_question_answer_sample_in_file_format(part_gen, chunck_size = 1000, save_times = 1, sub_dir = None):
280 | question_index_dict = {}
281 | answer_index_dict = {}
282 | pos_tuple_set = set([])
283 | neg_tuple_set = set([])
284 | have_save = 0
285 | #for idx, item_ in enumerate(part_gen):
286 | idx = 0
287 | while True:
288 | item_ = part_gen.__next__()
289 | idx += 1
290 | question, answer = item_.texts
291 | if question not in question_index_dict:
292 | question_index_dict[question] = len(question_index_dict)
293 | if answer not in answer_index_dict:
294 | answer_index_dict[answer] = len(answer_index_dict)
295 | label = item_.label
296 | assert label in [0, 1]
297 | if label == 1:
298 | pos_tuple_set.add((question_index_dict[question], answer_index_dict[answer]))
299 | else:
300 | neg_tuple_set.add((question_index_dict[question], answer_index_dict[answer]))
301 | if sub_dir is not None and not os.path.exists(os.path.join(os.path.abspath(""), sub_dir)):
302 | assert type(sub_dir) == type("") and "/" not in sub_dir
303 | os.mkdir(os.path.join(os.path.abspath(""), sub_dir))
304 | if (idx + 1) % chunck_size == 0:
305 | for c in ["question_index_dict", "answer_index_dict", "pos_tuple_set", "neg_tuple_set"]:
306 | if sub_dir is None:
307 | exec("json_save({}, '{}.json')".format(c, os.path.join(os.path.abspath(""), c)))
308 | else:
309 | exec("json_save({}, '{}.json')".format(c, os.path.join(os.path.abspath(""), sub_dir, c)))
310 | have_save += 1
311 | print("have_save in {} step".format(idx + 1))
312 | if have_save >= save_times:
313 | return
314 |
315 |
316 | train_part_gen = part_gen_constructor(sbert_sampler, train_part)
317 | produce_question_answer_sample_in_file_format(train_part_gen, chunck_size = 10000, save_times = 10000,
318 | sub_dir = "train_file_faiss_10")
319 |
320 |
321 |
322 |
323 |
--------------------------------------------------------------------------------
/script/cross_encoder/try_sbert_neg_sampler_valid.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import json
4 | import logging
5 | import os
6 | import pickle
7 | import random
8 | import time
9 | import traceback
10 | from functools import reduce
11 |
12 | import faiss
13 | import numpy as np
14 | import pandas as pd
15 | import scipy.spatial
16 | import torch
17 | from elasticsearch import Elasticsearch, helpers
18 | from es_pandas import es_pandas
19 | from IPython import embed
20 | from sentence_transformers import InputExample, SentenceTransformer, util
21 |
22 | es_host = 'localhost:9200'
23 |
24 | class es_pandas_edit(es_pandas):
25 | @staticmethod
26 | def serialize(row, columns, use_pandas_json, iso_dates):
27 | if use_pandas_json:
28 | return json.dumps(dict(zip(columns, row)), iso_dates=iso_dates)
29 | return dict(zip(columns, [None if (all(pd.isna(r)) if (hasattr(r, "__len__") and type(r) != type("")) else pd.isna(r)) else r for r in row]))
30 | def to_pandas_iter(self, index, query_rule=None, heads=[], dtype={}, infer_dtype=False, show_progress=True,
31 | chunk_size = None, **kwargs):
32 | if query_rule is None:
33 | query_rule = {'query': {'match_all': {}}}
34 | count = self.es.count(index=index, body=query_rule)['count']
35 | if count < 1:
36 | raise Exception('Empty for %s' % index)
37 | query_rule['_source'] = heads
38 | anl = helpers.scan(self.es, query=query_rule, index=index, **kwargs)
39 | source_iter = self.get_source(anl, show_progress = show_progress, count = count)
40 | print(source_iter)
41 | if chunk_size is None:
42 | df = pd.DataFrame(list(source_iter)).set_index('_id')
43 | if infer_dtype:
44 | dtype = self.infer_dtype(index, df.columns.values)
45 | if len(dtype):
46 | df = df.astype(dtype)
47 | yield df
48 | return
49 | assert type(chunk_size) == type(0)
50 | def map_list_of_dicts_into_df(list_of_dicts, set_index = "_id"):
51 | from collections import defaultdict
52 | req = defaultdict(list)
53 | for dict_ in list_of_dicts:
54 | for k, v in dict_.items():
55 | req[k].append(v)
56 | req = pd.DataFrame.from_dict(req)
57 | if set_index:
58 | assert set_index in req.columns.tolist()
59 | t_df = req.set_index(set_index)
60 | if infer_dtype:
61 | dtype = self.infer_dtype(index, t_df.columns.values)
62 | if len(dtype):
63 | t_df = t_df.astype(dtype)
64 | return t_df
65 | list_of_dicts = []
66 | for dict_ in source_iter:
67 | list_of_dicts.append(dict_)
68 | if len(list_of_dicts) >= chunk_size:
69 | yield map_list_of_dicts_into_df(list_of_dicts)
70 | list_of_dicts = []
71 | if list_of_dicts:
72 | yield map_list_of_dicts_into_df(list_of_dicts)
73 |
74 |
75 | class SentenceBERTNegativeSampler():
76 | """
77 | Sample candidates from a list of candidates using dense embeddings from sentenceBERT.
78 |
79 | Args:
80 | candidates: list of str containing the candidates
81 | num_candidates_samples: int containing the number of negative samples for each query.
82 | embeddings_file: str containing the path to cache the embeddings.
83 | sample_data: int indicating amount of candidates in the index (-1 if all)
84 | pre_trained_model: str containing the pre-trained sentence embedding model,
85 | e.g. bert-base-nli-stsb-mean-tokens.
86 | """
87 | def __init__(self, candidates, num_candidates_samples, embeddings_file, sample_data,
88 | pre_trained_model='bert-base-nli-stsb-mean-tokens', seed=42):
89 | random.seed(seed)
90 | self.candidates = candidates
91 | self.num_candidates_samples = num_candidates_samples
92 | self.pre_trained_model = pre_trained_model
93 |
94 | self.model = SentenceTransformer(self.pre_trained_model)
95 | #extract the name of the folder with the pre-trained sentence embedding
96 | if os.path.isdir(self.pre_trained_model):
97 | self.pre_trained_model = self.pre_trained_model.split("/")[-1]
98 |
99 | self.name = "SentenceBERTNS_"+self.pre_trained_model
100 | self.sample_data = sample_data
101 | self.embeddings_file = embeddings_file
102 |
103 | self._calculate_sentence_embeddings()
104 | self._build_faiss_index()
105 |
106 | def _calculate_sentence_embeddings(self):
107 | """
108 | Calculates sentenceBERT embeddings for all candidates.
109 | """
110 | embeds_file_path = "{}_n_sample_{}_pre_trained_model_{}".format(self.embeddings_file,
111 | self.sample_data,
112 | self.pre_trained_model)
113 | if not os.path.isfile(embeds_file_path):
114 | logging.info("Calculating embeddings for the candidates.")
115 | self.candidate_embeddings = self.model.encode(self.candidates, show_progress_bar=True)
116 | with open(embeds_file_path, 'wb') as f:
117 | pickle.dump(self.candidate_embeddings, f)
118 | else:
119 | with open(embeds_file_path, 'rb') as f:
120 | self.candidate_embeddings = pickle.load(f)
121 |
122 | def _build_faiss_index(self):
123 | """
124 | Builds the faiss indexes containing all sentence embeddings of the candidates.
125 | """
126 | self.index = faiss.IndexFlatL2(self.candidate_embeddings[0].shape[0]) # build the index
127 | self.index.add(np.array(self.candidate_embeddings))
128 | logging.info("There is a total of {} candidates.".format(len(self.candidates)))
129 | logging.info("There is a total of {} candidate embeddings.".format(len(self.candidate_embeddings)))
130 | logging.info("Faiss index has a total of {} candidates".format(self.index.ntotal))
131 |
132 | def sample(self, query_str, relevant_docs):
133 | """
134 | Samples from a list of candidates using dot product sentenceBERT similarity.
135 |
136 | If the samples match the relevant doc, then removes it and re-samples randomly.
137 | The method uses faiss index to be efficient.
138 |
139 | Args:
140 | query_str: the str of the query to be used for the dense similarity matching.
141 | relevant_docs: list with the str of the relevant documents, to avoid sampling them as negative sample.
142 |
143 | Returns:
144 | A triplet containing the list of negative samples,
145 | whether the method had retrieved the relevant doc and
146 | if yes its rank in the list.
147 | """
148 | query_embedding = self.model.encode([query_str], show_progress_bar=False)
149 |
150 | distances, idxs = self.index.search(np.array(query_embedding), self.num_candidates_samples)
151 | sampled_initial = [self.candidates[idx] for idx in idxs[0]]
152 |
153 | was_relevant_sampled = False
154 | relevant_doc_rank = -1
155 | sampled = []
156 | for i, d in enumerate(sampled_initial):
157 | if d in relevant_docs:
158 | was_relevant_sampled = True
159 | relevant_doc_rank = i
160 | else:
161 | sampled.append(d)
162 |
163 | while len(sampled) != self.num_candidates_samples:
164 | sampled = sampled + [d for d in random.sample(self.candidates, self.num_candidates_samples-len(sampled))
165 | if d not in relevant_docs]
166 | return sampled, was_relevant_sampled, relevant_doc_rank
167 |
168 |
169 | ep = es_pandas_edit(es_host)
170 | chunk_size = 10000
171 | valid_part = pd.concat(list(ep.to_pandas_iter("valid_part", chunk_size = chunk_size)), axis = 0)
172 |
173 |
174 | num_candidates_samples = 4
175 | embeddings_file = os.path.join(os.path.abspath(""), "valid_sbert_emb_cache")
176 | sample_data = -1
177 | pre_trained_model = os.path.join(os.path.abspath(""), "bi_encoder_save")
178 | sbert_sampler = SentenceBERTNegativeSampler(candidates, num_candidates_samples, embeddings_file, sample_data,
179 | pre_trained_model)
180 |
181 | def part_gen_constructor(sampler, part_df):
182 | #question_neg_dict = {}
183 | for question, df in part_df.groupby("question"):
184 | pos_answer_list = df["answer"].tolist()
185 | negs = sbert_sampler.sample(question, pos_answer_list)
186 | #negs = sbert_sampler.sample(question, [])
187 | #neg_mg_df = pd.merge(train_part_tiny, pd.DataFrame(np.asarray(negs[0]).reshape([-1, 1]), columns = ["answer"]), on = "answer", how = "inner")
188 | #question_neg_dict[question] = neg_mg_df
189 | for pos_answer in pos_answer_list:
190 | yield InputExample(texts=[question, pos_answer], label=1)
191 | for neg_answer in negs[0]:
192 | yield InputExample(texts=[question, neg_answer], label=0)
193 |
194 | def json_save(input_collection, path):
195 | assert path.endswith(".json")
196 | assert type(input_collection) in [type({}), type(set([]))]
197 | with open(path, "w", encoding = "utf-8") as f:
198 | if type(input_collection) == type({}):
199 | #json.dump(input_collection, f, encoding = "utf-8")
200 | pass
201 | else:
202 | input_collection = {path.split("/")[-1].replace(".json", ""): list(input_collection)}
203 | json.dump(input_collection, f)
204 | print("save to {}".format(path))
205 |
206 |
207 |
208 | def produce_question_answer_sample_in_file_format(part_gen, chunck_size = 1000, save_times = 1, sub_dir = None):
209 | question_index_dict = {}
210 | answer_index_dict = {}
211 | pos_tuple_set = set([])
212 | neg_tuple_set = set([])
213 | have_save = 0
214 | #for idx, item_ in enumerate(part_gen):
215 | idx = 0
216 | while True:
217 | item_ = part_gen.__next__()
218 | idx += 1
219 | question, answer = item_.texts
220 | if question not in question_index_dict:
221 | question_index_dict[question] = len(question_index_dict)
222 | if answer not in answer_index_dict:
223 | answer_index_dict[answer] = len(answer_index_dict)
224 | label = item_.label
225 | assert label in [0, 1]
226 | if label == 1:
227 | pos_tuple_set.add((question_index_dict[question], answer_index_dict[answer]))
228 | else:
229 | neg_tuple_set.add((question_index_dict[question], answer_index_dict[answer]))
230 | if sub_dir is not None and not os.path.exists(os.path.join(os.path.abspath(""), sub_dir)):
231 | assert type(sub_dir) == type("") and "/" not in sub_dir
232 | os.mkdir(os.path.join(os.path.abspath(""), sub_dir))
233 | if (idx + 1) % chunck_size == 0:
234 | for c in ["question_index_dict", "answer_index_dict", "pos_tuple_set", "neg_tuple_set"]:
235 | if sub_dir is None:
236 | exec("json_save({}, '{}.json')".format(c, os.path.join(os.path.abspath(""), c)))
237 | else:
238 | exec("json_save({}, '{}.json')".format(c, os.path.join(os.path.abspath(""), sub_dir, c)))
239 | have_save += 1
240 | print("have_save in {} step".format(idx + 1))
241 | if have_save >= save_times:
242 | return
243 |
244 |
245 | valid_part_gen = part_gen_constructor(sbert_sampler, valid_part)
246 |
247 |
248 |
249 | produce_question_answer_sample_in_file_format(valid_part_gen, chunck_size = 3000, save_times = 10000,
250 | sub_dir = "valid_file_faiss")
251 |
252 |
253 |
254 |
255 |
256 |
257 |
--------------------------------------------------------------------------------
/script/cross_encoder/valid_cross_encoder_on_bi_encoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import gzip
4 | import json
5 | import logging
6 | import os
7 | import tarfile
8 | import time
9 | from datetime import datetime
10 | from functools import partial, reduce
11 | from glob import glob
12 | from typing import Callable, Dict, List, Type
13 |
14 | import numpy as np
15 | import pandas as pd
16 | import seaborn as sns
17 | import torch
18 | import tqdm
19 | from elasticsearch import Elasticsearch, helpers
20 | from es_pandas import es_pandas
21 | from sentence_transformers import (InputExample, LoggingHandler,
22 | SentenceTransformer, util)
23 | from sentence_transformers.cross_encoder import CrossEncoder
24 | from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
25 | from sklearn.metrics import balanced_accuracy_score
26 | from torch.utils.data import DataLoader
27 |
28 | pd.set_option("display.max_rows", 200)
29 | es_host = 'localhost:9200'
30 |
31 |
32 | bi_model_path = os.path.join(os.path.dirname("__file__"), os.path.pardir, "bi_encoder_save/")
33 | bi_model = SentenceTransformer(bi_model_path, device = "cpu")
34 |
35 |
36 | cross_model_path = "output/training_ms-marco_cross-encoder-xlm-roberta-base-2021-01-17_14-43-23_map-train-eval"
37 | cross_model = CrossEncoder(cross_model_path, num_labels=1, max_length=512, device = "cpu")
38 |
39 |
40 | class es_pandas_edit(es_pandas):
41 | @staticmethod
42 | def serialize(row, columns, use_pandas_json, iso_dates):
43 | if use_pandas_json:
44 | return json.dumps(dict(zip(columns, row)), iso_dates=iso_dates)
45 | return dict(zip(columns, [None if (all(pd.isna(r)) if (hasattr(r, "__len__") and type(r) != type("")) else pd.isna(r)) else r for r in row]))
46 | def to_pandas_iter(self, index, query_rule=None, heads=[], dtype={}, infer_dtype=False, show_progress=True,
47 | chunk_size = None, **kwargs):
48 | if query_rule is None:
49 | query_rule = {'query': {'match_all': {}}}
50 | count = self.es.count(index=index, body=query_rule)['count']
51 | if count < 1:
52 | raise Exception('Empty for %s' % index)
53 | query_rule['_source'] = heads
54 | anl = helpers.scan(self.es, query=query_rule, index=index, **kwargs)
55 | source_iter = self.get_source(anl, show_progress = show_progress, count = count)
56 | print(source_iter)
57 | if chunk_size is None:
58 | df = pd.DataFrame(list(source_iter)).set_index('_id')
59 | if infer_dtype:
60 | dtype = self.infer_dtype(index, df.columns.values)
61 | if len(dtype):
62 | df = df.astype(dtype)
63 | yield df
64 | return
65 | assert type(chunk_size) == type(0)
66 | def map_list_of_dicts_into_df(list_of_dicts, set_index = "_id"):
67 | from collections import defaultdict
68 | req = defaultdict(list)
69 | for dict_ in list_of_dicts:
70 | for k, v in dict_.items():
71 | req[k].append(v)
72 | req = pd.DataFrame.from_dict(req)
73 | if set_index:
74 | assert set_index in req.columns.tolist()
75 | t_df = req.set_index(set_index)
76 | if infer_dtype:
77 | dtype = self.infer_dtype(index, t_df.columns.values)
78 | if len(dtype):
79 | t_df = t_df.astype(dtype)
80 | return t_df
81 | list_of_dicts = []
82 | for dict_ in source_iter:
83 | list_of_dicts.append(dict_)
84 | if len(list_of_dicts) >= chunk_size:
85 | yield map_list_of_dicts_into_df(list_of_dicts)
86 | list_of_dicts = []
87 | if list_of_dicts:
88 | yield map_list_of_dicts_into_df(list_of_dicts)
89 |
90 |
91 | ep = es_pandas_edit(es_host)
92 | ep.ic.get_alias("*")
93 |
94 | chunk_size = 1000
95 | valid_part_from_es_iter = ep.to_pandas_iter(index = "valid_part", chunk_size = chunk_size)
96 |
97 |
98 | valid_part_tiny = None
99 | for ele in valid_part_from_es_iter:
100 | valid_part_tiny = ele
101 | break
102 | del valid_part_from_es_iter
103 |
104 |
105 | if ep.ic.exists("valid_part_tiny"):
106 | ep.ic.delete(index = "valid_part_tiny")
107 |
108 |
109 | ep.init_es_tmpl(valid_part_tiny, "valid_part_tiny_doc_type", delete=True)
110 | valid_part_tmp = ep.es.indices.get_template("valid_part_tiny_doc_type")
111 |
112 |
113 | es_index = valid_part_tmp["valid_part_tiny_doc_type"]
114 | es_index["mappings"]["properties"]["question_emb"] = {
115 | "type": "dense_vector",
116 | "dims": 768
117 | }
118 | es_index["mappings"]["properties"]["answer_emb"] = {
119 | "type": "dense_vector",
120 | "dims": 768
121 | }
122 | es_index["mappings"]["properties"]["question"] = {
123 | "type": "text",
124 | }
125 | es_index["mappings"]["properties"]["answer"] = {
126 | "type": "text",
127 | }
128 | es_index = {"mappings": es_index["mappings"]}
129 |
130 |
131 | ep.es.indices.create(index='valid_part_tiny', body=es_index, ignore=[400])
132 | question_embeddings = bi_model.encode(valid_part_tiny["question"].tolist(), convert_to_tensor=True, show_progress_bar=True)
133 | answer_embeddings = bi_model.encode(valid_part_tiny["answer"].tolist(), convert_to_tensor=True, show_progress_bar=True)
134 |
135 | valid_part_tiny["question_emb"] = question_embeddings.cpu().numpy().tolist()
136 | valid_part_tiny["answer_emb"] = answer_embeddings.cpu().numpy().tolist()
137 |
138 | ep.to_es(valid_part_tiny, "valid_part_tiny")
139 |
140 | chunk_size = 1000
141 | valid_part_tiny = list(ep.to_pandas_iter(index = "valid_part_tiny", chunk_size = None))[0]
142 |
143 |
144 | def search_by_embedding_in_es(index = "valid_part" ,embedding = np.asarray(valid_part_tiny["question_emb"].iloc[0]), on_column = "answer_emb"):
145 | vector_search_one = ep.es.search(index=index, body={
146 | "query": {
147 | "script_score": {
148 | "query": {
149 | "match_all": {}
150 | },
151 | "script": {
152 | "source": "cosineSimilarity(params.queryVector, doc['{}']) + 1.0".format(on_column),
153 | "params": {
154 | "queryVector": embedding
155 | }
156 | }
157 | }
158 | }
159 | }, ignore = [400])
160 | req = list(map(lambda x: (x["_source"]["question"], x["_source"]["answer"], x["_score"]) ,vector_search_one["hits"]["hits"]))
161 | req_df = pd.DataFrame(req, columns = ["question", "answer", "score"])
162 | return req_df
163 |
164 |
165 | def search_by_text_in_es(index = "valid_part" ,text = valid_part_tiny["question"].iloc[0], on_column = "answer",
166 | analyzer = "smartcn"):
167 | if analyzer is not None:
168 | bm25 = es.search(index = index,
169 | body={"query":
170 | {
171 | "match": {on_column:{"query" :text, "analyzer": analyzer} },
172 |
173 | }
174 | },
175 | )
176 | else:
177 | bm25 = ep.es.search(index=index, body={"query": {"match": {on_column: text}}})
178 | req = list(map(lambda x: (x["_source"]["question"], x["_source"]["answer"], x["_score"]) ,bm25["hits"]["hits"]))
179 | req_df = pd.DataFrame(req, columns = ["question", "answer", "score"])
180 | return req_df
181 |
182 |
183 | def valid_two_model(cross_model, ep, index, question, question_embedding, on_column = "answer_emb", size = 10):
184 | def search_by_embedding(ep ,index = "valid_part" ,embedding = np.asarray(valid_part_tiny["question_emb"].iloc[0]), on_column = "answer_emb"):
185 | vector_search_one = ep.es.search(index=index, body={
186 | "size": size,
187 | "query": {
188 | "script_score": {
189 | "query": {
190 | "match_all": {}
191 | },
192 | "script": {
193 | "source": "cosineSimilarity(params.queryVector, doc['{}']) + 1.0".format(on_column),
194 | "params": {
195 | "queryVector": embedding
196 | }
197 | }
198 | }
199 | }
200 | }, ignore = [400])
201 | req = list(map(lambda x: (x["_source"]["question"], x["_source"]["answer"], x["_score"]) ,vector_search_one["hits"]["hits"]))
202 | req_df = pd.DataFrame(req, columns = ["question", "answer", "score"])
203 | return req_df
204 | search_by_emb = search_by_embedding(ep ,index = index, embedding = question_embedding, on_column = on_column)
205 | print("question : {}".format(question))
206 | preds = cross_model.predict(search_by_emb.apply(lambda r: [question, r["answer"]], axis = 1).tolist())
207 | search_by_emb["cross_score"] = preds.tolist()
208 | return search_by_emb
209 | def produce_df(question, size = 10):
210 | question, question_embedding = valid_part_tiny[valid_part_tiny["question"] == question].iloc[0][["question", "question_emb"]]
211 | valid_df = valid_two_model(cross_model, ep, index = "valid_part_tiny", question = question, question_embedding = question_embedding, size = size)
212 | return valid_df
213 |
214 |
215 | class ScoreCalculator(object):
216 | def __init__(self,
217 | queries_ids,
218 | relevant_docs,
219 | mrr_at_k: List[int] = [10],
220 | ndcg_at_k: List[int] = [10],
221 | accuracy_at_k: List[int] = [1, 3, 5, 10],
222 | precision_recall_at_k: List[int] = [1, 3, 5, 10],
223 | map_at_k: List[int] = [100],
224 | ):
225 | "queries_ids list of query, relevant_docs key query value set or list of relevant_docs"
226 | self.queries_ids = queries_ids
227 | self.relevant_docs = relevant_docs
228 |
229 | self.mrr_at_k = mrr_at_k
230 | self.ndcg_at_k = ndcg_at_k
231 | self.accuracy_at_k = accuracy_at_k
232 | self.precision_recall_at_k = precision_recall_at_k
233 | self.map_at_k = map_at_k
234 | def compute_metrics(self, queries_result_list: List[object]):
235 | # Init score computation values
236 | num_hits_at_k = {k: 0 for k in self.accuracy_at_k}
237 | precisions_at_k = {k: [] for k in self.precision_recall_at_k}
238 | recall_at_k = {k: [] for k in self.precision_recall_at_k}
239 | MRR = {k: 0 for k in self.mrr_at_k}
240 | ndcg = {k: [] for k in self.ndcg_at_k}
241 | AveP_at_k = {k: [] for k in self.map_at_k}
242 |
243 | # Compute scores on results
244 | #### elements with hits one hit is a dict : {"corpus_id": corpus_text, "score": score}
245 | #### corpus_id replace by corpus text
246 | for query_itr in range(len(queries_result_list)):
247 | query_id = self.queries_ids[query_itr]
248 |
249 | # Sort scores
250 | top_hits = sorted(queries_result_list[query_itr], key=lambda x: x['score'], reverse=True)
251 | query_relevant_docs = self.relevant_docs[query_id]
252 |
253 | # Accuracy@k - We count the result correct, if at least one relevant doc is accross the top-k documents
254 | for k_val in self.accuracy_at_k:
255 | for hit in top_hits[0:k_val]:
256 | if hit['corpus_id'] in query_relevant_docs:
257 | num_hits_at_k[k_val] += 1
258 | break
259 |
260 | # Precision and Recall@k
261 | for k_val in self.precision_recall_at_k:
262 | num_correct = 0
263 | for hit in top_hits[0:k_val]:
264 | if hit['corpus_id'] in query_relevant_docs:
265 | num_correct += 1
266 |
267 | precisions_at_k[k_val].append(num_correct / k_val)
268 | recall_at_k[k_val].append(num_correct / len(query_relevant_docs))
269 |
270 | # MRR@k
271 | for k_val in self.mrr_at_k:
272 | for rank, hit in enumerate(top_hits[0:k_val]):
273 | if hit['corpus_id'] in query_relevant_docs:
274 | MRR[k_val] += 1.0 / (rank + 1)
275 | #break
276 |
277 | # NDCG@k
278 | for k_val in self.ndcg_at_k:
279 | predicted_relevance = [1 if top_hit['corpus_id'] in query_relevant_docs else 0 for top_hit in top_hits[0:k_val]]
280 | true_relevances = [1] * len(query_relevant_docs)
281 |
282 | ndcg_value = self.compute_dcg_at_k(predicted_relevance, k_val) / self.compute_dcg_at_k(true_relevances, k_val)
283 | ndcg[k_val].append(ndcg_value)
284 |
285 | # MAP@k
286 | for k_val in self.map_at_k:
287 | num_correct = 0
288 | sum_precisions = 0
289 |
290 | for rank, hit in enumerate(top_hits[0:k_val]):
291 | if hit['corpus_id'] in query_relevant_docs:
292 | num_correct += 1
293 | sum_precisions += num_correct / (rank + 1)
294 |
295 | avg_precision = sum_precisions / min(k_val, len(query_relevant_docs))
296 | AveP_at_k[k_val].append(avg_precision)
297 |
298 | # Compute averages
299 | for k in num_hits_at_k:
300 | #num_hits_at_k[k] /= len(self.queries)
301 | num_hits_at_k[k] /= len(queries_result_list)
302 |
303 | for k in precisions_at_k:
304 | precisions_at_k[k] = np.mean(precisions_at_k[k])
305 |
306 | for k in recall_at_k:
307 | recall_at_k[k] = np.mean(recall_at_k[k])
308 |
309 | for k in ndcg:
310 | ndcg[k] = np.mean(ndcg[k])
311 |
312 | for k in MRR:
313 | #MRR[k] /= len(self.queries)
314 | MRR[k] /= len(queries_result_list)
315 |
316 | for k in AveP_at_k:
317 | AveP_at_k[k] = np.mean(AveP_at_k[k])
318 | return {'accuracy@k': num_hits_at_k, 'precision@k': precisions_at_k, 'recall@k': recall_at_k, 'ndcg@k': ndcg, 'mrr@k': MRR, 'map@k': AveP_at_k}
319 | @staticmethod
320 | def compute_dcg_at_k(relevances, k):
321 | dcg = 0
322 | for i in range(min(len(relevances), k)):
323 | dcg += relevances[i] / np.log2(i + 2) #+2 as we start our idx at 0
324 | return dcg
325 |
326 |
327 | def map_dev_samples_to_score_calculator_format(dev_samples):
328 | if isinstance(dev_samples, dict):
329 | dev_samples = list(dev_samples.values())
330 | queries_ids = list(map(lambda x: x["query"] ,dev_samples))
331 | relevant_docs = dict(map(lambda idx: (dev_samples[idx]["query"], dev_samples[idx]["positive"]), range(len(dev_samples))))
332 | return ScoreCalculator(queries_ids, relevant_docs)
333 |
334 | def map_valid_df_to_score_calculator_format(query ,valid_df):
335 | queries_ids = [query]
336 | relevant_docs = {query: valid_df[valid_df["question"] == query]["answer"].tolist()}
337 | return ScoreCalculator(queries_ids, relevant_docs)
338 |
339 |
340 | def df_to_mrr_score(df, query, score_col, mrr_at_k = 10):
341 | #model_input = [[query, doc] for doc in docs]
342 | #pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
343 | is_relevant = list(map(lambda t2: True if t2[1]["question"] == query else False, df.iterrows()))
344 | pred_scores = df[score_col].values
345 | pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
346 | mrr_score = 0
347 | for rank, index in enumerate(pred_scores_argsort[0:mrr_at_k]):
348 | if is_relevant[index]:
349 | mrr_score = 1 / (rank+1)
350 | #mrr_score += 1 / (rank+1)
351 | break
352 | return mrr_score
353 |
354 |
355 | question_list = valid_part_tiny["question"].value_counts().index.tolist()
356 | valid_df = produce_df(question_list[10], size = 100)
357 |
358 | def produce_score_dict(query ,valid_df, column = "score"):
359 | queries_result_list = valid_df[["answer", column]].apply(lambda x: {"corpus_id": x["answer"], "score": x[column]}, axis = 1).tolist()
360 | score_dict = map_valid_df_to_score_calculator_format(query, valid_df).compute_metrics([queries_result_list])
361 | return score_dict
362 |
363 | produce_score_dict(question_list[10] ,valid_df, "score")
364 | produce_score_dict(question_list[10] ,valid_df, "cross_score")
365 | produce_score_dict(question_list[10] ,valid_df.head(20), "score")
366 | produce_score_dict(question_list[10] ,valid_df.head(20), "cross_score")
367 |
368 | valid_df.head(20)
369 | valid_df.head(20).sort_values(by = "cross_score", ascending = False)
370 | valid_df.sort_values(by = "cross_score", ascending = False).head(10)
371 |
372 | sns.distplot(valid_df["cross_score"])
373 |
374 |
375 |
--------------------------------------------------------------------------------