├── LICENSE ├── README.md ├── construct_tree.py ├── dataset.py ├── figure ├── overflow.pdf ├── overflow.png └── tree.pdf ├── model.py ├── model_search_cos.py ├── reorganize_clusters_tree.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Haitao Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 8 | 9 | # Constructing Tree-based Index for Efficient and Effective Dense Retrieval 10 | 11 | The official repo for our SIGIR'23 Full paper: [Constructing Tree-based Index for Efficient and Effective Dense Retrieval](https://arxiv.org/abs/2304.11943) 12 | 13 | ## Introduction 14 | 15 | To balance the effectiveness and efficiency of the tree-based indexes, we propose **JTR**, which stands for **J**oint optimization of **TR**ee-based index and query encoding. To jointly optimize index structure and query encoder in an end-to-end manner, JTR drops the original ``encoding-indexing" training paradigm and designs a unified contrastive learning loss. However, training tree-based indexes using contrastive learning loss is non-trivial due to the problem of differentiability. To overcome this obstacle, the tree-based index is divided into two parts: cluster node embeddings and cluster assignment. For differentiable cluster node embeddings, which are small but very critical, we design tree-based negative sampling to optimize them. For cluster assignment, an overlapped cluster method is applied to iteratively optimize it. 16 | 17 | ![image](./figure/overflow.png) 18 | 19 | ## Preprocess 20 | 21 | JTR initializes the document embeddings with STAR, refer to [DRhard](https://github.com/jingtaozhan/DRhard) for details. 22 | 23 | 24 | Run the following codes in DRhard to preprocess document. 25 | 26 | `` 27 | python preprocess.py --data_type 0; python preprocess.py --data_type 1 28 | `` 29 | 30 | 31 | ## Tree Initialization 32 | 33 | After getting the text embeddings, we can initialize the tree using recursive k-means. 34 | 35 | Run the following codes: 36 | 37 | `` 38 | python construct_tree.py 39 | `` 40 | 41 | We will get the following files: 42 | 43 | tree.pkl: Tree structure 44 | 45 | node_dict.pkl: Map of node id to node 46 | 47 | node_list: Nodes per level 48 | 49 | pid_labelid.memmap: Mapping of document ids to clustering nodes 50 | 51 | leaf_dict.pkl: Leaf Nodes 52 | 53 | 54 | ## Train 55 | Run the following codes: 56 | `` 57 | python train.py --task train 58 | `` 59 | 60 | The training process trains both the query encoder and the clustering node embeddings. Therefore, we need to save both the node embeddings and the query encoder. 61 | 62 | ## Inference 63 | 64 | Run the following codes: 65 | 66 | `` 67 | python train.py --task dev 68 | `` 69 | 70 | The inference process can construct the matrix M for Reorganize Cluster. 71 | 72 | ## Reorganize Cluster 73 | 74 | Run the following codes: 75 | 76 | `` 77 | python reorganize_clusters_tree.py 78 | `` 79 | 80 | The re-clustering requires M and Y matrices. Y matrix is constructed by running other retrieval models. M matrix is constructed by inference on the tree index. 81 | 82 | 83 | ## Other 84 | 85 | This work was done when I was a beginner and the code was embarrassing. If somebody can further organize and optimize the code or integrate it into Faiss with C. I would appreciate it. 86 | 87 | ## Citations 88 | 89 | If you find our work useful, please do not save your star and cite our work: 90 | 91 | ``` 92 | @misc{JTR, 93 | title={Constructing Tree-based Index for Efficient and Effective Dense Retrieval}, 94 | author={Haitao Li and Qingyao Ai and Jingtao Zhan and Jiaxin Mao and Yiqun Liu and Zheng Liu and Zhao Cao}, 95 | year={2023}, 96 | eprint={2304.11943}, 97 | archivePrefix={arXiv}, 98 | primaryClass={cs.IR} 99 | } 100 | ``` -------------------------------------------------------------------------------- /construct_tree.py: -------------------------------------------------------------------------------- 1 | from http.client import PROXY_AUTHENTICATION_REQUIRED 2 | import numpy as np 3 | import time 4 | import os 5 | import pandas as pd 6 | from sklearn.cluster import KMeans 7 | import joblib 8 | import pickle as pkl 9 | os.environ["CUDA_VISIBLE_DEVICES"] = '1,2,3,4,5' 10 | 11 | 12 | 13 | def save_object(obj, path): 14 | with open(path,'wb') as f: 15 | pkl.dump(obj,f) 16 | 17 | def load_object(path): 18 | with open(path, 'rb') as f: 19 | obj = pkl.load(f) 20 | return obj 21 | 22 | def _node_list(root): 23 | def node_val(node): 24 | if(node.isleaf == False): 25 | return node.val 26 | else: 27 | return node.val 28 | 29 | node_queue = [root] 30 | arr_arr_node = [] 31 | arr_arr_node.append([node_val(node_queue[0])]) 32 | while node_queue: 33 | tmp = [] 34 | tmp_val = [] 35 | for node in node_queue: 36 | for child in node.children: 37 | tmp.append(child) 38 | tmp_val.append(node_val(child)) 39 | if len(tmp_val) > 0: 40 | arr_arr_node.append(tmp_val) 41 | node_queue = tmp 42 | return arr_arr_node 43 | 44 | class TreeNode(object): 45 | """define the tree node structure.""" 46 | def __init__(self, x ,item_embedding = None, layer = None): 47 | self.val = x 48 | self.embedding = item_embedding 49 | self.parent = None 50 | self.children = [] 51 | self.isleaf = False 52 | self.pids = [] 53 | self.layer = layer 54 | 55 | def getval(self): 56 | return self.val 57 | def getchildren(self): 58 | return self.children 59 | def add(self, node): 60 | ##if full 61 | if len(self.children) == 10: 62 | return False 63 | else: 64 | self.children.append(node) 65 | 66 | 67 | class TreeInitialize(object): 68 | """"Build the random binary tree.""" 69 | def __init__(self, pid_embeddings, pids, blance_factor, leaf_factor): 70 | self.embeddings = pid_embeddings 71 | self.pids = pids 72 | self.root = None 73 | self.blance_factor = blance_factor 74 | self.leaf_factor = leaf_factor 75 | self.leaf_dict = {} 76 | self.node_dict = {} 77 | self.node_size = 0 78 | 79 | def _k_means_clustering(self, pid_embeddings): 80 | if len(pid_embeddings)>1000000: 81 | idxs = np.arange(pid_embeddings.shape[0]) 82 | np.random.shuffle(idxs) 83 | idxs = idxs[0:1000000] 84 | train_embeddings = pid_embeddings[idxs] 85 | else: 86 | train_embeddings = pid_embeddings 87 | train_embeddings = pid_embeddings 88 | kmeans = KMeans(n_clusters=self.blance_factor, max_iter=3000, n_init=100).fit(train_embeddings) 89 | return kmeans 90 | 91 | def _build_ten_tree(self, root, pid_embeddings, pids, layer): 92 | if len(pids) < self.leaf_factor: 93 | root.isleaf = True 94 | root.pids = pids 95 | self.leaf_dict[root.val] = root 96 | return root 97 | 98 | kmeans = self._k_means_clustering(pid_embeddings) 99 | clusters_embeddings = kmeans.cluster_centers_ 100 | labels = kmeans.labels_ 101 | for i in range(self.blance_factor): ## self.blance_factor < 10 102 | val = root.val + str(i) 103 | node = TreeNode(x = val, item_embedding=clusters_embeddings[i],layer=layer+1) 104 | node.parent = root 105 | index = np.where(labels == i)[0] 106 | pid_embedding = pid_embeddings[index] 107 | pid = pids[index] 108 | node = self._build_ten_tree(node, pid_embedding, pid, layer+1) 109 | root.add(node) 110 | return root 111 | 112 | def clustering_tree(self): 113 | root = TreeNode('0') 114 | self.root = self._build_ten_tree(root, self.embeddings, self.pids, layer = 0) 115 | return self.root 116 | 117 | 118 | 119 | 120 | if __name__ == '__main__': 121 | 122 | type = "passage" 123 | max_pid = 1000 124 | pass_embedding_dir = f'passages.memmap' 125 | 126 | 127 | ## build tree 128 | output_path = f"../tree/{type}/cluster_tree" 129 | tree_path = f"{output_path}/tree.pkl" 130 | dict_label = {} 131 | pid_embeddings_all = np.memmap(pass_embedding_dir,dtype=np.float32,mode="r").reshape(-1,768) 132 | pids_all = [x for x in range(pid_embeddings_all.shape[0])] 133 | pids_all = np.array(pids_all) 134 | tree = TreeInitialize(pid_embeddings_all, pids_all) 135 | _ = tree.clustering_tree() 136 | save_object(tree,tree_path) 137 | 138 | 139 | ## save node_dict 140 | tree = load_object(tree_path) 141 | node_dict = {} 142 | node_queue = [tree.root] 143 | val = [] 144 | while node_queue: 145 | current_node = node_queue.pop(0) 146 | node_dict[current_node.val] = current_node 147 | for child in current_node.children: 148 | node_queue.append(child) 149 | print("node dict length") 150 | print(len(node_dict)) 151 | print("leaf dict length") 152 | print(len(tree.leaf_dict)) 153 | save_object(node_dict,f"{output_path}/node_dict.pkl") 154 | 155 | ## save node_list 156 | tree = load_object(tree_path) 157 | root = tree.root 158 | node_list = _node_list(root) 159 | save_object(node_list,f"{output_path}/node_list.pkl") 160 | 161 | 162 | ## pid2cluster 163 | for leaf in tree.leaf_dict: 164 | node = tree.leaf_dict[leaf] 165 | pids = node.pids 166 | for pid in pids: 167 | dict_label[pid] = str(node.val) 168 | df = pd.DataFrame.from_dict(dict_label, orient='index',columns=['labels']) 169 | df = df.reset_index().rename(columns = {'index':'pid'}) 170 | df.to_csv(f"{output_path}/pid_labelid.memmap",header=False, index=False) 171 | 172 | print('end') 173 | tree = load_object('tree.pkl') 174 | print(len(tree.leaf_dict)) 175 | save_object(tree.leaf_dict,'leaf_dict.pkl') 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path += ["./"] 3 | import os 4 | import math 5 | import json 6 | import torch 7 | import pickle 8 | import random 9 | import logging 10 | import numpy as np 11 | from tqdm import tqdm 12 | from torch import nn 13 | from collections import defaultdict 14 | from torch.utils.data import Dataset 15 | from typing import List 16 | import transformers 17 | if int(transformers.__version__[0]) <=3: 18 | from transformers.modeling_roberta import RobertaPreTrainedModel 19 | else: 20 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel 21 | from transformers import RobertaModel 22 | import torch.nn.functional as F 23 | from torch.cuda.amp import autocast 24 | 25 | 26 | class SequenceDataset(Dataset): 27 | def __init__(self, ids_cache, max_seq_length): 28 | self.ids_cache = ids_cache 29 | self.max_seq_length = max_seq_length 30 | 31 | def __len__(self): 32 | return len(self.ids_cache) 33 | 34 | def __getitem__(self, item): 35 | input_ids = self.ids_cache[item].tolist() 36 | seq_length = min(self.max_seq_length-1, len(input_ids)-1) 37 | input_ids = [input_ids[0]] + input_ids[1:seq_length] + [input_ids[-1]] 38 | attention_mask = [1]*len(input_ids) 39 | 40 | ret_val = { 41 | "input_ids": input_ids, 42 | "attention_mask": attention_mask, 43 | "id": item, 44 | } 45 | return ret_val 46 | 47 | 48 | class TrainQueryDataset(SequenceDataset): 49 | def __init__(self, queryids_cache, 50 | rel_file, max_query_length): 51 | SequenceDataset.__init__(self, queryids_cache, max_query_length) 52 | self.reldict = load_rel(rel_file) 53 | 54 | def __getitem__(self, item): 55 | ret_val = super().__getitem__(item) 56 | ret_val['rel_ids'] = self.reldict[item] 57 | return ret_val 58 | 59 | class TextTokenIdsCache: 60 | def __init__(self, data_dir, prefix): 61 | meta = json.load(open(f"{data_dir}/{prefix}_meta")) 62 | self.total_number = meta['total_number'] 63 | self.max_seq_len = meta['embedding_size'] 64 | try: 65 | self.ids_arr = np.memmap(f"{data_dir}/{prefix}.memmap", 66 | shape=(self.total_number, self.max_seq_len), 67 | dtype=np.dtype(meta['type']), mode="r") 68 | self.lengths_arr = np.load(f"{data_dir}/{prefix}_length.npy") 69 | except FileNotFoundError: 70 | self.ids_arr = np.memmap(f"{data_dir}/memmap/{prefix}.memmap", 71 | shape=(self.total_number, self.max_seq_len), 72 | dtype=np.dtype(meta['type']), mode="r") 73 | self.lengths_arr = np.load(f"{data_dir}/memmap/{prefix}_length.npy") 74 | assert len(self.lengths_arr) == self.total_number 75 | 76 | def __len__(self): 77 | return self.total_number 78 | 79 | def __getitem__(self, item): 80 | return self.ids_arr[item, :self.lengths_arr[item]] 81 | 82 | 83 | 84 | 85 | 86 | class SubsetSeqDataset: 87 | def __init__(self, subset: List[int], ids_cache, max_seq_length): 88 | self.subset = sorted(list(subset)) 89 | self.alldataset = SequenceDataset(ids_cache, max_seq_length) 90 | 91 | def __len__(self): 92 | return len(self.subset) 93 | 94 | def __getitem__(self, item): 95 | return self.alldataset[self.subset[item]] 96 | 97 | 98 | def load_rel(rel_path): 99 | reldict = defaultdict(list) 100 | for line in tqdm(open(rel_path), desc=os.path.split(rel_path)[1]): 101 | qid, _, pid, _ = line.split() 102 | qid, pid = int(qid), int(pid) 103 | reldict[qid].append((pid)) 104 | return dict(reldict) 105 | 106 | 107 | def load_rank(rank_path): 108 | rankdict = defaultdict(list) 109 | for line in tqdm(open(rank_path), desc=os.path.split(rank_path)[1]): 110 | qid, pid, _ = line.split() 111 | qid, pid = int(qid), int(pid) 112 | rankdict[qid].append(pid) 113 | return dict(rankdict) 114 | 115 | 116 | def pack_tensor_2D(lstlst, default, dtype, length=None): 117 | batch_size = len(lstlst) 118 | length = length if length is not None else max(len(l) for l in lstlst) 119 | tensor = default * torch.ones((batch_size, length), dtype=dtype) 120 | for i, l in enumerate(lstlst): 121 | tensor[i, :len(l)] = torch.tensor(l, dtype=dtype) 122 | return tensor 123 | 124 | 125 | def get_collate_function(max_seq_length): 126 | cnt = 0 127 | def collate_function(batch): 128 | nonlocal cnt 129 | length = None 130 | if cnt < 10: 131 | length = max_seq_length 132 | cnt += 1 133 | 134 | input_ids = [x["input_ids"] for x in batch] 135 | attention_mask = [x["attention_mask"] for x in batch] 136 | data = { 137 | "input_ids": pack_tensor_2D(input_ids, default=1, 138 | dtype=torch.int64, length=length), 139 | "attention_mask": pack_tensor_2D(attention_mask, default=0, 140 | dtype=torch.int64, length=length), 141 | } 142 | ids = [x['id'] for x in batch] 143 | return data, ids 144 | return collate_function 145 | 146 | 147 | 148 | class TrainInbatchDataset(Dataset): 149 | def __init__(self, rel_file, queryids_cache, docids_cache, 150 | max_query_length, max_doc_length): 151 | self.query_dataset = SequenceDataset(queryids_cache, max_query_length) 152 | self.doc_dataset = SequenceDataset(docids_cache, max_doc_length) 153 | self.reldict = load_rel(rel_file) 154 | self.qids = sorted(list(self.reldict.keys())) 155 | 156 | def __len__(self): 157 | return len(self.qids) 158 | 159 | def __getitem__(self, item): 160 | qid = self.qids[item] 161 | pid = random.choice(self.reldict[qid]) 162 | query_data = self.query_dataset[qid] 163 | passage_data = self.doc_dataset[pid] 164 | return query_data, passage_data 165 | 166 | 167 | class TrainInbatchWithHardDataset(TrainInbatchDataset): 168 | def __init__(self, rel_file, rank_file, queryids_cache, 169 | docids_cache, hard_num, 170 | max_query_length, max_doc_length): 171 | TrainInbatchDataset.__init__(self, 172 | rel_file, queryids_cache, docids_cache, 173 | max_query_length, max_doc_length) 174 | self.rankdict = json.load(open(rank_file)) 175 | assert hard_num > 0 176 | self.hard_num = hard_num 177 | 178 | def __len__(self): 179 | return len(self.qids) 180 | 181 | def __getitem__(self, item): 182 | qid = self.qids[item] 183 | pid = random.choice(self.reldict[qid]) 184 | query_data = self.query_dataset[qid] 185 | passage_data = self.doc_dataset[pid] 186 | hardpids = random.sample(self.rankdict[str(qid)], self.hard_num) 187 | hard_passage_data = [self.doc_dataset[hardpid] for hardpid in hardpids] 188 | return query_data, passage_data, hard_passage_data 189 | 190 | 191 | class TrainInbatchWithRandDataset(TrainInbatchDataset): 192 | def __init__(self, rel_file, queryids_cache, 193 | docids_cache, rand_num, 194 | max_query_length, max_doc_length): 195 | TrainInbatchDataset.__init__(self, 196 | rel_file, queryids_cache, docids_cache, 197 | max_query_length, max_doc_length) 198 | assert rand_num > 0 199 | self.rand_num = rand_num 200 | 201 | def __getitem__(self, item): 202 | qid = self.qids[item] 203 | pid = random.choice(self.reldict[qid]) 204 | query_data = self.query_dataset[qid] 205 | passage_data = self.doc_dataset[pid] 206 | randpids = random.sample(range(len(self.doc_dataset)), self.rand_num) 207 | rand_passage_data = [self.doc_dataset[randpid] for randpid in randpids] 208 | return query_data, passage_data, rand_passage_data 209 | 210 | 211 | def single_get_collate_function(max_seq_length, padding=False): 212 | cnt = 0 213 | def collate_function(batch): 214 | nonlocal cnt 215 | length = None 216 | if cnt < 10 or padding: 217 | length = max_seq_length 218 | cnt += 1 219 | 220 | input_ids = [x["input_ids"] for x in batch] 221 | attention_mask = [x["attention_mask"] for x in batch] 222 | data = { 223 | "input_ids": pack_tensor_2D(input_ids, default=1, 224 | dtype=torch.int64, length=length), 225 | "attention_mask": pack_tensor_2D(attention_mask, default=0, 226 | dtype=torch.int64, length=length), 227 | } 228 | ids = [x['id'] for x in batch] 229 | return data, ids 230 | return collate_function 231 | 232 | 233 | def dual_get_collate_function(max_query_length, max_doc_length, rel_dict, padding=False): 234 | query_collate_func = single_get_collate_function(max_query_length, padding) 235 | doc_collate_func = single_get_collate_function(max_doc_length, padding) 236 | 237 | def collate_function(batch): 238 | query_data, query_ids = query_collate_func([x[0] for x in batch]) 239 | doc_data, doc_ids = doc_collate_func([x[1] for x in batch]) 240 | rel_pair_mask = [[1 if docid not in rel_dict[qid] else 0 241 | for docid in doc_ids] 242 | for qid in query_ids] 243 | input_data = { 244 | "input_query_ids":query_data['input_ids'], 245 | "query_attention_mask":query_data['attention_mask'], 246 | "input_doc_ids":doc_data['input_ids'], 247 | "doc_attention_mask":doc_data['attention_mask'], 248 | "rel_pair_mask":torch.FloatTensor(rel_pair_mask), 249 | } 250 | return input_data 251 | return collate_function 252 | 253 | 254 | def triple_get_collate_function(max_query_length, max_doc_length, rel_dict, padding=False): 255 | query_collate_func = single_get_collate_function(max_query_length, padding) 256 | doc_collate_func = single_get_collate_function(max_doc_length, padding) 257 | 258 | def collate_function(batch): 259 | query_data, query_ids = query_collate_func([x[0] for x in batch]) 260 | doc_data, doc_ids = doc_collate_func([x[1] for x in batch]) 261 | hard_doc_data, hard_doc_ids = doc_collate_func(sum([x[2] for x in batch], [])) 262 | rel_pair_mask = [[1 if docid not in rel_dict[qid] else 0 263 | for docid in doc_ids] 264 | for qid in query_ids] 265 | hard_pair_mask = [[1 if docid not in rel_dict[qid] else 0 266 | for docid in hard_doc_ids ] 267 | for qid in query_ids] 268 | query_num = len(query_data['input_ids']) 269 | hard_num_per_query = len(batch[0][2]) 270 | input_data = { 271 | "input_query_ids":query_data['input_ids'], 272 | "query_attention_mask":query_data['attention_mask'], 273 | "input_doc_ids":doc_data['input_ids'], 274 | "doc_attention_mask":doc_data['attention_mask'], 275 | "other_doc_ids":hard_doc_data['input_ids'].reshape(query_num, hard_num_per_query, -1), 276 | "other_doc_attention_mask":hard_doc_data['attention_mask'].reshape(query_num, hard_num_per_query, -1), 277 | "rel_pair_mask":torch.FloatTensor(rel_pair_mask), 278 | "hard_pair_mask":torch.FloatTensor(hard_pair_mask), 279 | } 280 | return input_data 281 | return collate_function 282 | 283 | 284 | class EmbeddingMixin: 285 | """ 286 | Mixin for common functions in most embedding models. Each model should define its own bert-like backbone and forward. 287 | We inherit from RobertaModel to use from_pretrained 288 | """ 289 | def __init__(self, model_argobj): 290 | if model_argobj is None: 291 | self.use_mean = False 292 | else: 293 | self.use_mean = model_argobj.use_mean 294 | print("Using mean:", self.use_mean) 295 | 296 | def _init_weights(self, module): 297 | """ Initialize the weights """ 298 | if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)): #判断是否是一个已知类型 299 | # Slightly different from the TF version which uses truncated_normal for initialization 300 | # cf https://github.com/pytorch/pytorch/pull/5617 301 | module.weight.data.normal_(mean=0.0, std=0.02) 302 | 303 | def masked_mean(self, t, mask): 304 | s = torch.sum(t * mask.unsqueeze(-1).float(), axis=1) 305 | d = mask.sum(axis=1, keepdim=True).float() 306 | return s / d 307 | 308 | def masked_mean_or_first(self, emb_all, mask): 309 | # emb_all is a tuple from bert - sequence output, pooler 310 | assert isinstance(emb_all, tuple) 311 | if self.use_mean: 312 | return self.masked_mean(emb_all[0], mask) 313 | else: 314 | return emb_all[0][:, 0] 315 | 316 | def query_emb(self, input_ids, attention_mask): 317 | raise NotImplementedError("Please Implement this method") 318 | 319 | def body_emb(self, input_ids, attention_mask): 320 | raise NotImplementedError("Please Implement this method") 321 | 322 | 323 | class BaseModelDot(EmbeddingMixin): 324 | def _text_encode(self, input_ids, attention_mask): 325 | # TODO should raise NotImplementedError 326 | # temporarily do this 327 | return None 328 | 329 | def query_emb(self, input_ids, attention_mask): 330 | outputs1 = self._text_encode(input_ids=input_ids, 331 | attention_mask=attention_mask) 332 | full_emb = self.masked_mean_or_first(outputs1, attention_mask) 333 | query1 = self.norm(self.embeddingHead(full_emb)) 334 | return query1 335 | 336 | def body_emb(self, input_ids, attention_mask): 337 | return self.query_emb(input_ids, attention_mask) 338 | 339 | def forward(self, input_ids, attention_mask, is_query, *args): 340 | assert len(args) == 0 341 | if is_query: 342 | return self.query_emb(input_ids, attention_mask) 343 | else: 344 | return self.body_emb(input_ids, attention_mask) 345 | 346 | 347 | class RobertaDot(BaseModelDot, RobertaPreTrainedModel): 348 | def __init__(self, config, model_argobj=None): 349 | BaseModelDot.__init__(self, model_argobj) 350 | RobertaPreTrainedModel.__init__(self, config) 351 | if int(transformers.__version__[0]) ==4 : 352 | config.return_dict = False 353 | self.roberta = RobertaModel(config, add_pooling_layer=False) 354 | if hasattr(config, "output_embedding_size"): 355 | self.output_embedding_size = config.output_embedding_size 356 | else: 357 | self.output_embedding_size = config.hidden_size 358 | print("output_embedding_size", self.output_embedding_size) 359 | self.embeddingHead = nn.Linear(config.hidden_size, self.output_embedding_size) 360 | self.norm = nn.LayerNorm(self.output_embedding_size) 361 | self.apply(self._init_weights) 362 | 363 | def _text_encode(self, input_ids, attention_mask): 364 | outputs1 = self.roberta(input_ids=input_ids, 365 | attention_mask=attention_mask) 366 | return outputs1 367 | 368 | 369 | class RobertaDot_InBatch(RobertaDot): 370 | def forward(self, input_query_ids, query_attention_mask, 371 | input_doc_ids, doc_attention_mask, 372 | other_doc_ids=None, other_doc_attention_mask=None, 373 | rel_pair_mask=None, hard_pair_mask=None): 374 | return inbatch_train(self.query_emb, self.body_emb, 375 | input_query_ids, query_attention_mask, 376 | input_doc_ids, doc_attention_mask, 377 | other_doc_ids, other_doc_attention_mask, 378 | rel_pair_mask, hard_pair_mask) 379 | 380 | 381 | class RobertaDot_Rand(RobertaDot): 382 | def forward(self, input_query_ids, query_attention_mask, 383 | input_doc_ids, doc_attention_mask, 384 | other_doc_ids=None, other_doc_attention_mask=None, 385 | rel_pair_mask=None, hard_pair_mask=None): 386 | return randneg_train(self.query_emb, self.body_emb, 387 | input_query_ids, query_attention_mask, 388 | input_doc_ids, doc_attention_mask, 389 | other_doc_ids, other_doc_attention_mask, 390 | hard_pair_mask) 391 | 392 | 393 | def inbatch_train(query_encode_func, doc_encode_func, 394 | input_query_ids, query_attention_mask, 395 | input_doc_ids, doc_attention_mask, 396 | other_doc_ids=None, other_doc_attention_mask=None, 397 | rel_pair_mask=None, hard_pair_mask=None): 398 | 399 | query_embs = query_encode_func(input_query_ids, query_attention_mask) 400 | doc_embs = doc_encode_func(input_doc_ids, doc_attention_mask) 401 | 402 | batch_size = query_embs.shape[0] 403 | with autocast(enabled=False): 404 | batch_scores = torch.matmul(query_embs, doc_embs.T) 405 | # print("batch_scores", batch_scores) 406 | single_positive_scores = torch.diagonal(batch_scores, 0) 407 | # print("positive_scores", positive_scores) 408 | positive_scores = single_positive_scores.reshape(-1, 1).repeat(1, batch_size).reshape(-1) 409 | if rel_pair_mask is None: 410 | rel_pair_mask = 1 - torch.eye(batch_size, dtype=batch_scores.dtype, device=batch_scores.device) 411 | # print("mask", mask) 412 | batch_scores = batch_scores.reshape(-1) 413 | logit_matrix = torch.cat([positive_scores.unsqueeze(1), 414 | batch_scores.unsqueeze(1)], dim=1) 415 | # print(logit_matrix) 416 | lsm = F.log_softmax(logit_matrix, dim=1) 417 | loss = -1.0 * lsm[:, 0] * rel_pair_mask.reshape(-1) 418 | # print(loss) 419 | # print("\n") 420 | first_loss, first_num = loss.sum(), rel_pair_mask.sum() 421 | 422 | if other_doc_ids is None: 423 | return (first_loss/first_num,) 424 | 425 | # other_doc_ids: batch size, per query doc, length 426 | other_doc_num = other_doc_ids.shape[0] * other_doc_ids.shape[1] 427 | other_doc_ids = other_doc_ids.reshape(other_doc_num, -1) 428 | other_doc_attention_mask = other_doc_attention_mask.reshape(other_doc_num, -1) 429 | other_doc_embs = doc_encode_func(other_doc_ids, other_doc_attention_mask) 430 | 431 | with autocast(enabled=False): 432 | other_batch_scores = torch.matmul(query_embs, other_doc_embs.T) 433 | other_batch_scores = other_batch_scores.reshape(-1) 434 | positive_scores = single_positive_scores.reshape(-1, 1).repeat(1, other_doc_num).reshape(-1) 435 | other_logit_matrix = torch.cat([positive_scores.unsqueeze(1), 436 | other_batch_scores.unsqueeze(1)], dim=1) 437 | # print(logit_matrix) 438 | other_lsm = F.log_softmax(other_logit_matrix, dim=1) 439 | other_loss = -1.0 * other_lsm[:, 0] 440 | # print(loss) 441 | # print("\n") 442 | if hard_pair_mask is not None: 443 | hard_pair_mask = hard_pair_mask.reshape(-1) 444 | other_loss = other_loss * hard_pair_mask 445 | second_loss, second_num = other_loss.sum(), hard_pair_mask.sum() 446 | else: 447 | second_loss, second_num = other_loss.sum(), len(other_loss) 448 | 449 | return ((first_loss+second_loss)/(first_num+second_num),) 450 | 451 | 452 | def randneg_train(query_encode_func, doc_encode_func, 453 | input_query_ids, query_attention_mask, 454 | input_doc_ids, doc_attention_mask, 455 | other_doc_ids=None, other_doc_attention_mask=None, 456 | hard_pair_mask=None): 457 | 458 | query_embs = query_encode_func(input_query_ids, query_attention_mask) 459 | doc_embs = doc_encode_func(input_doc_ids, doc_attention_mask) 460 | 461 | with autocast(enabled=False): 462 | batch_scores = torch.matmul(query_embs, doc_embs.T) 463 | single_positive_scores = torch.diagonal(batch_scores, 0) 464 | # other_doc_ids: batch size, per query doc, length 465 | other_doc_num = other_doc_ids.shape[0] * other_doc_ids.shape[1] 466 | other_doc_ids = other_doc_ids.reshape(other_doc_num, -1) 467 | other_doc_attention_mask = other_doc_attention_mask.reshape(other_doc_num, -1) 468 | other_doc_embs = doc_encode_func(other_doc_ids, other_doc_attention_mask) 469 | 470 | with autocast(enabled=False): 471 | other_batch_scores = torch.matmul(query_embs, other_doc_embs.T) 472 | other_batch_scores = other_batch_scores.reshape(-1) 473 | positive_scores = single_positive_scores.reshape(-1, 1).repeat(1, other_doc_num).reshape(-1) 474 | other_logit_matrix = torch.cat([positive_scores.unsqueeze(1), 475 | other_batch_scores.unsqueeze(1)], dim=1) 476 | # print(logit_matrix) 477 | other_lsm = F.log_softmax(other_logit_matrix, dim=1) 478 | other_loss = -1.0 * other_lsm[:, 0] 479 | if hard_pair_mask is not None: 480 | hard_pair_mask = hard_pair_mask.reshape(-1) 481 | other_loss = other_loss * hard_pair_mask 482 | second_loss, second_num = other_loss.sum(), hard_pair_mask.sum() 483 | else: 484 | second_loss, second_num = other_loss.sum(), len(other_loss) 485 | return (second_loss/second_num,) -------------------------------------------------------------------------------- /figure/overflow.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSHaitao/JTR/da4074a664457595f84733bc0f664752493e02e8/figure/overflow.pdf -------------------------------------------------------------------------------- /figure/overflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSHaitao/JTR/da4074a664457595f84733bc0f664752493e02e8/figure/overflow.png -------------------------------------------------------------------------------- /figure/tree.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSHaitao/JTR/da4074a664457595f84733bc0f664752493e02e8/figure/tree.pdf -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from asyncio.format_helpers import _format_callback_source 2 | import enum 3 | import sys 4 | import random 5 | from grpc import compute_engine_channel_credentials 6 | 7 | from torch._C import dtype 8 | sys.path += ['./'] 9 | import torch 10 | from torch import nn 11 | import numpy as np 12 | import transformers 13 | if int(transformers.__version__[0]) <=3: 14 | from transformers.modeling_roberta import RobertaPreTrainedModel 15 | else: 16 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel 17 | from transformers import RobertaModel 18 | import torch.nn.functional as F 19 | from torch.cuda.amp import autocast 20 | 21 | 22 | class embedding_model(nn.Module): 23 | def __init__(self,node_embeddings, label_offest, device): 24 | super().__init__() 25 | self.embedding = nn.Embedding.from_pretrained(node_embeddings,freeze = False).to(device) 26 | self.label_offest = label_offest 27 | self.device = device 28 | self.loss_fct = nn.CrossEntropyLoss().to(device) 29 | 30 | def forward(self, query_embeddings, all_rel_label, node_list, node_dict,label_offest,layer,all_node): 31 | batch=query_embeddings.shape[0] 32 | scores = None 33 | for i in range(batch): 34 | features = None 35 | query_embedding = query_embeddings[i].reshape(-1,768) 36 | rel_label = all_rel_label[i] 37 | offest = label_offest[rel_label] 38 | node_embedding = self.embedding(offest).reshape(-1,768) 39 | child_embeddings = node_embedding 40 | 41 | negetive = [] 42 | node_father = node_dict[rel_label].parent 43 | for node_bro in node_father.children: 44 | negetive.append(node_bro.val) 45 | num = len(rel_label) 46 | tmp_node = node_list[num-1] 47 | if rel_label in negetive: 48 | negetive.remove(rel_label) 49 | index_list = random.sample(range(len(tmp_node)), all_node-1) 50 | 51 | for i in index_list: 52 | negetive.append(tmp_node[i]) 53 | if rel_label in negetive: 54 | negetive.remove(rel_label) 55 | 56 | if len(negetive)>all_node-1: 57 | negetive = negetive[:all_node-1] 58 | 59 | 60 | for index in negetive: 61 | offest = label_offest[index] 62 | child_embedding = self.embedding(offest).reshape(-1,768) 63 | 64 | if child_embeddings is None: 65 | child_embeddings = child_embedding 66 | else: 67 | child_embeddings = torch.cat([child_embeddings, child_embedding], dim=0) 68 | 69 | features = torch.matmul(query_embedding, child_embeddings.T) 70 | 71 | score = features.reshape(1,-1) 72 | if(score.shape[1]!= all_node): 73 | num_len = all_node-score.shape[1] 74 | other = torch.zeros(1,num_len).to(self.device) 75 | score = torch.cat((score,other),1) 76 | 77 | 78 | if scores is None: 79 | scores = score 80 | else: 81 | scores = torch.cat([scores, score], dim=0) 82 | 83 | 84 | return scores 85 | 86 | def get_embedding(self, node_dict, node_list, layer): 87 | tmp_node = node_list[layer] 88 | for index in tmp_node: 89 | offest = self.label_offest[index] 90 | node_embedding = np.array(self.embedding(offest).cpu().detach()) 91 | node_dict[index].embedding = node_embedding 92 | return node_dict 93 | 94 | def predict(self, query_embeddings, node): 95 | query_embedding = query_embeddings.reshape(-1,768) 96 | offest = self.label_offest[node.val] 97 | node_embedding = self.embedding(offest) 98 | node_embeddings = torch.tensor(node_embedding).reshape(-1,768).to(self.device) 99 | score = torch.matmul(query_embedding, node_embeddings.T) 100 | return score 101 | 102 | 103 | 104 | class EmbeddingMixin: 105 | """ 106 | Mixin for common functions in most embedding models. Each model should define its own bert-like backbone and forward. 107 | We inherit from RobertaModel to use from_pretrained 108 | """ 109 | def __init__(self, model_argobj): 110 | if model_argobj is None: 111 | self.use_mean = False 112 | else: 113 | self.use_mean = model_argobj.use_mean 114 | print("Using mean:", self.use_mean) 115 | 116 | def _init_weights(self, module): 117 | """ Initialize the weights """ 118 | if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)): 119 | # Slightly different from the TF version which uses truncated_normal for initialization 120 | # cf https://github.com/pytorch/pytorch/pull/5617 121 | module.weight.data.normal_(mean=0.0, std=0.02) 122 | 123 | def masked_mean(self, t, mask): 124 | s = torch.sum(t * mask.unsqueeze(-1).float(), axis=1) 125 | d = mask.sum(axis=1, keepdim=True).float() 126 | return s / d 127 | 128 | def masked_mean_or_first(self, emb_all, mask): 129 | # emb_all is a tuple from bert - sequence output, pooler 130 | assert isinstance(emb_all, tuple) 131 | if self.use_mean: 132 | return self.masked_mean(emb_all[0], mask) 133 | else: 134 | return emb_all[0][:, 0] 135 | 136 | def query_emb(self, input_ids, attention_mask): 137 | raise NotImplementedError("Please Implement this method") 138 | 139 | def body_emb(self, input_ids, attention_mask): 140 | raise NotImplementedError("Please Implement this method") 141 | 142 | 143 | class BaseModelDot(EmbeddingMixin): 144 | def _text_encode(self, input_ids, attention_mask): 145 | # TODO should raise NotImplementedError 146 | # temporarily do this 147 | return None 148 | 149 | def query_emb(self, input_ids, attention_mask): 150 | outputs1 = self._text_encode(input_ids=input_ids, 151 | attention_mask=attention_mask) 152 | full_emb = self.masked_mean_or_first(outputs1, attention_mask) 153 | query1 = self.norm(self.embeddingHead(full_emb)) 154 | return query1 155 | 156 | def body_emb(self, input_ids, attention_mask): 157 | return self.query_emb(input_ids, attention_mask) 158 | 159 | def forward(self, input_ids, attention_mask, is_query, *args): 160 | assert len(args) == 0 161 | if is_query: 162 | return self.query_emb(input_ids, attention_mask) 163 | else: 164 | return self.body_emb(input_ids, attention_mask) 165 | 166 | 167 | class RobertaDot(BaseModelDot, RobertaPreTrainedModel): 168 | def __init__(self, config, model_argobj=None): 169 | BaseModelDot.__init__(self, model_argobj) 170 | RobertaPreTrainedModel.__init__(self, config) 171 | if int(transformers.__version__[0]) ==4 : 172 | config.return_dict = False 173 | self.roberta = RobertaModel(config, add_pooling_layer=False) 174 | if hasattr(config, "output_embedding_size"): 175 | self.output_embedding_size = config.output_embedding_size 176 | else: 177 | self.output_embedding_size = config.hidden_size 178 | print("output_embedding_size", self.output_embedding_size) 179 | self.embeddingHead = nn.Linear(config.hidden_size, self.output_embedding_size) 180 | self.norm = nn.LayerNorm(self.output_embedding_size) 181 | self.apply(self._init_weights) 182 | 183 | def _text_encode(self, input_ids, attention_mask): 184 | outputs1 = self.roberta(input_ids=input_ids, 185 | attention_mask=attention_mask) 186 | return outputs1 187 | 188 | 189 | 190 | class RobertaDot_4(BaseModelDot, RobertaPreTrainedModel): 191 | def __init__(self, config, model_argobj=None): 192 | BaseModelDot.__init__(self, model_argobj) 193 | RobertaPreTrainedModel.__init__(self, config) 194 | if int(transformers.__version__[0]) ==4 : 195 | config.return_dict = False 196 | self.roberta = RobertaModel(config, add_pooling_layer=False) 197 | if hasattr(config, "output_embedding_size"): 198 | self.output_embedding_size = config.output_embedding_size*4 199 | else: 200 | self.output_embedding_size = config.hidden_size*4 201 | print("output_embedding_size", self.output_embedding_size) 202 | self.embeddingHead = nn.Linear(config.hidden_size, self.output_embedding_size) 203 | self.norm = nn.LayerNorm(self.output_embedding_size) 204 | self.apply(self._init_weights) 205 | 206 | def _text_encode(self, input_ids, attention_mask): 207 | outputs1 = self.roberta(input_ids=input_ids, 208 | attention_mask=attention_mask) 209 | return outputs1 210 | 211 | 212 | 213 | class RobertaDot_InBatch(RobertaDot): 214 | def forward(self, input_query_ids, query_attention_mask, 215 | input_doc_ids, doc_attention_mask, 216 | other_doc_ids=None, other_doc_attention_mask=None, 217 | rel_pair_mask=None, hard_pair_mask=None): 218 | return inbatch_train(self.query_emb, self.body_emb, 219 | input_query_ids, query_attention_mask, 220 | input_doc_ids, doc_attention_mask, 221 | other_doc_ids, other_doc_attention_mask, 222 | rel_pair_mask, hard_pair_mask) 223 | 224 | 225 | class RobertaDot_Rand(RobertaDot): 226 | def forward(self, input_query_ids, query_attention_mask, 227 | input_doc_ids, doc_attention_mask, 228 | other_doc_ids=None, other_doc_attention_mask=None, 229 | rel_pair_mask=None, hard_pair_mask=None): 230 | return randneg_train(self.query_emb, self.body_emb, 231 | input_query_ids, query_attention_mask, 232 | input_doc_ids, doc_attention_mask, 233 | other_doc_ids, other_doc_attention_mask, 234 | hard_pair_mask) 235 | 236 | 237 | def inbatch_train(query_encode_func, doc_encode_func, 238 | input_query_ids, query_attention_mask, 239 | input_doc_ids, doc_attention_mask, 240 | other_doc_ids=None, other_doc_attention_mask=None, 241 | rel_pair_mask=None, hard_pair_mask=None): 242 | 243 | query_embs = query_encode_func(input_query_ids, query_attention_mask) 244 | doc_embs = doc_encode_func(input_doc_ids, doc_attention_mask) 245 | 246 | batch_size = query_embs.shape[0] 247 | with autocast(enabled=False): 248 | batch_scores = torch.matmul(query_embs, doc_embs.T) 249 | 250 | single_positive_scores = torch.diagonal(batch_scores, 0) 251 | 252 | positive_scores = single_positive_scores.reshape(-1, 1).repeat(1, batch_size).reshape(-1) 253 | if rel_pair_mask is None: 254 | rel_pair_mask = 1 - torch.eye(batch_size, dtype=batch_scores.dtype, device=batch_scores.device) 255 | 256 | batch_scores = batch_scores.reshape(-1) 257 | logit_matrix = torch.cat([positive_scores.unsqueeze(1), 258 | batch_scores.unsqueeze(1)], dim=1) 259 | 260 | lsm = F.log_softmax(logit_matrix, dim=1) 261 | loss = -1.0 * lsm[:, 0] * rel_pair_mask.reshape(-1) 262 | 263 | first_loss, first_num = loss.sum(), rel_pair_mask.sum() 264 | 265 | if other_doc_ids is None: 266 | return (first_loss/first_num,) 267 | 268 | # other_doc_ids: batch size, per query doc, length 269 | other_doc_num = other_doc_ids.shape[0] * other_doc_ids.shape[1] 270 | other_doc_ids = other_doc_ids.reshape(other_doc_num, -1) 271 | other_doc_attention_mask = other_doc_attention_mask.reshape(other_doc_num, -1) 272 | other_doc_embs = doc_encode_func(other_doc_ids, other_doc_attention_mask) 273 | 274 | with autocast(enabled=False): 275 | other_batch_scores = torch.matmul(query_embs, other_doc_embs.T) 276 | other_batch_scores = other_batch_scores.reshape(-1) 277 | positive_scores = single_positive_scores.reshape(-1, 1).repeat(1, other_doc_num).reshape(-1) 278 | other_logit_matrix = torch.cat([positive_scores.unsqueeze(1), 279 | other_batch_scores.unsqueeze(1)], dim=1) 280 | 281 | other_lsm = F.log_softmax(other_logit_matrix, dim=1) 282 | other_loss = -1.0 * other_lsm[:, 0] 283 | 284 | if hard_pair_mask is not None: 285 | hard_pair_mask = hard_pair_mask.reshape(-1) 286 | other_loss = other_loss * hard_pair_mask 287 | second_loss, second_num = other_loss.sum(), hard_pair_mask.sum() 288 | else: 289 | second_loss, second_num = other_loss.sum(), len(other_loss) 290 | 291 | return ((first_loss+second_loss)/(first_num+second_num),) 292 | 293 | 294 | def randneg_train(query_encode_func, doc_encode_func, 295 | input_query_ids, query_attention_mask, 296 | input_doc_ids, doc_attention_mask, 297 | other_doc_ids=None, other_doc_attention_mask=None, 298 | hard_pair_mask=None): 299 | 300 | query_embs = query_encode_func(input_query_ids, query_attention_mask) 301 | doc_embs = doc_encode_func(input_doc_ids, doc_attention_mask) 302 | 303 | with autocast(enabled=False): 304 | batch_scores = torch.matmul(query_embs, doc_embs.T) 305 | single_positive_scores = torch.diagonal(batch_scores, 0) 306 | # other_doc_ids: batch size, per query doc, length 307 | other_doc_num = other_doc_ids.shape[0] * other_doc_ids.shape[1] 308 | other_doc_ids = other_doc_ids.reshape(other_doc_num, -1) 309 | other_doc_attention_mask = other_doc_attention_mask.reshape(other_doc_num, -1) 310 | other_doc_embs = doc_encode_func(other_doc_ids, other_doc_attention_mask) 311 | 312 | with autocast(enabled=False): 313 | other_batch_scores = torch.matmul(query_embs, other_doc_embs.T) 314 | other_batch_scores = other_batch_scores.reshape(-1) 315 | positive_scores = single_positive_scores.reshape(-1, 1).repeat(1, other_doc_num).reshape(-1) 316 | other_logit_matrix = torch.cat([positive_scores.unsqueeze(1), 317 | other_batch_scores.unsqueeze(1)], dim=1) 318 | # print(logit_matrix) 319 | other_lsm = F.log_softmax(other_logit_matrix, dim=1) 320 | other_loss = -1.0 * other_lsm[:, 0] 321 | if hard_pair_mask is not None: 322 | hard_pair_mask = hard_pair_mask.reshape(-1) 323 | other_loss = other_loss * hard_pair_mask 324 | second_loss, second_num = other_loss.sum(), hard_pair_mask.sum() 325 | else: 326 | second_loss, second_num = other_loss.sum(), len(other_loss) 327 | return (second_loss/second_num,) -------------------------------------------------------------------------------- /model_search_cos.py: -------------------------------------------------------------------------------- 1 | 2 | from cProfile import label 3 | import email 4 | import sys 5 | sys.path += ["./"] 6 | import os 7 | import time 8 | import torch 9 | import random 10 | import faiss 11 | import joblib 12 | import logging 13 | import argparse 14 | import subprocess 15 | import numpy as np 16 | import pandas as pd 17 | import pickle as pkl 18 | from construct_tree import TreeInitialize,TreeNode 19 | from torch import nn 20 | from tqdm import tqdm, trange 21 | from numba import jit 22 | 23 | 24 | @jit(nopython=False) 25 | def candidates_generator(embeddings,node_dict,topk): #qid 26 | """layer-wise retrieval algorithm in prediction.""" 27 | root = node_dict['0'] 28 | Q, A = root.children, [] 29 | layer = 0 30 | embedding = embeddings.reshape(1,768).cpu().numpy() 31 | 32 | while Q: 33 | layer = layer+1 34 | B = [] 35 | for node in Q: 36 | if node.isleaf is True: #如果是叶节点 37 | A.append(node) 38 | B.append(node) 39 | for node in B: 40 | Q.remove(node) 41 | 42 | if(len(Q) == 0): 43 | break 44 | 45 | probs = [] 46 | embeddings = [] 47 | for node in Q: 48 | embeddings.append(node.embedding) 49 | 50 | embeddings =np.array(embeddings) 51 | 52 | probs = np.dot(embedding, embeddings.T).reshape(-1,).tolist() 53 | prob_list = list(zip(Q, probs)) 54 | prob_list = sorted(prob_list, key=lambda x: x[1], reverse=True) 55 | 56 | I = [] 57 | if len(prob_list) > topk: 58 | for i in range(topk): 59 | I.append(prob_list[i][0]) 60 | else: 61 | for p in prob_list: 62 | I.append(p[0]) 63 | 64 | 65 | Q = [] 66 | while I: 67 | node = I.pop() 68 | for child in node.children: 69 | Q.append(child) 70 | 71 | # A = [] 72 | # for i in range(topk): 73 | # A.append(prob_list[i][0].val) 74 | 75 | # return A 76 | probs = [] 77 | leaf_embeddings = [] 78 | for leaf in A: 79 | leaf_embeddings.append(leaf.embedding) 80 | leaf_embeddings =np.array(leaf_embeddings) 81 | 82 | probs = np.dot(embedding, leaf_embeddings.T).reshape(-1,).tolist() 83 | prob_list = list(zip(A, probs)) 84 | prob_list = sorted(prob_list, key=lambda x: x[1], reverse=True) 85 | A = [] 86 | for i in range(topk): 87 | A.append(prob_list[i][0].val) #pid 88 | return A 89 | 90 | @numba.jit(nopython=True) 91 | def metrics_count(embeddings,node_dict,topk): #(vtest, tree.root, 10, model 92 | rank_list = [] 93 | size = embeddings.shape[0] 94 | for i in range(size): 95 | cands = candidates_generator(embeddings,node_dict,topk) #返回的节点 96 | rank_list.append(cands) 97 | return rank_list -------------------------------------------------------------------------------- /reorganize_clusters_tree.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | from cgi import print_environ 3 | import email 4 | from json import load 5 | from lib2to3.pytree import Node 6 | import sys 7 | from tkinter import Y 8 | from traceback import print_tb 9 | sys.path += ["./"] 10 | import os 11 | import time 12 | import torch 13 | import random 14 | import faiss 15 | import joblib 16 | import logging 17 | import argparse 18 | import subprocess 19 | import numpy as np 20 | import pandas as pd 21 | from numba import njit 22 | from numba.core import types 23 | from numba.typed import Dict 24 | import scipy.sparse as smat 25 | import pickle as pkl 26 | from construct_tree import TreeInitialize,TreeNode 27 | from torch import nn 28 | from tqdm import tqdm, trange 29 | from model_search_cos import metrics_count 30 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 31 | 32 | 33 | def load_object(path): 34 | with open(path, 'rb') as f: 35 | obj = pkl.load(f) 36 | return obj 37 | 38 | def save_object(obj, path): 39 | with open(path,'wb') as f: 40 | pkl.dump(obj,f) 41 | 42 | 43 | 44 | def _define_node_emebedding(arr_node, node_dict): 45 | for key in arr_node[5]: 46 | node = node_dict[key] 47 | if node.isleaf == False: 48 | # print(node.val) 49 | embedding = [0 for _ in range(768)] 50 | num = 0 51 | for child in node.children: 52 | embedding = np.sum([embedding,child.embedding],axis=0) 53 | num += 1 54 | print(num) 55 | node.embedding = [ x/num for x in embedding] 56 | else: 57 | pass 58 | for key in arr_node[4]: 59 | node = node_dict[key] 60 | if node.isleaf == False: 61 | embedding = [0 for _ in range(768)] 62 | num = 0 63 | for child in node.children: 64 | embedding = np.sum([embedding,child.embedding],axis=0) 65 | num += 1 66 | node.embedding = [ x/num for x in embedding] 67 | for key in arr_node[3]: 68 | node = node_dict[key] 69 | if node.isleaf == False: 70 | embedding = [0 for _ in range(768)] 71 | num = 0 72 | for child in node.children: 73 | embedding = np.sum([embedding,child.embedding],axis=0) 74 | num += 1 75 | node.embedding = [ x/num for x in embedding] 76 | for key in arr_node[2]: 77 | node = node_dict[key] 78 | if node.isleaf == False: 79 | embedding = [0 for _ in range(768)] 80 | num = 0 81 | for child in node.children: 82 | embedding = np.sum([embedding,child.embedding],axis=0) 83 | num += 1 84 | node.embedding = [ x/num for x in embedding] 85 | for key in arr_node[1]: 86 | node = node_dict[key] 87 | if node.isleaf == False: 88 | embedding = [0 for _ in range(768)] 89 | num = 0 90 | for child in node.children: 91 | embedding = np.sum([embedding,child.embedding],axis=0) 92 | num += 1 93 | node.embedding = [ x/num for x in embedding] 94 | return node_dict 95 | 96 | 97 | 98 | 99 | 100 | 101 | def _node_list(root): 102 | def node_val(node): 103 | if(node.isleaf == False): 104 | return node.val 105 | else: 106 | return node.val 107 | 108 | node_queue = [root] 109 | arr_arr_node = [] 110 | arr_arr_node.append([node_val(node_queue[0])]) 111 | while node_queue: 112 | tmp = [] 113 | tmp_val = [] 114 | for node in node_queue: 115 | for child in node.children: 116 | tmp.append(child) 117 | tmp_val.append(node_val(child)) 118 | if len(tmp_val) > 0: 119 | arr_arr_node.append(tmp_val) 120 | node_queue = tmp 121 | return arr_arr_node 122 | 123 | @njit 124 | def construct_new_C_and_Y( 125 | counts_rows, 126 | counts_cols, 127 | counts, 128 | row_ids, 129 | row_ranges, 130 | C_rows, 131 | sort_idx, 132 | nr_labels, 133 | max_cluster_size, 134 | n_copies, 135 | ): 136 | """Determine the new clustering matrix and the new label matrix given the couting matrix. 137 | 138 | This function implements Eq.(10) in our paper. I.e. given the couting matrix C = Y^T * M, 139 | we select the correct cluster id for each label one by one, in descending order of C entries, 140 | possibly assign a label multiple times (`n_copies`) to different clusters. Finally, the new 141 | cluster and new label matrix is returned. Notice that Numba is used here, this prevents us 142 | from passing scipy sparse matrix directly. 143 | 144 | Args: 145 | counts_rows, counts_cols, counts: The counting matrix in COO format. 146 | row_ids, row_ranges: The indices and indptr of original Y matrix in CSC format. 147 | C_rows: Clustering matrix C in LIL format, converted to list of numpy arrays. 148 | sort_idx: Index of counts_{rows,cols} to sort them in decending order. 149 | nr_labels: Number of original labels. 150 | max_cluster_size: (Unused for now) Hard constraints to limit the number of labels 151 | in each cluster (to balance cluster size). 152 | n_copies: Max number of copies for each label (\lambda in our paper). 153 | 154 | Returns: 155 | New cluster matrix (`new_C_*`), new label matrix (`new_Y_*`), the replicated label 156 | assignment (`C_overlap_*`), number of duplicated labels (`nr_copied_labels`), a map 157 | from new label id to the underlying label id (`mapper`), unused labels that never 158 | show up in training (`unused_labels`), number of lightly used labels (`nr_tail_labels`). 159 | """ 160 | # construct empty cluster matrix and label matrix 161 | nr_copied_labels = 0 162 | new_C_cols = [] 163 | new_C_rows = [] 164 | new_C_data = [] 165 | new_Y_rows = [] 166 | labels_included = set() 167 | mapper = Dict.empty(key_type=types.int64, value_type=types.int64,) 168 | cluster_size = Dict.empty(key_type=types.int64, value_type=types.int64,) 169 | pseudo_label_count = Dict.empty(key_type=types.int64, value_type=types.int64,) 170 | # results 171 | C_overlap_rows, C_overlap_cols = [], [] 172 | max_count = n_copies 173 | # adding labels to clusters one by one in descending frequency 174 | for idx in sort_idx: 175 | label_id = counts_rows[idx] 176 | leaf_id = counts_cols[idx] 177 | if label_id in pseudo_label_count and pseudo_label_count[label_id] >= max_count: 178 | continue 179 | # If you need to contrain the max cluster size, then 180 | # uncomment following two lines 181 | # if label_count[leaf_id] >= max_cluster_size: 182 | # continue 183 | if leaf_id not in cluster_size: 184 | cluster_size[leaf_id] = 1 185 | else: 186 | cluster_size[leaf_id] += 1 187 | 188 | if label_id not in pseudo_label_count: 189 | pseudo_label_count[label_id] = 1 190 | else: 191 | pseudo_label_count[label_id] += 1 192 | 193 | if label_id in labels_included: # 194 | # add a pseudo label that duplicates label_id 195 | pseudo_label_id = nr_copied_labels + nr_labels 196 | mapper[pseudo_label_id] = label_id 197 | # add one more row to C (in lil format) 198 | new_C_rows.append(nr_copied_labels) 199 | new_C_cols.append(leaf_id) 200 | new_C_data.append(1.0) 201 | # add one more column to Yt 202 | examples = row_ids[row_ranges[label_id] : row_ranges[label_id + 1]] 203 | new_Y_rows.append(examples) 204 | nr_copied_labels += 1 205 | else: 206 | # add a new label 207 | labels_included.add(label_id) 208 | C_overlap_rows.append(label_id) 209 | C_overlap_cols.append(leaf_id) 210 | 211 | # exit early if we have too many effective labels 212 | if len(mapper) >= max_count * nr_labels: 213 | break 214 | # add missing labels back to clusters 215 | nr_tail_labels = 0 216 | for label_id in range(nr_labels): 217 | if label_id not in labels_included: 218 | original_leaf_id = C_rows[label_id][0] # 219 | C_overlap_rows.append(label_id) 220 | C_overlap_cols.append(original_leaf_id) 221 | labels_included.add(label_id) 222 | nr_tail_labels += 1 223 | 224 | unused_labels = set() 225 | for label_id in range(nr_labels): 226 | if label_id not in labels_included: 227 | unused_labels.add(label_id) 228 | 229 | # new_Y elements 230 | new_Y_indptr = [0] 231 | new_Y_indices = [] 232 | for rows in new_Y_rows: 233 | new_Y_indptr.append(new_Y_indptr[-1] + len(rows)) 234 | new_Y_indices.extend(rows) 235 | new_Y_data = np.ones(len(new_Y_indices), dtype=np.int32) 236 | return ( 237 | np.array(new_C_cols), 238 | np.array(new_C_rows), 239 | np.array(new_C_data), 240 | new_Y_data, 241 | new_Y_indices, 242 | new_Y_indptr, 243 | C_overlap_cols, 244 | C_overlap_rows, 245 | nr_copied_labels, 246 | mapper, 247 | unused_labels, 248 | nr_tail_labels, 249 | ) 250 | 251 | 252 | def get_matching_matrix(rank_output, dict_value): 253 | path_to_rank = rank_output 254 | row = [] 255 | col = [] 256 | with open(path_to_rank,'r') as f: 257 | for l in f: 258 | try: 259 | l = l.strip().split('\t') 260 | qid = int(l[0]) 261 | labelid = l[1] 262 | row.append(qid) 263 | col.append(dict_value[labelid]) 264 | except: 265 | raise IOError('\"%s\" is not valid format' % l) 266 | 267 | row = np.array(row) 268 | print(row.shape) 269 | col = np.array(col) 270 | row_num = int(len(row)/5) 271 | data = np.ones((len(col),)) 272 | data_matrix = smat.csr_matrix((data,(row,col)),shape=(row_num,len(dict_value))) 273 | 274 | return data_matrix 275 | 276 | def get_cluster_matrix(leaf_dict,dict_value): 277 | row = [] 278 | col = [] 279 | for node in leaf_dict.values(): 280 | for pid in node.pids: 281 | row.append(pid) 282 | col.append(dict_value[node.val]) 283 | 284 | row = np.array(row) 285 | col = np.array(col) 286 | data = np.ones((len(col),)) 287 | data_matrix = smat.csr_matrix((data,(row,col)),shape=(len(row),len(dict_value))) 288 | return data_matrix 289 | 290 | def Get_New_C(args): 291 | 292 | leaf_dict = load_object(f'{args.leaf_path}') 293 | 294 | value_dict = {} 295 | dict_value = {} 296 | i = 0 297 | for node in leaf_dict.values(): 298 | value_dict[i] = node.val 299 | dict_value[node.val] = i 300 | i = i + 1 301 | 302 | ## get matric M 303 | path_to_rank = f"{args.M_path}" 304 | M = get_matching_matrix(path_to_rank,dict_value) 305 | print("M shape") 306 | 307 | print(M.shape) 308 | 309 | # get matric C 310 | C = get_cluster_matrix(leaf_dict,dict_value) 311 | 312 | print(C.shape) 313 | 314 | 315 | 316 | qrel_path = f'{args.Y_path}' 317 | row = [] 318 | col = [] 319 | data_dev = pd.read_csv(qrel_path, header=None, names=["qid","pid",'rank'], sep='\t') 320 | for qid,pid in zip(data_dev.qid,data_dev.pid): 321 | row.append(qid) 322 | col.append(pid) 323 | row = np.array(row) 324 | col = np.array(col) 325 | data = np.ones((len(col),)) 326 | Y = smat.csr_matrix((data,(row,col)),shape=(M.shape[0],C.shape[0])) 327 | print(Y.shape) 328 | counts = Y.transpose().dot(M).tocoo() 329 | counts.eliminate_zeros() 330 | counts_rows, counts_cols, counts = counts.row, counts.col, counts.data 331 | print(len(counts)) 332 | 333 | sort_idx = np.argsort(counts)[::-1] 334 | 335 | 336 | Yt_csc = Y.tocsc() 337 | row_ranges = Yt_csc.indptr 338 | row_ids = Yt_csc.indices 339 | 340 | C = C.tolil() 341 | C_rows = C.rows 342 | 343 | max_cluster_size = int(1.0 * C.shape[0] / C.shape[1]) 344 | ( 345 | new_C_cols, 346 | new_C_rows, 347 | new_C_data, 348 | new_Y_data, 349 | new_Y_indices, 350 | new_Y_indptr, 351 | C_overlap_cols, 352 | C_overlap_rows, 353 | out_labels, 354 | mapper, 355 | unused_labels, 356 | nr_tail_labels, 357 | ) = construct_new_C_and_Y( 358 | np.asarray(counts_rows, dtype=np.int32), 359 | np.asarray(counts_cols, dtype=np.int32), 360 | np.asarray(counts, dtype=np.int32), 361 | np.asarray(row_ids, dtype=np.int32), 362 | np.asarray(row_ranges, dtype=np.int32), 363 | [np.asarray(row, dtype=np.int32) for row in C_rows], 364 | sort_idx, 365 | Y.shape[1], 366 | max_cluster_size, 367 | args.overlap, 368 | ) 369 | C_overlap = smat.coo_matrix( 370 | (np.ones_like(C_overlap_cols), (C_overlap_rows, C_overlap_cols)), 371 | shape=C.shape, 372 | dtype=C.dtype, 373 | ).tocsr() 374 | 375 | print(f"#copied labels: {out_labels}, #tail labels: {nr_tail_labels}") 376 | 377 | new_C = smat.csr_matrix((new_C_data,(new_C_rows,new_C_cols)),shape=(out_labels, C.shape[1]),dtype=C.dtype) 378 | C_new = smat.vstack((C_overlap, new_C), format="csc") 379 | 380 | new_Y = smat.csc_matrix( 381 | (new_Y_data, new_Y_indices, new_Y_indptr), 382 | shape=(Y.shape[0], len(new_Y_indptr) - 1), 383 | dtype=Y.dtype, 384 | ) 385 | Y = smat.hstack((Y, new_Y), format="csr") 386 | smat.save_npz(f'{args.save_dir}/C.npz', C_new, compressed=True) 387 | smat.save_npz(f'{args.save_dir}/Y.npz', Y, compressed=True) 388 | smat.save_npz(f'{args.save_dir}/M.npz', M, compressed=True) 389 | save_object(dict(mapper),f'{args.save_dir}/mapper.pkl') 390 | 391 | def tree_update(args): 392 | # tree_path = f"{args.raw_tree_path}" 393 | tree = load_object(args.raw_tree_path) 394 | # tree_leaf_dict = load_object(args.leaf_path) 395 | print(len(tree.leaf_dict)) 396 | for leaf in tree.leaf_dict: 397 | node = tree.leaf_dict[leaf] 398 | node.pids = [] 399 | 400 | 401 | # leaf_dict = load_object(f'{args.leaf_path}') 402 | value_dict = {} 403 | dict_value = {} 404 | i = 0 405 | for node in tree.leaf_dict.values(): 406 | value_dict[i] = node.val 407 | dict_value[node.val] = i 408 | i = i + 1 409 | 410 | fname = f'{args.save_dir}/C.npz' 411 | C = smat.load_npz(fname).tocoo() 412 | 413 | C.eliminate_zeros() 414 | pids, cluster_ids = C.row, C.col 415 | 416 | for pid,cluster_id in zip(pids,cluster_ids): 417 | tree.leaf_dict[value_dict[cluster_id]].pids.append(pid) 418 | 419 | 420 | 421 | dict_label = {} 422 | for leaf in tree.leaf_dict: 423 | node = tree.leaf_dict[leaf] 424 | pids = node.pids 425 | 426 | for pid in pids: 427 | dict_label[pid] = str(node.val) 428 | 429 | # print("Update embedding") 430 | # for leaf in tree.leaf_dict: 431 | # node = tree.leaf_dict[leaf] 432 | # pids = node.pids 433 | # num = 0 434 | # embedding = [0 for _ in range(768)] 435 | # for pid in pids: 436 | # dict_label[pid] = str(node.val) 437 | 438 | # try: 439 | # embedding = np.sum([embedding,pid_embeddings_all[pid]],axis=0) 440 | # num = num+1 441 | # except: 442 | # pass 443 | # node.embedding = [ x/num for x in embedding] 444 | 445 | root = tree.root 446 | node_list = _node_list(root) 447 | save_object(node_list,f'{args.save_dir}/node_list.pkl') 448 | 449 | node_dict = {} 450 | node_queue = [tree.root] 451 | while node_queue: 452 | current_node = node_queue.pop(0) 453 | node_dict[current_node.val] = current_node 454 | for child in current_node.children: 455 | node_queue.append(child) 456 | 457 | 458 | save_object(tree.leaf_dict,f'{args.save_dir}/leaf_dict.pkl') 459 | df = pd.DataFrame.from_dict(dict_label, orient='index',columns=['labels']) 460 | df = df.reset_index().rename(columns = {'index':'pid'}) 461 | df.to_csv(f'{args.save_dir}/pid_labelid.memmap',header=False, index=False) 462 | save_object(node_dict,f'{args.save_dir}/node_dict.pkl') 463 | 464 | path = 'dev-qrel.tsv' 465 | output_file_path = f'{args.save_dir}/dev_label.tsv' 466 | 467 | outfile = open(output_file_path,'w') 468 | with open(path,'r') as f: 469 | for l in f: 470 | try: 471 | l = l.strip().split('\t') 472 | qid = int(l[0]) 473 | pid = int(l[2]) 474 | rel = int(l[3]) 475 | label = dict_label[pid][:] 476 | if rel != 0: 477 | outfile.write(f"{qid}\t0\t{label}\t{rel}\n") 478 | except: 479 | raise IOError('\"%s\" is not valid format' % l) 480 | 481 | 482 | 483 | if __name__ == "__main__": 484 | parser = argparse.ArgumentParser() 485 | ## Required parameters 486 | parser.add_argument("--output_dir", type=str, default=f"../tree/doc/new_tree") 487 | parser.add_argument("--type", choices=["doc", "passage"], default="doc") 488 | parser.add_argument("--overlap", type=int, default=2) 489 | parser.add_argument("--raw_tree_path", type=str, default=f"tree.pkl") 490 | parser.add_argument("--leaf_path", type=str, default=f"leaf_dict.pkl") 491 | parser.add_argument("--M_path", type=str, default=f"recall_train_5.tsv") 492 | parser.add_argument("--Y_path", type=str, default=f"train.rank_100.tsv") 493 | 494 | args = parser.parse_args() 495 | args.save_dir = f"{args.output_dir}/{args.overlap}" 496 | 497 | if not os.path.exists(args.save_dir): 498 | os.makedirs(args.save_dir) 499 | Get_New_C(args) 500 | tree_update(args) 501 | 502 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import email 3 | from lib2to3.pytree import Node 4 | import sys 5 | sys.path += ["./"] 6 | import os 7 | import time 8 | import torch 9 | import random 10 | import faiss 11 | import joblib 12 | import logging 13 | import argparse 14 | import subprocess 15 | import numpy as np 16 | import pandas as pd 17 | import pickle as pkl 18 | from construct_tree import TreeInitialize,TreeNode 19 | from torch import nn 20 | from tqdm import tqdm, trange 21 | 22 | from model_search_cos import metrics_count 23 | os.environ["CUDA_VISIBLE_DEVICES"] = '6' 24 | from torch.utils.tensorboard import SummaryWriter 25 | from torch.utils.data import DataLoader, RandomSampler 26 | from transformers import (AdamW, get_linear_schedule_with_warmup, 27 | RobertaConfig) 28 | 29 | from dataset import TextTokenIdsCache, SequenceDataset, load_rel, pack_tensor_2D 30 | from model import RobertaDot,embedding_model 31 | 32 | logger = logging.getLogger(__name__) 33 | logging.basicConfig(format = '%(asctime)s-%(levelname)s-%(name)s- %(message)s', 34 | datefmt = '%d %H:%M:%S', 35 | level = logging.INFO) 36 | 37 | 38 | def set_seed(args): 39 | random.seed(args.seed) 40 | np.random.seed(args.seed) 41 | torch.manual_seed(args.seed) 42 | if args.n_gpu > 0: 43 | torch.cuda.manual_seed_all(args.seed) 44 | 45 | def load_object(path): 46 | with open(path, 'rb') as f: 47 | obj = pkl.load(f) 48 | return obj 49 | 50 | def save_object(obj, path): 51 | with open(path,'wb') as f: 52 | pkl.dump(obj,f) 53 | 54 | 55 | def save_model(model, output_dir, save_name, args, optimizer=None): 56 | save_dir = os.path.join(output_dir, save_name) 57 | if not os.path.exists(save_dir): 58 | os.makedirs(save_dir) 59 | model_to_save = model.module if hasattr(model, 'module') else model 60 | model_to_save.save_pretrained(save_dir) 61 | torch.save(args, os.path.join(save_dir, 'training_args.bin')) 62 | if optimizer is not None: 63 | torch.save(optimizer.state_dict(), os.path.join(save_dir, "optimizer.bin")) 64 | 65 | 66 | 67 | 68 | class TrainQueryDataset(SequenceDataset): 69 | def __init__(self, queryids_cache, 70 | rel_file, max_query_length): 71 | SequenceDataset.__init__(self, queryids_cache, max_query_length) 72 | self.reldict = load_rel(rel_file) 73 | 74 | def __getitem__(self, item): 75 | ret_val = super().__getitem__(item) 76 | ret_val['rel_ids'] = self.reldict[item] 77 | return ret_val 78 | 79 | 80 | def get_collate_function(mode, max_seq_length): 81 | cnt = 0 82 | def collate_function(batch): 83 | nonlocal cnt 84 | length = None 85 | if cnt < 10: 86 | length = max_seq_length 87 | cnt += 1 88 | 89 | input_ids = [x["input_ids"] for x in batch] 90 | attention_mask = [x["attention_mask"] for x in batch] 91 | data = { 92 | "input_ids": pack_tensor_2D(input_ids, default=1, 93 | dtype=torch.int64, length=length), 94 | "attention_mask": pack_tensor_2D(attention_mask, default=0, 95 | dtype=torch.int64, length=length), 96 | } 97 | 98 | qids = [x['id'] for x in batch] 99 | if(mode == 'train'): 100 | all_rel_pids = [x["rel_ids"] for x in batch] 101 | return data, qids, all_rel_pids 102 | else: 103 | return data, qids 104 | return collate_function 105 | 106 | 107 | gpu_resources = [] 108 | 109 | 110 | def get_kmeans_labels(all_rel_pids, pid_label_dict,layer): 111 | all_rel_labels = [] 112 | for pids in all_rel_pids: 113 | labels = pid_label_dict[pids[0]][:layer+1] 114 | all_rel_labels.append(labels) 115 | return all_rel_labels 116 | 117 | 118 | def rel2label(all_rel_label): 119 | labels = [] 120 | for rel in all_rel_label: 121 | label = int(rel[-1:]) 122 | labels.append(label) 123 | return labels 124 | 125 | def train(args, model): 126 | """ Train the model """ 127 | if hasattr(torch.cuda, 'empty_cache'): 128 | torch.cuda.empty_cache() 129 | tb_writer = SummaryWriter(os.path.join(args.log_dir, 130 | time.strftime("%b-%d_%H:%M:%S", time.localtime()))) 131 | 132 | 133 | 134 | label_path = f"{args.tree_dir}/pid_labelid.memmap" 135 | pid_label = pd.read_csv(label_path,header=None,names = ["pids","labels"], sep=',',dtype = {'pids':np.int32,'labels':str}) 136 | pid_label_dict = dict(zip(pid_label['pids'], pid_label['labels'])) 137 | 138 | 139 | node_dict = load_object('{args.tree_dir}/node_list.pkl') 140 | node_list = load_object(f"{args.tree_dir}/node_list.pkl") 141 | 142 | 143 | node_embeddings = [] 144 | label_offest = {} 145 | j = 0 146 | for i in node_dict: 147 | if(i == '0'): 148 | continue 149 | label_offest[i] = torch.tensor(int(j), dtype=torch.long).to(args.model_device) 150 | j = j + 1 151 | embedding = node_dict[i].embedding 152 | node_embeddings.append(embedding) 153 | node_embeddings = np.array(node_embeddings).astype(float) 154 | node_embeddings = torch.FloatTensor(node_embeddings).to(args.model_device) 155 | 156 | embedding = embedding_model(node_embeddings, label_offest, args.model_device).to(args.model_device) 157 | args.train_batch_size = args.per_gpu_batch_size 158 | train_dataset = TrainQueryDataset( 159 | TextTokenIdsCache(args.preprocess_dir, "train-query"), 160 | os.path.join(args.preprocess_dir, "train-qrel.tsv"), 161 | args.max_seq_length 162 | ) 163 | train_sampler = RandomSampler(train_dataset) 164 | collate_fn = get_collate_function(args.task, args.max_seq_length) 165 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, 166 | batch_size=args.train_batch_size, collate_fn=collate_fn) 167 | 168 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 169 | no_decay = ['bias', 'LayerNorm.weight'] 170 | 171 | optimizer_grouped_parameters = [ 172 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 173 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, 174 | {'params': embedding_model.parameters(), 'lr':0.00001} 175 | ] 176 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 177 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, 178 | num_training_steps=t_total) 179 | 180 | 181 | # Train! 182 | logger.info("***** Running training *****") 183 | logger.info(" Num examples = %d", len(train_dataset)) 184 | logger.info(" Num Epochs = %d", args.num_train_epochs) 185 | logger.info(" Total train batch size (w. accumulation) = %d", 186 | args.train_batch_size * args.gradient_accumulation_steps) 187 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 188 | logger.info(" Total optimization steps = %d", t_total) 189 | 190 | global_step = 0 191 | tr_loss, logging_loss = 0.0, 0.0 192 | model.zero_grad() 193 | embedding.zero_grad() 194 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch") 195 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 196 | loss_fct = nn.CrossEntropyLoss().to(args.model_device) 197 | 198 | 199 | negative_num = 8 200 | 201 | for epoch_idx, _ in enumerate(train_iterator): 202 | 203 | for layer in args.layer: 204 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 205 | for step, (batch, qids, all_rel_pids) in enumerate(epoch_iterator): 206 | batch = {k:v.to(args.model_device) for k, v in batch.items()} 207 | 208 | model.train() 209 | query_embeddings = model( 210 | input_ids=batch["input_ids"], 211 | attention_mask=batch["attention_mask"], 212 | is_query=True) 213 | 214 | all_rel_label = get_kmeans_labels(all_rel_pids, pid_label_dict, args.layer) 215 | 216 | scores = embedding(query_embeddings, all_rel_label, node_list, node_dict, label_offest, layer,all_node) 217 | labels = torch.zeros(len(scores),).to(args.model_device).long() 218 | loss = loss_fct(scores, labels) 219 | 220 | loss.backward(retain_graph=True) 221 | tr_loss += loss.item() 222 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 223 | 224 | 225 | if (step + 1) % args.gradient_accumulation_steps == 0: 226 | optimizer.step() 227 | scheduler.step() 228 | model.zero_grad() 229 | embedding.zero_grad() 230 | global_step += 1 231 | 232 | 233 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 234 | tb_writer.add_scalar('lr', scheduler.get_last_lr()[0], global_step) 235 | cur_loss = (tr_loss - logging_loss)/args.logging_steps 236 | tb_writer.add_scalar('train/loss', cur_loss, global_step) 237 | logging_loss = tr_loss 238 | 239 | 240 | 241 | save_model(model, args.model_save_dir, 'epoch-model_{}-{}'.format(args.layer,epoch_idx+1), args) 242 | save_dir = os.path.join(args.model_save_dir,'node_dict_{}-{}.pkl'.format(args.layer,epoch_idx+1)) 243 | node_dict = mlp_model.get_embedding(node_dict,node_list,args.layer) 244 | save_object(node_dict,save_dir) 245 | 246 | def evaluate(args, model, node_dict,mode,prefix): 247 | eval_output_dir = args.eval_save_dir 248 | if not os.path.exists(eval_output_dir): 249 | os.makedirs(eval_output_dir) 250 | 251 | dev_dataset = SequenceDataset( 252 | TextTokenIdsCache(args.preprocess_dir, "dev-query"), 253 | 64) 254 | 255 | 256 | 257 | args.eval_batch_size = args.per_gpu_eval_batch_size 258 | collate_fn = get_collate_function(mode=mode, max_seq_length = args.max_seq_length) 259 | dev_dataloader = DataLoader(dev_dataset, batch_size=args.eval_batch_size, collate_fn=collate_fn) 260 | 261 | 262 | # Eval! 263 | logger.info("***** Running evaluation {} *****".format(prefix)) 264 | logger.info(" Num examples = %d", len(dev_dataset)) 265 | logger.info(" Batch size = %d", args.eval_batch_size) 266 | all_time = 0 267 | num = 0 268 | output_file_path = f"{eval_output_dir}/recall_{mode}_{args.topk}.tsv" 269 | with open(output_file_path, 'w') as outputfile: 270 | for batch, qids in tqdm(dev_dataloader, desc="Evaluating"): 271 | model.eval() 272 | 273 | with torch.no_grad(): 274 | batch = {k:v.to(args.model_device) for k, v in batch.items()} 275 | embeddings = model( 276 | input_ids=batch["input_ids"], 277 | attention_mask=batch["attention_mask"], 278 | is_query=True) 279 | 280 | num = num + 1 281 | 282 | scores = metrics_count(embeddings, node_dict, args.topk) 283 | 284 | for qid, score_one in zip(qids, scores): 285 | index = 0 286 | for score in score_one: 287 | index = index + 1 288 | outputfile.write(f"{qid}\t{score}\t{index}\n") 289 | 290 | 291 | print("num %f" %(num)) 292 | 293 | 294 | def run_parse_args(): 295 | parser = argparse.ArgumentParser() 296 | 297 | ## Required parameters 298 | parser.add_argument("--task", choices=["train", "dev"], required=True) 299 | parser.add_argument("--output_dir", type=str, default=f"output") 300 | parser.add_argument("---init_path", type=str, default=f"model") 301 | parser.add_argument("--pembed_path", type=str, default=f"/passages.memmap") 302 | parser.add_argument("--preprocess_dir", type=str,default=f'/preprocess') 303 | parser.add_argument("--tree_dir", type=str,default=f'/tree') 304 | parser.add_argument("--per_gpu_batch_size", type=int, default=32) 305 | parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int) 306 | parser.add_argument("--max_seq_length", type=int, default=64) 307 | parser.add_argument("--gradient_accumulation_steps", type=int, default=2) 308 | parser.add_argument("--num_train_epochs", default=3, type=int) 309 | parser.add_argument("--warmup_steps", default=2000, type=int) 310 | parser.add_argument("--learning_rate", default=5e-6, type=float) 311 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 312 | parser.add_argument("--no_cuda", action='store_true', default=False) 313 | parser.add_argument('--seed', type=int, default=42) 314 | parser.add_argument("--max_grad_norm", default=1.0, type=float) 315 | parser.add_argument("--logging_steps", type=int, default=100) 316 | parser.add_argument("--save_steps", type=int, default=5000) 317 | parser.add_argument("--eval_ckpt", type=int, default=5000) 318 | parser.add_argument("--layer", type=int, default=5) 319 | parser.add_argument("--topk", type=int, default=10) 320 | 321 | args = parser.parse_args() 322 | 323 | time_stamp = time.strftime("%b-%d_%H:%M:%S", time.localtime()) 324 | args.log_dir = f"{args.output_dir}/log/{time_stamp}" 325 | args.model_save_dir = f"{args.output_dir}/models" 326 | args.eval_save_dir = f"{args.output_dir}/eval_results" 327 | 328 | return args 329 | 330 | 331 | 332 | 333 | 334 | def main(): 335 | args = run_parse_args() 336 | 337 | # Setup CUDA, GPU 338 | args.model_device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 339 | args.n_gpu = torch.cuda.device_count() 340 | 341 | # Setup logging 342 | logger.warning("Model Device: %s, n_gpu: %s", args.model_device, args.n_gpu) 343 | 344 | # Set seed 345 | set_seed(args) 346 | 347 | logger.info(f"load from {args.init_path}") 348 | if args.task == "train": 349 | config = RobertaConfig.from_pretrained(args.init_path) 350 | model = RobertaDot.from_pretrained(args.init_path, config=config) 351 | model.to(args.model_device) 352 | logger.info("Training/evaluation parameters %s", args) 353 | 354 | 355 | if args.task == "train": 356 | os.makedirs(args.model_save_dir, exist_ok=True) 357 | train(args, model) 358 | else: 359 | result = evaluate(args, model, node_dict, args.task, prefix=f"ckpt-{args.eval_ckpt}") 360 | 361 | 362 | 363 | if __name__ == "__main__": 364 | main() 365 | --------------------------------------------------------------------------------