├── model ├── __init__.py └── pro_model.py ├── dataset ├── __init__.py └── dataset.py ├── img └── Figure_1.png ├── LEGAL.md ├── utils └── common_utils.py ├── README.md ├── train.sh ├── preprocess ├── save_logits.py ├── save_hardneg_bm25.py └── save_hardnrg_bi.py ├── LICENSE.md └── train.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /img/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/D2LLM/main/img/Figure_1.png -------------------------------------------------------------------------------- /LEGAL.md: -------------------------------------------------------------------------------- 1 | Legal Disclaimer 2 | 3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. 4 | 5 | 法律免责声明 6 | 7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pathlib 4 | import numpy as np 5 | from scipy.stats import pearsonr, spearmanr 6 | import torch 7 | from loguru import logger 8 | import shutil 9 | from torch.utils.tensorboard import SummaryWriter 10 | import pickle 11 | import linecache 12 | import tracemalloc 13 | 14 | def set_seed(seed): 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | if torch.cuda.is_available(): 19 | torch.cuda.manual_seed_all(seed) 20 | 21 | def save_model(model_engine, ckpt_dir, client_state): 22 | model_engine.save_checkpoint(ckpt_dir, client_state=client_state, exclude_frozen_parameters=True) 23 | 24 | def remove_earlier_ckpt(path, start_name, current_step_num, max_save_num): 25 | 26 | filenames=os.listdir(path) 27 | ckpts = [dir_name for dir_name in filenames if dir_name.startswith(start_name) and int(dir_name.split('-')[1])<=current_step_num] 28 | 29 | current_ckpt_num = len(ckpts) 30 | for dir_name in filenames: 31 | if dir_name.startswith(start_name) and int(dir_name.split('-')[1]) <= current_step_num and current_ckpt_num > (max_save_num-1): 32 | shutil.rmtree(os.path.join(path, dir_name)) 33 | 34 | 35 | def makedirs(path): 36 | p = pathlib.Path(path) 37 | p.parent.mkdir(parents=True, exist_ok=True) 38 | return path 39 | 40 | def load_pickle(path): 41 | with open(path, "rb") as f: 42 | return pickle.load(f) 43 | 44 | def write_pickle(obj, path:str): 45 | if not os.path.exists(path): 46 | makedirs(path) 47 | with open(path, "wb") as f: 48 | return pickle.dump(obj, f) 49 | 50 | def write_tensorboard(summary_writer, log_dict, completed_steps): 51 | for key, value in log_dict.items(): 52 | summary_writer.add_scalar(f'{key}', value, completed_steps) 53 | 54 | def cos_sim(a, b): 55 | 56 | if not isinstance(a, torch.Tensor): 57 | a = torch.tensor(a) 58 | 59 | if not isinstance(b, torch.Tensor): 60 | b = torch.tensor(b) 61 | 62 | if len(a.shape) == 1: 63 | a = a.unsqueeze(0) 64 | 65 | if len(b.shape) == 1: 66 | b = b.unsqueeze(0) 67 | 68 | a_norm = torch.nn.functional.normalize(a, p=2, dim=1) 69 | b_norm = torch.nn.functional.normalize(b, p=2, dim=1) 70 | return torch.mm(a_norm, b_norm.transpose(0, 1)) 71 | 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # D2LLM: Decomposed and Distilled Large Language Models for Semantic Search 2 | 3 | This is the Pytorch implementation of D2LLM in the ACL'24 paper: D2LLM: Decomposed and Distilled Large Language Models for Semantic Search. 4 | 5 | ![The network architecture of D2LLM.](./img/Figure_1.png) 6 |
Figure 1. The network architecture of D2LLM.
7 | 8 | ## Requirements 9 | 10 | * Ubuntu OS 11 | * python==3.10 12 | * torch==2.0.1 13 | * cuda==11.7 14 | * transformers==4.37.0 15 | * deepspeed==0.14.2 16 | * flash-attn==2.3.6 17 | * peft==0.7.0 18 | 19 | Dependencies can be installed by: 20 | 21 | pip install -r requirements.txt 22 | 23 | 24 | The overall directory structure is as follows: 25 | 26 | ${CODE_ROOT} 27 | ...... 28 | |-- preprocess 29 | |-- save_hardneg_bm25.py 30 | |-- save_hardneg_bi.py 31 | |-- save_logits.py 32 | |-- dataset 33 | |-- dataset.py 34 | |-- model 35 | |-- pro_model.py 36 | |-- utils 37 | |-- common_utils.py 38 | |-- train.py 39 | |-- train.sh 40 | 41 | 42 | 43 | ## Data preparetion 44 | 45 | The six datasets (SNLI-zh, NLI-zh, T2Ranking, DuReader, cMedQA2 and mMARCO) used in this paper can be downloaded from the following links: 46 | 47 | * [SNLI-zh](https://huggingface.co/datasets/shibing624/snli-zh) 48 | * [NLI-zh](https://huggingface.co/datasets/shibing624/nli_zh) 49 | * [T2Ranking](https://github.com/THUIR/T2Ranking) 50 | * [DuReader](https://github.com/baidu/DuReader) 51 | * [cMedQA2](https://github.com/zhangsheng93/cMedQA2) 52 | * [mMARCO](https://huggingface.co/datasets/unicamp-dl/mmarco) 53 | 54 | Before performing training, we mine hard negatives through BM25 and other bi-encoder evaluations using scripts save_hardneg_bm25.py and save_hardneg_bi.py. Then, we use the script save_logits.py to perform correlation scoring on in-batch negatives and hard negatives through LLM. 55 | 56 | ## Train 57 | 58 | To perform training, just adjust the parameters and run: 59 | 60 | sh train.sh 61 | 62 | ## Evaluate 63 | 64 | Evaluation can be done throw the mteb tools. Note that the cosine similarity should be replace by the IEM module. 65 | 66 | ## Citation 67 | 68 | @inproceedings{ 69 | anonymous2024dllm, 70 | title={D2{LLM}: Decomposed and Distilled Large Language Models for Semantic Search}, 71 | author={Anonymous}, 72 | booktitle={The 62nd Annual Meeting of the Association for Computational Linguistics}, 73 | year={2024} 74 | } 75 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BASE_MODEL_DIR="PATH_OF_BASE_MODEL" 3 | TRAIN_DATA_LIST="TRAIN_DATASETS" 4 | POS_DIR="PATH_TO_POS_LOGITS" 5 | NEG_DIR="PATH_TO_NEG_LOGITS" 6 | DATA_DIR="DATASET_DIR" 7 | INBATCH_PKL_PATH_DIR="PATH_TO_INBATCH_LOGITS_PKL" 8 | FEATURE_PKL_PATH_DIR="PATH_TO_FEATURE_PKL" 9 | BATCH_SIZE=32 10 | NEG_K=8 11 | NUM_HEADS=32 12 | HIDDEN_DIM=512 13 | OUTPUT_DIM=1 14 | LN="True" 15 | NORM="False" 16 | PADDING_SIDE="right" 17 | NUM_EPOCHS=5 18 | MAX_SEQ_LENGTH=250 19 | LR=1e-4 20 | ALPHA=1 21 | BETA=1 22 | GAMMA=0.01 23 | ETA=0.001 24 | TEMPERATURE_IN_BATCH=1 25 | TEMPERATURE_HARDNEG=1 26 | TEMPERATURE_TEACHER_HARDNEG=1 27 | SCALE_PARAM=1 28 | LOG_INTERVAL=10 29 | EVAL_INTERVAL=300 30 | TB_DIR="PATH_TO_TENSORBOARD_PATH" 31 | PATIENCE=5 32 | NUM_CKPT=4 33 | TRAINING_LOG="PATH_TO_TRAINING_LOG" 34 | OUTPUT_DIR="PATH_TO_OUTPUT_MODEL" 35 | 36 | WORLD_SIZE=${WORLD_SIZE:-1} 37 | NODE_RANK=${RANK:-0} 38 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 39 | MASTER_PORT=${MASTER_PORT:-12346} 40 | 41 | python -m torch.distributed.run --nproc_per_node=$gpus --nnode=$WORLD_SIZE --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ 42 | train.py --base_model_dir $BASE_MODEL_DIR \ 43 | --train_data_list $TRAIN_DATA_LIST \ 44 | --pos_dir $POS_DIR \ 45 | --neg_dir $NEG_DIR \ 46 | --data_dir $DATA_DIR \ 47 | --inbatch_pkl_path_dir $INBATCH_PKL_PATH_DIR \ 48 | --feature_pkl_path_dir $FEATURE_PKL_PATH_DIR \ 49 | --batch_size $BATCH_SIZE 50 | --neg_K $NEG_K \ 51 | --num_heads $NUM_HEADS \ 52 | --hidden_dim $HIDDEN_DIM \ 53 | --output_dim $OUTPUT_DIM \ 54 | --ln $LN \ 55 | --norm $NORM \ 56 | --num_epochs $NUM_EPOCHS \ 57 | --padding_side $PADDING_SIDE \ 58 | --max_seq_length $MAX_SEQ_LENGTH \ 59 | --lr $LR \ 60 | --alpha $ALPHA \ 61 | --beta $BETA \ 62 | --gamma $GAMMA \ 63 | --eta $ETA \ 64 | --temperature_in_batch $TEMPERATURE_IN_BATCH \ 65 | --temperature_hardneg $TEMPERATURE_HARDNEG \ 66 | --temperature_teacher_hardneg $TEMPERATURE_TEACHER_HARDNEG \ 67 | --scale_param $SCALE_PARAM \ 68 | --log_interval $LOG_INTERVAL \ 69 | --eval_interval $EVAL_INTERVAL \ 70 | --tb_dir $TB_DIR \ 71 | --patience $PATIENCE \ 72 | --num_ckpt $NUM_CKPT \ 73 | --training_log $TRAINING_LOG \ 74 | --output_dir $OUTPUT_DIR \ -------------------------------------------------------------------------------- /preprocess/save_logits.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import argparse 5 | from tqdm import tqdm, trange 6 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig 7 | from utils.common_utils import load_pickle, write_pickle 8 | 9 | 10 | def sts_template_v5(text1, text2): 11 | return f'#P和#H将分别描述一种事件或问题,它们可能并无关系。仅使用此描述和您对世界的了解,判断#H是不是一个关于#P中的事件绝对正确的句子,或者#H是不是绝对正确地描述了#P的事件或问题,请回答是或不是,若您不确定,请回答不是。\n#P:{text1}\n#H:{text2}\n回答:' 12 | 13 | def context_template_v5(text1, text2): 14 | return f'#Q将描述一个问题,#A将描述一个网络段落,它们可能并没有关系。仅依据这些描述和您对世界的了解,判断#A能不能正确地回答#Q中提出的问题,请回答能或不能。\n#Q:{text1}\n#A:{text2}\n回答:' 15 | 16 | 17 | def generate_logits(model_dir, neg_pkl_file, task_type, bs, teacher_max_seq_length, num_shards, id_shard): 18 | bm_25_dict = load_pickle(neg_pkl_file) 19 | all_sample_list = [] 20 | len_dict = {} 21 | all_logits = [] 22 | res_dict = {} 23 | lenth_one = len(list(bm_25_dict.keys()))/num_shards 24 | for i, query in enumerate(bm_25_dict): 25 | if i >= lenth_one*id_shard and i < lenth_one*(id_shard+1): 26 | doc_list = bm_25_dict[query] 27 | len_dict[i] = len(doc_list) 28 | if task_type == 'context': 29 | qry_doc_list = [context_template_v5(query, d) for d in doc_list] 30 | elif task_type == 'sts': 31 | qry_doc_list = [sts_template_v5(query, d) for d in doc_list] 32 | all_sample_list.extend(qry_doc_list) 33 | teacher_tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, pad_token='<|endoftext|>', truncation_side='right', padding_side='left') 34 | teacher_tokenizer.pad_token_id = teacher_tokenizer.eod_id 35 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True).to('cuda') 36 | model.eval() 37 | if task_type == 'sts': 38 | yes_id = teacher_tokenizer.encode('是')[0] 39 | no_id = teacher_tokenizer.encode('不是')[0] 40 | elif task_type == 'context': 41 | yes_id = teacher_tokenizer.encode('能')[0] 42 | no_id = teacher_tokenizer.encode('不能')[0] 43 | else: 44 | raise ValueError(f'Error: No Task Type {task_type}') 45 | with torch.no_grad(): 46 | for start_index in trange(0, len(all_sample_list), bs, disable=False): 47 | print(start_index) 48 | cross_sentence_batch = all_sample_list[start_index: start_index+bs] 49 | cross_sentence_inputs = teacher_tokenizer(text=cross_sentence_batch, padding='max_length', max_length=teacher_max_seq_length, truncation=True, return_tensors='pt').to('cuda') 50 | outputs_logits = model(**cross_sentence_inputs).logits 51 | outputs_logits = outputs_logits[:, -1, [yes_id, no_id]].cpu().float().numpy().tolist() 52 | all_logits.extend(outputs_logits) 53 | assert len(all_logits) == len(all_sample_list) 54 | start = 0 55 | for i, query in enumerate(bm_25_dict): 56 | if i >= lenth_one*id_shard and i < lenth_one*(id_shard+1): 57 | end = start + len_dict[i] 58 | doc_list = bm_25_dict[query] 59 | logits_list = all_logits[start:end] 60 | assert len(doc_list) == len(logits_list) 61 | res_doc_logits = list(zip(doc_list, logits_list)) 62 | res_dict[query] = res_doc_logits 63 | start = end 64 | return res_dict 65 | 66 | 67 | 68 | if __name__ == '__main__': 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument('--model_dir', default='', type=str) 71 | parser.add_argument('--hardneg_dir', default='', type=str) 72 | parser.add_argument('--output_pkl', default='', type=str) 73 | parser.add_argument('--dataset', default='', type=str) 74 | parser.add_argument('--task_type', default='', type=str) 75 | parser.add_argument('--bs', default=140, type=int) 76 | parser.add_argument('--K', type=int) 77 | parser.add_argument('--teacher_max_seq_length', default=500, type=int) 78 | parser.add_argument('--num_shards', default=8, type=int) 79 | parser.add_argument('--id_shard', default=0, type=int) 80 | args = parser.parse_args() 81 | 82 | neg_pkl_file = args.hardneg_dir 83 | output_pkl_path = args.output_pkl 84 | res_dict = generate_logits(args.model_dir, neg_pkl_file, args.task_type, args.bs, args.teacher_max_seq_length, args.num_shards, args.id_shard) 85 | write_pickle(res_dict, output_pkl_path) 86 | 87 | -------------------------------------------------------------------------------- /preprocess/save_hardneg_bm25.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import csv 3 | import time 4 | import os 5 | import jieba 6 | import pickle 7 | import argparse 8 | from rank_bm25 import BM25Okapi 9 | from collections import defaultdict 10 | from tqdm import tqdm 11 | from datasets import load_dataset 12 | 13 | def write_pickle(obj, file): 14 | with open(file, 'wb') as f: 15 | pickle.dump(obj, f) 16 | 17 | def load_pickle(file): 18 | with open(file, 'rb') as f: 19 | obj = pickle.load(f) 20 | return obj 21 | 22 | 23 | def load_snli_zh(path): 24 | queries = [] 25 | corpus = [] 26 | 27 | pos_sample_dict = defaultdict(list) 28 | 29 | with open(path, encoding='utf-8') as f: 30 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 31 | for id, row in enumerate(reader): 32 | text_a = row['sentence1'] 33 | text_b = row['sentence2'] 34 | label = row['gold_label'] 35 | 36 | if isinstance(text_b, str): 37 | corpus.append(text_b) 38 | 39 | if label == 'entailment': 40 | if isinstance(text_a, str): 41 | queries.append(text_a) 42 | 43 | pos_sample_dict[text_a].append(text_b) 44 | 45 | return queries, list(set(corpus)), pos_sample_dict 46 | 47 | 48 | def load_sts_zh(path): 49 | queries = [] 50 | corpus = [] 51 | pos_sample_dict = defaultdict(list) 52 | dataset = load_dataset(path, split='train') 53 | for id, row in enumerate(dataset): 54 | text_a = row['sentence1'] 55 | text_b = row['sentence2'] 56 | label = row['label'] 57 | if isinstance(text_b, str): 58 | corpus.append(text_b) 59 | if path.split('/')[-1] != 'STS-B': 60 | if label == 1: 61 | if isinstance(text_a, str): 62 | queries.append(text_a) 63 | 64 | pos_sample_dict[text_a].append(text_b) 65 | else: 66 | if label >= 4: 67 | if isinstance(text_a, str) : 68 | queries.append(text_a) 69 | pos_sample_dict[text_a].append(text_b) 70 | return queries, list(set(corpus)), pos_sample_dict 71 | 72 | def load_t2(path): 73 | queries = [] 74 | corpus = [] 75 | 76 | pos_sample_dict = defaultdict(list) 77 | with open(path, 'r', encoding='utf-8') as f: 78 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 79 | for id, row in enumerate(reader): 80 | text_a = row['sentece1'] 81 | text_b = row['sentence2'] 82 | 83 | if isinstance(text_b, str): 84 | corpus.append(text_b[:320]) 85 | 86 | if isinstance(text_a, str): 87 | queries.append(text_a) 88 | 89 | pos_sample_dict[text_a].append(text_b[:320]) 90 | 91 | 92 | 93 | return queries, list(set(corpus)), pos_sample_dict 94 | 95 | 96 | 97 | def main(): 98 | 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--data_name', default='', type=str) 101 | parser.add_argument('--K', default=10, type=int) 102 | parser.add_argument('--num', default=50, type=int) 103 | 104 | args = parser.parse_args() 105 | 106 | 107 | stopwords = [] 108 | with open('STOPWORDS_PATH', 'r', encoding='utf8') as f: 109 | for line in f: 110 | line = line.strip('\n') 111 | stopwords.append(line) 112 | output_dir = "OUTPUTS_NEG_BM25_PATH" 113 | if args.data_name == 'snli-zh': 114 | queries, corpus, pos_sample_dict = load_snli_zh("NLI_DATA_PATH") 115 | output_pickle = os.path.join(output_dir, args.data_name+'.pkl') 116 | if args.data_name in ['ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STS-B']: 117 | queries, corpus, pos_sample_dict = load_sts_zh("STS_DATA_PATH") 118 | output_pickle = os.path.join(output_dir, args.data_name+'.pkl') 119 | if args.data_name == 't2': 120 | queries, corpus, pos_sample_dict = load_t2("T2_DATA_PATH") 121 | output_pickle = os.path.join(output_dir, args.data_name+'.pkl') 122 | tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus] 123 | tokenized_corpus = [list(set(tokenized_doc).difference(set(stopwords))) for tokenized_doc in tokenized_corpus] 124 | bm25 = BM25Okapi(tokenized_corpus) 125 | 126 | 127 | tokenized_queries = [list(jieba.cut(q)) for q in queries] 128 | tokenized_queries = [list(set(tokenized_query).difference(set(stopwords))) for tokenized_query in tokenized_queries] 129 | assert len(queries) == len(tokenized_queries) 130 | 131 | hard_neg_sample_dict = defaultdict(list) 132 | for i,tokenized_query in enumerate(tqdm(tokenized_queries)): 133 | doc_scores = bm25.get_scores(tokenized_query) 134 | res_docs = bm25.get_top_n(tokenized_query, corpus, n=args.K) 135 | for pos in pos_sample_dict[queries[i]]: 136 | while pos in res_docs: 137 | res_docs.remove(pos) 138 | 139 | hard_neg_sample_dict[queries[i]] = res_docs 140 | 141 | 142 | if not os.path.exists(output_dir): 143 | os.makedirs(output_dir) 144 | 145 | write_pickle(hard_neg_sample_dict, output_pickle) 146 | -------------------------------------------------------------------------------- /model/pro_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from tqdm import tqdm, trange 9 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig 10 | 11 | class MAB(nn.Module): 12 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 13 | super(MAB, self).__init__() 14 | self.dim_V = dim_V 15 | self.num_heads = num_heads 16 | self.fc_q = nn.Linear(dim_Q, dim_V) 17 | self.fc_k = nn.Linear(dim_K, dim_V) 18 | self.fc_v = nn.Linear(dim_K, dim_V) 19 | 20 | if ln: 21 | self.ln0 = nn.LayerNorm(dim_V) 22 | self.ln1 = nn.LayerNorm(dim_V) 23 | self.fc_o = nn.Linear(dim_V, dim_V) 24 | nn.init.xavier_uniform_(self.fc_q.weight) 25 | nn.init.xavier_uniform_(self.fc_k.weight) 26 | nn.init.xavier_uniform_(self.fc_v.weight) 27 | nn.init.xavier_uniform_(self.fc_o.weight) 28 | 29 | class PMA(nn.Module): 30 | def __init__(self, dim, num_heads, num_seeds, ln=False): 31 | super(PMA, self).__init__() 32 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 33 | nn.init.xavier_uniform_(self.S) 34 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln) 35 | def forward(self, X, pad_mask): 36 | if self.S.dtype != torch.bfloat16: 37 | X = X.float() 38 | return self.mab(self.S.repeat(X.size(0), 1, 1), X, pad_mask) 39 | 40 | def forward(self, Q, K, pad_mask=None): 41 | 42 | Q_ = self.fc_q(Q) 43 | K_, V_ = self.fc_k(K), self.fc_v(K) 44 | dim_split = self.dim_V // self.num_heads 45 | Q_ = torch.cat(Q_.split(dim_split, 2), 0) 46 | K_ = torch.cat(K_.split(dim_split, 2), 0) 47 | V_ = torch.cat(V_.split(dim_split, 2), 0) 48 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) 49 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) 50 | score = score.masked_fill(pad_mask == 0, -1e12) 51 | A = torch.softmax(score, 2) 52 | A = A * pad_mask 53 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) 54 | O = Q + O 55 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 56 | O = O + F.relu(self.fc_o(O)) 57 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 58 | return O 59 | 60 | 61 | class IEM(nn.Module): 62 | 63 | def __init__(self, d_model, hidden, d_output, drop_prob=0.0): 64 | super(IEM, self).__init__() 65 | self.linear1 = nn.Linear(2*d_model, hidden) 66 | self.proj0 = nn.Linear(hidden, hidden) 67 | self.proj1 = nn.Linear(hidden, hidden) 68 | self.linear2 = nn.Linear(hidden, d_output) 69 | nn.init.xavier_uniform_(self.linear1.weight) 70 | nn.init.xavier_uniform_(self.proj0.weight) 71 | nn.init.xavier_uniform_(self.proj1.weight) 72 | nn.init.xavier_uniform_(self.linear2.weight) 73 | self.relu = nn.ReLU() 74 | self.dropout = nn.Dropout(p=drop_prob) 75 | self.sftmx = nn.Softmax(dim=-1) 76 | 77 | def forward(self, emb_a, emb_b): 78 | x = torch.cat((emb_a, emb_b), dim=-1) 79 | x = self.linear1(x) 80 | x = self.relu(x) 81 | x = self.dropout(x) 82 | x0 = self.proj0(x) 83 | x1 = self.proj1(x) 84 | x0 = self.relu(x0) 85 | x1 = self.relu(x1) 86 | rep = torch.stack((x0,x1),dim=0) 87 | logits0 = self.linear2(x0) 88 | logits1 = self.linear2(x1) 89 | logits = torch.cat((logits0, logits1), dim=-1) 90 | return logits, rep 91 | 92 | 93 | 94 | class Mymodel(nn.Module): 95 | def __init__(self, 96 | model_name_or_path = None, 97 | alias = None, 98 | max_seq_length = 256, 99 | args = None 100 | ): 101 | super(Mymodel, self).__init__() 102 | self.alias = alias 103 | if self.alias == None: 104 | self.alias = model_name_or_path 105 | self.args = args 106 | self.max_seq_length = max_seq_length 107 | self.model_name_or_path = model_name_or_path 108 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True, pad_token='<|endoftext|>', truncation_side='right', padding_side=self.args.padding_side) 109 | self.tokenizer.pad_token_id = self.tokenizer.eod_id 110 | self.plm_model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, trust_remote_code=True) 111 | self.emb_dim = self.plm_model.transformer.wte.weight.size(1) 112 | self.num_heads = args.num_heads 113 | self.ln = args.ln 114 | self.norm = args.norm 115 | self.mha_pma = PMA(self.emb_dim, self.num_heads, 1, ln=self.ln) 116 | self.iem = IEM(self.emb_dim, self.hidden_dim, self.output_dim) 117 | 118 | def forward(self, inputs_all, task_ids, mode): 119 | if mode == 'train': 120 | output_embeddings_all = self.get_sentence_embedding(**inputs_all).reshape(2+self.args.neg_K, -1, self.emb_dim) 121 | output_embeddings_hardneg = output_embeddings_all[2:] 122 | elif mode == 'eval': 123 | output_embeddings_all = self.get_sentence_embedding(**inputs_all).reshape(2, -1, self.emb_dim) 124 | else: 125 | raise ValueError('Error of mode value') 126 | 127 | output_embeddings_a = output_embeddings_all[0] 128 | output_embeddings_b = output_embeddings_all[1] 129 | 130 | bs = output_embeddings_a.size(0) 131 | a_expand_emb = output_embeddings_a.unsqueeze(1).expand(-1, bs, -1).reshape(-1, self.emb_dim) 132 | b_expand_emb = output_embeddings_b.unsqueeze(0).expand(bs, -1, -1).reshape(-1, self.emb_dim) 133 | 134 | task_expand = task_ids.unsqueeze(1).expand(-1, bs).reshape(-1,1).squeeze() 135 | output_in_batch, _ = self.iem(a_expand_emb, b_expand_emb) # (bs*bs, 2) 136 | output_in_batch_specific_task = output_in_batch[range(task_expand.size(0)), task_expand].squeeze().reshape(bs, -1) 137 | 138 | if mode == 'train': 139 | pos_neg_emb = torch.cat([output_embeddings_b.unsqueeze(0), output_embeddings_hardneg], dim=0) 140 | achr_emb = output_embeddings_a.unsqueeze(0).expand(pos_neg_emb.size(0),-1,-1) 141 | output_hardneg, output_pos_hardneg_rep = self.iem(achr_emb, pos_neg_emb) 142 | task_id_gather = task_ids.unsqueeze(0).unsqueeze(-1).expand(pos_neg_emb.size(0), -1, -1) 143 | output_hardneg_specific_task = torch.gather(output_hardneg, -1, task_id_gather).squeeze().t() 144 | output_pos_hardneg_rep_specific_task = output_pos_hardneg_rep[task_ids[0]] 145 | elif mode == 'eval': 146 | output_hardneg_specific_task = None 147 | output_pos_hardneg_rep_specific_task = None 148 | 149 | return output_in_batch_specific_task, output_hardneg_specific_task, output_pos_hardneg_rep_specific_task 150 | 151 | def pma_embedding(self, A, mask): 152 | res = self.mha_pma(A, mask).squeeze(1) 153 | return res 154 | 155 | def get_sentence_embedding(self, **inputs): 156 | outputs = self.plm_model(**inputs, output_hidden_states=True) 157 | embedding = outputs.hidden_states[self.keep_max_layer] 158 | attention_mask = inputs['attention_mask'] 159 | res_embedding = self.pma_embedding(embedding, attention_mask) 160 | 161 | if self.norm: 162 | res_embedding = torch.nn.functional.normalize(res_embedding, p=2.0, dim=-1, eps=1e-12, out=None) 163 | return res_embedding 164 | 165 | def encode(self, sentences, batch_size=64, convert_to_numpy=True, 166 | convert_to_tensor=False, show_progress_bar=True, max_seq_length=None, **kwargs): 167 | 168 | if max_seq_length is None: 169 | max_seq_length = self.max_seq_length 170 | 171 | input_is_string = False 172 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"): 173 | sentences = [sentences] 174 | input_is_string = True 175 | 176 | all_embeddings = [] 177 | length_sorted_idx = np.argsort([-len(s) for s in sentences]) 178 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] 179 | with torch.no_grad(): 180 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): 181 | sentences_batch = sentences_sorted[start_index: start_index + batch_size] 182 | with torch.no_grad(): 183 | inputs = self.tokenizer(sentences_batch, padding=True, truncation=True, max_length=max_seq_length, return_tensors='pt').to(self.plm_model.device) 184 | embeddings = self.get_sentence_embedding(**inputs) 185 | embeddings = embeddings.detach() 186 | if convert_to_numpy: 187 | if embeddings.dtype == torch.bfloat16: 188 | embeddings = embeddings.cpu().to(torch.float32) 189 | else: 190 | embeddings = embeddings.cpu() 191 | all_embeddings.extend(embeddings) 192 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] 193 | if convert_to_tensor: 194 | all_embeddings = torch.stack(all_embeddings) 195 | elif convert_to_numpy: 196 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 197 | 198 | if input_is_string: 199 | all_embeddings = all_embeddings[0] 200 | return all_embeddings 201 | 202 | -------------------------------------------------------------------------------- /preprocess/save_hardnrg_bi.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import csv 4 | import pathlib 5 | import json 6 | import argparse 7 | import warnings 8 | import deepspeed 9 | from enum import Enum 10 | from typing import Union, List 11 | from datasets import load_dataset 12 | from tqdm import tqdm, trange 13 | from collections import defaultdict 14 | from utils.common_utils import * 15 | warnings.filterwarnings('ignore') 16 | from mteb.mteb import MTEB 17 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig 18 | from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType 19 | maxInt = sys.maxsize 20 | 21 | while True: 22 | try: 23 | csv.field_size_limit(maxInt) 24 | break 25 | except OverflowError: 26 | maxInt = int(maxInt/10) 27 | 28 | def makedirs(path): 29 | p = pathlib.Path(path) 30 | p.parent.mkdir(parents=True, exist_ok=True) 31 | return path 32 | 33 | class EncoderType(Enum): 34 | FIRST_LAST_AVG = 0 35 | LAST_AVG = 1 36 | CLS = 2 37 | POOLER = 3 38 | MEAN = 4 39 | 40 | def __str__(self): 41 | return self.name 42 | 43 | @staticmethod 44 | def from_string(s): 45 | try: 46 | return EncoderType[s] 47 | except KeyError: 48 | raise ValueError() 49 | 50 | class BaseBertModel: 51 | def __init__( 52 | self, 53 | model_name_or_path = None, 54 | max_seq_length = 512, 55 | encoder_type = 'CLS', 56 | alias = None 57 | ): 58 | self.model_name_or_path = model_name_or_path 59 | encoder_type = EncoderType.from_string(encoder_type) if isinstance(encoder_type, str) else encoder_type 60 | if encoder_type not in list(EncoderType): 61 | raise ValueError(f'encoder_type must be in {list(EncoderType)}') 62 | self.encoder_type = encoder_type 63 | self.max_seq_length = max_seq_length 64 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side='right', padding_side='right') 65 | self.plm_model = AutoModel.from_pretrained(model_name_or_path) 66 | self.results = {} 67 | device = "cuda" if torch.cuda.is_available() else "cpu" 68 | self.device = torch.device(device) 69 | self.plm_model.to(self.device) 70 | 71 | 72 | 73 | def get_sentence_embeddings(self, input_ids, attention_mask, token_type_ids=None): 74 | model_output = self.plm_model(input_ids, attention_mask, token_type_ids, output_hidden_states=True) 75 | 76 | if self.encoder_type == EncoderType.FIRST_LAST_AVG: 77 | first = model_output.hidden_states[1] 78 | last = model_output.hidden_states[-1] 79 | seq_length = first.size(1) 80 | 81 | first_avg = torch.avg_pool1d(first.transpose(1, 2), kernel_size=seq_length).squeeze(-1) 82 | last_avg = torch.avg_pool1d(last.transpose(1, 2), kernel_size=seq_length).squeeze(-1) 83 | final_encoding = torch.avg_pool1d( 84 | torch.cat([first_avg.unsqueeze(1), last_avg.unsqueeze(1)], dim=1).transpose(1, 2), 85 | kernel_size=2).squeeze(-1) 86 | return final_encoding 87 | 88 | if self.encoder_type == EncoderType.LAST_AVG: 89 | sequence_output = model_output.last_hidden_state 90 | seq_length = sequence_output.size(1) 91 | final_encoding = torch.avg_pool1d(sequence_output.transpose(1, 2), kernel_size=seq_length).squeeze(-1) 92 | return final_encoding 93 | 94 | if self.encoder_type == EncoderType.CLS: 95 | sequence_output = model_output.last_hidden_state 96 | return sequence_output[:, 0] 97 | 98 | if self.encoder_type == EncoderType.POOLER: 99 | return model_output.pooler_output 100 | 101 | if self.encoder_type == EncoderType.MEAN: 102 | token_embeddings = model_output.last_hidden_state # Contains all token embeddings 103 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 104 | final_encoding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( 105 | input_mask_expanded.sum(1), min=1e-9) 106 | return final_encoding # [batch, hid_size] 107 | 108 | def batch_to_device(self, batch, device): 109 | for key in batch: 110 | if isinstance(batch[key], torch.Tensor): 111 | batch[key] = batch[key].to(device) 112 | return batch 113 | 114 | 115 | def encode( 116 | self, 117 | sentences: Union[str, List[str]], 118 | batch_size: int = 32, 119 | show_progress_bar: bool = False, 120 | convert_to_numpy: bool = True, 121 | convert_to_tensor: bool = False, 122 | device: str = None, 123 | normalize_embeddings: bool = True, 124 | max_seq_length: int = None, 125 | ): 126 | self.plm_model.eval() 127 | if device is None: 128 | device = self.device 129 | self.plm_model.to(device) 130 | 131 | if max_seq_length is None: 132 | max_seq_length = self.max_seq_length 133 | if convert_to_tensor: 134 | convert_to_numpy = False 135 | input_is_string = False 136 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"): 137 | sentences = [sentences] 138 | input_is_string = True 139 | 140 | all_embeddings = [] 141 | length_sorted_idx = np.argsort([-len(s) for s in sentences]) 142 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] 143 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): 144 | sentences_batch = sentences_sorted[start_index: start_index + batch_size] 145 | with torch.no_grad(): 146 | features = self.tokenizer( 147 | sentences_batch, max_length=max_seq_length, 148 | padding=True, truncation=True, return_tensors='pt' 149 | ) 150 | features = self.batch_to_device(features, device) 151 | embeddings = self.get_sentence_embeddings(**features) 152 | embeddings = embeddings.detach() 153 | if normalize_embeddings: 154 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) 155 | 156 | if convert_to_numpy: 157 | embeddings = embeddings.cpu() 158 | all_embeddings.extend(embeddings) 159 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] 160 | if convert_to_tensor: 161 | all_embeddings = torch.stack(all_embeddings) 162 | elif convert_to_numpy: 163 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 164 | 165 | if input_is_string: 166 | all_embeddings = all_embeddings[0] 167 | 168 | return all_embeddings 169 | 170 | def write_t2_corpus(model, output_dir): 171 | makedirs(output_dir) 172 | corpus = set() 173 | corpus_path = "PATH_TO_SAVED_CORPUS" 174 | with open(corpus_path, 'r', encoding='utf-8') as f: 175 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 176 | for id, row in enumerate(reader): 177 | corpus.add(row['text'][:320]) 178 | 179 | corpus = list(corpus) 180 | 181 | corpus_psg_id_dict = {psg:id for id, psg in enumerate(corpus)} 182 | corpus_id_psg_dict = {id:psg for id, psg in enumerate(corpus)} 183 | 184 | corpus_psg_id_dict_path = "PATH_TO_SAVED_PSG_ID_DICT" 185 | corpus_id_psg_dict_path = "PATH_TO_SAVED_ID_PSG_DICT" 186 | corpus_rep_path = "PATH_TO_SAVED_REP" 187 | corpus_rep = model.encode(corpus, batch_size=1500, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True, max_seq_length=250).to('cpu') 188 | 189 | write_pickle(corpus_psg_id_dict, corpus_psg_id_dict_path) 190 | write_pickle(corpus_id_psg_dict, corpus_id_psg_dict_path) 191 | write_pickle(corpus_rep, corpus_rep_path) 192 | 193 | 194 | def write_t2_qry(model, corpus_psg_id_dict_path, corpus_id_psg_dict_path, corpus_rep_path, output_dir, K): 195 | res = defaultdict(list) 196 | queries = [] 197 | pos_sample_dict = defaultdict(list) 198 | corpus_psg_id_dict = load_pickle(corpus_psg_id_dict_path) 199 | corpus_id_psg_dict = load_pickle(corpus_id_psg_dict_path) 200 | corpus_rep = load_pickle(corpus_rep_path) 201 | query_path = f'DATA_PATH' 202 | data_all_path = f'ALL_DATA_PATH' 203 | 204 | with open(data_all_path, 'r', encoding='utf-8') as f: 205 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 206 | for id, row in enumerate(reader): 207 | text_a = row['sentence1'] 208 | text_b = row['sentence2'][:320] 209 | pos_sample_dict[text_a].append(text_b) 210 | 211 | with open(query_path, 'r', encoding='utf-8') as f: 212 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 213 | for id, row in enumerate(reader): 214 | text_a = row['sentence1'] 215 | queries.append(text_a) 216 | 217 | makedirs("QUERY_PATH") 218 | if not os.path.exists("QUERY_PKL_PATH"): 219 | queries_rep = model.encode(queries, batch_size=1500, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True, max_seq_length=100).to('cpu') 220 | write_pickle(queries_rep, "QUERY_PKL_PATH") 221 | queries_rep = load_pickle("QUERY_PKL_PATH") 222 | 223 | 224 | qry_chunk_size = 2000 225 | qry_num = queries_rep.size(0) 226 | corpus_num = corpus_rep.size(0) 227 | for start in trange(0, qry_num, qry_chunk_size, disable=False): 228 | end = min(start+qry_chunk_size, qry_num) 229 | qry_bch_rep = queries_rep[start:end, :] 230 | score_bch = cos_sim(qry_bch_rep, corpus_rep) 231 | _, ids = torch.topk(score_bch, min(K+1, score_bch.size(1)), dim=1, largest=True,sorted=True) 232 | ids = ids.tolist() 233 | for qry_id in range(start, end): 234 | id_from_zero = qry_id - start 235 | qry_text = queries[qry_id] 236 | pos_text_list = pos_sample_dict[qry_text] 237 | for sub_id in ids[id_from_zero][-100:]: 238 | hardneg_text = corpus_id_psg_dict[sub_id] 239 | if hardneg_text not in pos_text_list and hardneg_text not in res[qry_text]: 240 | res[qry_text].append(hardneg_text) 241 | 242 | res_path = "FINAL_RES_PATH" 243 | write_pickle(res, res_path) 244 | 245 | 246 | 247 | def main(): 248 | parser = argparse.ArgumentParser() 249 | parser.add_argument('--dataset', default='', type=str) 250 | parser.add_argument('--data_path', default='', type=str) 251 | parser.add_argument('--output_dir', default='', type=str) 252 | parser.add_argument('--ratio', default=0.5, type=float) 253 | parser.add_argument('--K', default=100, type=int) 254 | parser.add_argument('--base_model_dir', default='', type=str) 255 | parser.add_argument('--max_seq_len', default=250, type=int, help='max sequence length') 256 | parser.add_argument('--seed', default=2023, type=int) 257 | args = parser.parse_args() 258 | set_seed(args.seed) 259 | args.output_corpus_path = os.path.join(args.data_path, 'corpus') 260 | makedirs(args.output_corpus_path) 261 | model = BaseBertModel(model_name_or_path=args.base_model_dir, 262 | alias=None, 263 | encoder_type = 'CLS', 264 | max_seq_length=args.max_seq_len) 265 | 266 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 267 | device = torch.device(device) 268 | model.plm_model.to(device) 269 | model.plm_model.eval() 270 | 271 | if not os.path.exists(os.path.join(f'{args.output_corpus_path}', 'corpus_rep.pkl')): 272 | if args.dataset == 'T2Ranking': 273 | write_t2_corpus(model, f'{args.output_corpus_path}') 274 | corpus_psg_id_dict_path = os.path.join(f'{args.output_corpus_path}', 'corpus_psg_id_dict.pkl') 275 | corpus_id_psg_dict_path = os.path.join(f'{args.output_corpus_path}', 'corpus_id_psg_dict.pkl') 276 | corpus_rep_path = os.path.join(f'{args.output_corpus_path}', 'corpus_rep.pkl') 277 | if args.dataset == 'T2Ranking': 278 | write_t2_qry(args.ratio, model, corpus_psg_id_dict_path, corpus_id_psg_dict_path, corpus_rep_path, args.output_dir, args.K) 279 | 280 | 281 | if __name__ == '__main__': 282 | main() -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright [2023] [Ant Group] 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | http://www.apache.org/licenses/LICENSE-2.0 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | 13 | 14 | Apache License 15 | Version 2.0, January 2004 16 | http://www.apache.org/licenses/ 17 | 18 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 19 | 20 | 1. Definitions. 21 | 22 | "License" shall mean the terms and conditions for use, reproduction, 23 | and distribution as defined by Sections 1 through 9 of this document. 24 | 25 | "Licensor" shall mean the copyright owner or entity authorized by 26 | the copyright owner that is granting the License. 27 | 28 | "Legal Entity" shall mean the union of the acting entity and all 29 | other entities that control, are controlled by, or are under common 30 | control with that entity. For the purposes of this definition, 31 | "control" means (i) the power, direct or indirect, to cause the 32 | direction or management of such entity, whether by contract or 33 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 34 | outstanding shares, or (iii) beneficial ownership of such entity. 35 | 36 | "You" (or "Your") shall mean an individual or Legal Entity 37 | exercising permissions granted by this License. 38 | 39 | "Source" form shall mean the preferred form for making modifications, 40 | including but not limited to software source code, documentation 41 | source, and configuration files. 42 | 43 | "Object" form shall mean any form resulting from mechanical 44 | transformation or translation of a Source form, including but 45 | not limited to compiled object code, generated documentation, 46 | and conversions to other media types. 47 | 48 | "Work" shall mean the work of authorship, whether in Source or 49 | Object form, made available under the License, as indicated by a 50 | copyright notice that is included in or attached to the work 51 | (an example is provided in the Appendix below). 52 | 53 | "Derivative Works" shall mean any work, whether in Source or Object 54 | form, that is based on (or derived from) the Work and for which the 55 | editorial revisions, annotations, elaborations, or other modifications 56 | represent, as a whole, an original work of authorship. For the purposes 57 | of this License, Derivative Works shall not include works that remain 58 | separable from, or merely link (or bind by name) to the interfaces of, 59 | the Work and Derivative Works thereof. 60 | 61 | "Contribution" shall mean any work of authorship, including 62 | the original version of the Work and any modifications or additions 63 | to that Work or Derivative Works thereof, that is intentionally 64 | submitted to Licensor for inclusion in the Work by the copyright owner 65 | or by an individual or Legal Entity authorized to submit on behalf of 66 | the copyright owner. For the purposes of this definition, "submitted" 67 | means any form of electronic, verbal, or written communication sent 68 | to the Licensor or its representatives, including but not limited to 69 | communication on electronic mailing lists, source code control systems, 70 | and issue tracking systems that are managed by, or on behalf of, the 71 | Licensor for the purpose of discussing and improving the Work, but 72 | excluding communication that is conspicuously marked or otherwise 73 | designated in writing by the copyright owner as "Not a Contribution." 74 | 75 | "Contributor" shall mean Licensor and any individual or Legal Entity 76 | on behalf of whom a Contribution has been received by Licensor and 77 | subsequently incorporated within the Work. 78 | 79 | 2. Grant of Copyright License. Subject to the terms and conditions of 80 | this License, each Contributor hereby grants to You a perpetual, 81 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 82 | copyright license to reproduce, prepare Derivative Works of, 83 | publicly display, publicly perform, sublicense, and distribute the 84 | Work and such Derivative Works in Source or Object form. 85 | 86 | 3. Grant of Patent License. Subject to the terms and conditions of 87 | this License, each Contributor hereby grants to You a perpetual, 88 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 89 | (except as stated in this section) patent license to make, have made, 90 | use, offer to sell, sell, import, and otherwise transfer the Work, 91 | where such license applies only to those patent claims licensable 92 | by such Contributor that are necessarily infringed by their 93 | Contribution(s) alone or by combination of their Contribution(s) 94 | with the Work to which such Contribution(s) was submitted. If You 95 | institute patent litigation against any entity (including a 96 | cross-claim or counterclaim in a lawsuit) alleging that the Work 97 | or a Contribution incorporated within the Work constitutes direct 98 | or contributory patent infringement, then any patent licenses 99 | granted to You under this License for that Work shall terminate 100 | as of the date such litigation is filed. 101 | 102 | 4. Redistribution. You may reproduce and distribute copies of the 103 | Work or Derivative Works thereof in any medium, with or without 104 | modifications, and in Source or Object form, provided that You 105 | meet the following conditions: 106 | 107 | (a) You must give any other recipients of the Work or 108 | Derivative Works a copy of this License; and 109 | 110 | (b) You must cause any modified files to carry prominent notices 111 | stating that You changed the files; and 112 | 113 | (c) You must retain, in the Source form of any Derivative Works 114 | that You distribute, all copyright, patent, trademark, and 115 | attribution notices from the Source form of the Work, 116 | excluding those notices that do not pertain to any part of 117 | the Derivative Works; and 118 | 119 | (d) If the Work includes a "NOTICE" text file as part of its 120 | distribution, then any Derivative Works that You distribute must 121 | include a readable copy of the attribution notices contained 122 | within such NOTICE file, excluding those notices that do not 123 | pertain to any part of the Derivative Works, in at least one 124 | of the following places: within a NOTICE text file distributed 125 | as part of the Derivative Works; within the Source form or 126 | documentation, if provided along with the Derivative Works; or, 127 | within a display generated by the Derivative Works, if and 128 | wherever such third-party notices normally appear. The contents 129 | of the NOTICE file are for informational purposes only and 130 | do not modify the License. You may add Your own attribution 131 | notices within Derivative Works that You distribute, alongside 132 | or as an addendum to the NOTICE text from the Work, provided 133 | that such additional attribution notices cannot be construed 134 | as modifying the License. 135 | 136 | You may add Your own copyright statement to Your modifications and 137 | may provide additional or different license terms and conditions 138 | for use, reproduction, or distribution of Your modifications, or 139 | for any such Derivative Works as a whole, provided Your use, 140 | reproduction, and distribution of the Work otherwise complies with 141 | the conditions stated in this License. 142 | 143 | 5. Submission of Contributions. Unless You explicitly state otherwise, 144 | any Contribution intentionally submitted for inclusion in the Work 145 | by You to the Licensor shall be under the terms and conditions of 146 | this License, without any additional terms or conditions. 147 | Notwithstanding the above, nothing herein shall supersede or modify 148 | the terms of any separate license agreement you may have executed 149 | with Licensor regarding such Contributions. 150 | 151 | 6. Trademarks. This License does not grant permission to use the trade 152 | names, trademarks, service marks, or product names of the Licensor, 153 | except as required for reasonable and customary use in describing the 154 | origin of the Work and reproducing the content of the NOTICE file. 155 | 156 | 7. Disclaimer of Warranty. Unless required by applicable law or 157 | agreed to in writing, Licensor provides the Work (and each 158 | Contributor provides its Contributions) on an "AS IS" BASIS, 159 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 160 | implied, including, without limitation, any warranties or conditions 161 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 162 | PARTICULAR PURPOSE. You are solely responsible for determining the 163 | appropriateness of using or redistributing the Work and assume any 164 | risks associated with Your exercise of permissions under this License. 165 | 166 | 8. Limitation of Liability. In no event and under no legal theory, 167 | whether in tort (including negligence), contract, or otherwise, 168 | unless required by applicable law (such as deliberate and grossly 169 | negligent acts) or agreed to in writing, shall any Contributor be 170 | liable to You for damages, including any direct, indirect, special, 171 | incidental, or consequential damages of any character arising as a 172 | result of this License or out of the use or inability to use the 173 | Work (including but not limited to damages for loss of goodwill, 174 | work stoppage, computer failure or malfunction, or any and all 175 | other commercial damages or losses), even if such Contributor 176 | has been advised of the possibility of such damages. 177 | 178 | 9. Accepting Warranty or Additional Liability. While redistributing 179 | the Work or Derivative Works thereof, You may choose to offer, 180 | and charge a fee for, acceptance of support, warranty, indemnity, 181 | or other liability obligations and/or rights consistent with this 182 | License. However, in accepting such obligations, You may act only 183 | on Your own behalf and on Your sole responsibility, not on behalf 184 | of any other Contributor, and only if You agree to indemnify, 185 | defend, and hold each Contributor harmless for any liability 186 | incurred by, or claims asserted against, such Contributor by reason 187 | of your accepting any such warranty or additional liability. 188 | 189 | END OF TERMS AND CONDITIONS 190 | 191 | APPENDIX: How to apply the Apache License to your work. 192 | 193 | To apply the Apache License to your work, attach the following 194 | boilerplate notice, with the fields enclosed by brackets "[]" 195 | replaced with your own identifying information. (Don't include 196 | the brackets!) The text should be enclosed in the appropriate 197 | comment syntax for the file format. We also recommend that a 198 | file or class name and description of purpose be included on the 199 | same "printed page" as the copyright notice for easier 200 | identification within third-party archives. 201 | 202 | Copyright [yyyy] [name of copyright owner] 203 | 204 | Licensed under the Apache License, Version 2.0 (the "License"); 205 | you may not use this file except in compliance with the License. 206 | You may obtain a copy of the License at 207 | 208 | http://www.apache.org/licenses/LICENSE-2.0 209 | 210 | Unless required by applicable law or agreed to in writing, software 211 | distributed under the License is distributed on an "AS IS" BASIS, 212 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 213 | See the License for the specific language governing permissions and 214 | limitations under the License. -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from datasets import load_dataset 8 | from transformers import PreTrainedTokenizer 9 | import csv 10 | from loguru import logger 11 | import random 12 | from utils.common_utils import load_pickle 13 | DATASET_ID_DICT = {'snli-zh':1,'sts':2,'t2-05':3,'du-10':4,'mmarco':5,'cmedqa':6} 14 | def load_text_dataset(name, pos_dir, neg_dir, file_path, neg_K, res_data, split): 15 | data = [] 16 | if split == 'train': 17 | hard_neg_house = load_pickle(neg_dir) 18 | pos_logis = load_pickle(pos_dir) 19 | with open(file_path, encoding='utf-8') as f: 20 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 21 | for id, row in enumerate(reader): 22 | text_a = row['sentence1'] 23 | text_b = row['sentence2'] 24 | score = row['gold_label'] 25 | if score == 'entailment': 26 | if split == 'train': 27 | if len(hard_neg_house[text_a]) < neg_K: 28 | num = math.ceil(neg_K / len(hard_neg_house[text_a])) 29 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K) 30 | else: 31 | negs_logits = random.sample(hard_neg_house[text_a], neg_K) 32 | hardnegs, hardneg_logits = zip(*negs_logits) 33 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits) 34 | elif split == 'validation': 35 | hardnegs = [] 36 | hardneg_logits = [] 37 | pos_logits = [] 38 | hardnegs = [sample[:100] for sample in hardnegs] 39 | data.append((text_a[:100], text_b[:100], pos_logits, hardnegs, hardneg_logits, 0)) 40 | if split == 'train': 41 | split_data = data[:-10000] 42 | sample_num = len(split_data) 43 | elif split == 'validation': 44 | split_data = data[-10000:] 45 | sample_num = len(split_data) 46 | res_data.extend(split_data) 47 | 48 | return res_data, sample_num 49 | 50 | 51 | def load_sts_dataset_train(name, pos_dir, neg_dir, file_path, neg_K, res_data): 52 | data = [] 53 | pos_logis = load_pickle(pos_dir) 54 | hard_neg_house = load_pickle(neg_dir) 55 | with open(file_path, encoding='utf-8') as f: 56 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 57 | for id, row in enumerate(reader): 58 | text_a = row['sentence1'] 59 | text_b = row['sentence2'] 60 | if len(hard_neg_house[text_a]) < neg_K: 61 | num = math.ceil(neg_K / len(hard_neg_house[text_a])) 62 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K) 63 | else: 64 | negs_logits = random.sample(hard_neg_house[text_a], neg_K) 65 | hardnegs, hardneg_logits = zip(*negs_logits) 66 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits) 67 | hardnegs = [sample[:100] for sample in hardnegs] 68 | data.append((text_a[:100], text_b[:100], pos_logits, hardnegs, hardneg_logits, 0)) 69 | 70 | sample_num = len(data) 71 | res_data.extend(data) 72 | 73 | return res_data, sample_num 74 | 75 | def load_sts_dataset_val(name, pos_dir, neg_dir, file_path, neg_K, res_data): 76 | data = [] 77 | with open(file_path, encoding='utf-8') as f: 78 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 79 | for id, row in enumerate(reader): 80 | text_a = row['sentence1'] 81 | text_b = row['sentence2'] 82 | data.append((text_a[:100], text_b[:100], [], [], [], 0)) 83 | 84 | sample_num = len(data) 85 | res_data.extend(data) 86 | 87 | return res_data, sample_num 88 | 89 | def load_t2_dataset_train(name, pos_dir, neg_dir, file_path, neg_K, res_data): 90 | data = [] 91 | pos_logis = load_pickle(pos_dir) 92 | hard_neg_house = load_pickle(neg_dir) 93 | with open(file_path, encoding='utf-8') as f: 94 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 95 | for id, row in enumerate(reader): 96 | text_a = row['sentence1'] 97 | text_b = row['sentence2'] 98 | if len(hard_neg_house[text_a]) < neg_K: 99 | num = math.ceil(neg_K / len(hard_neg_house[text_a])) 100 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K) 101 | else: 102 | negs_logits = random.sample(hard_neg_house[text_a], neg_K) 103 | hardnegs, hardneg_logits = zip(*negs_logits) 104 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits) 105 | hardnegs = [sample[:320] for sample in hardnegs] 106 | data.append((text_a[:50], text_b[:320], pos_logits, hardnegs, hardneg_logits, 1)) 107 | 108 | sample_num = len(data) 109 | res_data.extend(data) 110 | 111 | return res_data, sample_num 112 | 113 | def load_t2_dataset_val(name, pos_dir, neg_dir, file_path, neg_K, res_data): 114 | data = [] 115 | with open(file_path, encoding='utf-8') as f: 116 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 117 | for id, row in enumerate(reader): 118 | text_a = row['sentence1'] 119 | text_b = row['sentence2'] 120 | data.append((text_a[:50], text_b[:320], [], [], [], 1)) 121 | 122 | sample_num = len(data) 123 | res_data.extend(data) 124 | 125 | return res_data, sample_num 126 | 127 | 128 | def load_du_dataset_train(name, pos_dir, neg_dir, file_path, neg_K, res_data): 129 | data = [] 130 | pos_logits = load_pickle(pos_dir) 131 | hard_neg_house = load_pickle(neg_dir) 132 | with open(file_path, encoding='utf-8') as f: 133 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 134 | for id, row in enumerate(reader): 135 | text_a = row['sentence1'] 136 | text_b = row['sentence2'] 137 | if len(hard_neg_house[text_a]) < neg_K: 138 | num = math.ceil(neg_K / len(hard_neg_house[text_a])) 139 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K) 140 | else: 141 | negs_logits = random.sample(hard_neg_house[text_a], neg_K) 142 | hardnegs, hardneg_logits = zip(*negs_logits) 143 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits) 144 | hardnegs = [sample[:320] for sample in hardnegs] 145 | data.append((text_a[:50], text_b[:320], pos_logits, hardnegs, hardneg_logits, 1)) 146 | 147 | sample_num = len(data) 148 | res_data.extend(data) 149 | 150 | return res_data, sample_num 151 | 152 | def load_du_dataset_val(name, pos_dir, neg_dir, file_path, neg_K, res_data): 153 | data = [] 154 | with open(file_path, encoding='utf-8') as f: 155 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 156 | for id, row in enumerate(reader): 157 | text_a = row['sentence1'] 158 | text_b = row['sentence2'] 159 | data.append((text_a[:50], text_b[:320], [], [], [], 1)) 160 | 161 | sample_num = len(data) 162 | res_data.extend(data) 163 | 164 | return res_data, sample_num 165 | 166 | def load_mmarco_dataset_train(name, pos_dir, neg_dir, file_path, neg_K, res_data): 167 | data = [] 168 | pos_logis = load_pickle(pos_dir) 169 | hard_neg_house = load_pickle(neg_dir) 170 | with open(file_path, encoding='utf-8') as f: 171 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 172 | for id, row in enumerate(reader): 173 | text_a = row['sentence1'] 174 | text_b = row['sentence2'] 175 | if len(hard_neg_house[text_a]) < neg_K: 176 | num = math.ceil(neg_K / len(hard_neg_house[text_a])) 177 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K) 178 | else: 179 | negs_logits = random.sample(hard_neg_house[text_a], neg_K) 180 | hardnegs, hardneg_logits = zip(*negs_logits) 181 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits) 182 | hardnegs = [sample[:320] for sample in hardnegs] 183 | data.append((text_a[:50], text_b[:320], pos_logits, hardnegs, hardneg_logits, 1)) 184 | 185 | sample_num = len(data) 186 | res_data.extend(data) 187 | 188 | return res_data, sample_num 189 | 190 | def load_mmarco_dataset_val(name, pos_dir, neg_dir, file_path, neg_K, res_data): 191 | data = [] 192 | with open(file_path, encoding='utf-8') as f: 193 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 194 | for id, row in enumerate(reader): 195 | text_a = row['sentence1'] 196 | text_b = row['sentence2'] 197 | data.append((text_a[:50], text_b[:320], [], [], [], 1)) 198 | 199 | sample_num = len(data) 200 | res_data.extend(data) 201 | 202 | return res_data, sample_num 203 | 204 | def load_cmedqa_dataset_train(name, pos_dir, neg_dir, file_path, neg_K, res_data): 205 | data = [] 206 | pos_logis = load_pickle(pos_dir) 207 | hard_neg_house = load_pickle(neg_dir) 208 | with open(file_path, encoding='utf-8') as f: 209 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 210 | for id, row in enumerate(reader): 211 | text_a = row['sentence1'] 212 | text_b = row['sentence2'] 213 | if len(hard_neg_house[text_a]) < neg_K: 214 | num = math.ceil(neg_K / len(hard_neg_house[text_a])) 215 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K) 216 | else: 217 | negs_logits = random.sample(hard_neg_house[text_a], neg_K) 218 | hardnegs, hardneg_logits = zip(*negs_logits) 219 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits) 220 | hardnegs = [sample[:320] for sample in hardnegs] 221 | data.append((text_a[:50], text_b[:320], pos_logits, hardnegs, hardneg_logits, 1)) 222 | 223 | sample_num = len(data) 224 | res_data.extend(data) 225 | 226 | return res_data, sample_num 227 | 228 | def load_cmedqa_dataset_val(name, pos_dir, neg_dir, file_path, neg_K, res_data): 229 | data = [] 230 | with open(file_path, encoding='utf-8') as f: 231 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE) 232 | for id, row in enumerate(reader): 233 | text_a = row['sentence1'] 234 | text_b = row['sentence2'] 235 | data.append((text_a[:50], text_b[:320], [], [], [], 1)) 236 | 237 | sample_num = len(data) 238 | res_data.extend(data) 239 | 240 | return res_data, sample_num 241 | 242 | 243 | def collate_fn(data): 244 | res_s_a = [] 245 | res_s_b = [] 246 | res_pos_logits = [] 247 | res_neg_K = [] 248 | res_neg_logits = [] 249 | res_task_id = [] 250 | 251 | for d in data[0]: 252 | res_s_a.append(d[0]) 253 | res_s_b.append(d[1]) 254 | res_pos_logits.append(d[2]) 255 | res_neg_K.append(d[3]) 256 | res_neg_logits.extend(d[4]) 257 | res_task_id.append(int(d[5])) 258 | 259 | res_neg_K = [list(group) for group in zip(*res_neg_K)] 260 | res_neg_K = [e for l in res_neg_K for e in l] 261 | 262 | 263 | return res_s_a, res_s_b, torch.FloatTensor(res_pos_logits), res_neg_K, torch.FloatTensor(res_neg_logits), torch.LongTensor(res_task_id) 264 | 265 | 266 | 267 | class TrainDataset(Dataset): 268 | 269 | def __init__(self, tokenizer, pos_dir, neg_dir, datadir, names=None, batch_size=32, neg_K=8, process_index=0, num_processes=1, seed=2023): 270 | self.dataset_id_dict = DATASET_ID_DICT 271 | self.tokenizer = tokenizer 272 | self.data = [] 273 | self.batch_size = batch_size 274 | self.sample_stas = dict() 275 | self.dataset_indices_range = dict() 276 | self.process_index = process_index 277 | self.num_processes = num_processes 278 | self.neg_K = neg_K 279 | self.deterministic_generator = np.random.default_rng(seed) 280 | names.sort(reverse=True) 281 | for name in names: 282 | if name in ['snli-zh']: 283 | if name == 'snli-zh': 284 | start_id = len(self.data) 285 | self.data, sample_num = load_text_dataset(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data, 'train') 286 | end_id = len(self.data) 287 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 288 | self.sample_stas[name] = sample_num 289 | elif name in ['sts']: 290 | if name == 'sts': 291 | start_id = len(self.data) 292 | self.data, sample_num = load_sts_dataset_train(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), datadir, self.neg_K, self.data) 293 | end_id = len(self.data) 294 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 295 | self.sample_stas[name] = sample_num 296 | elif name in ['t2','du', 'mmarco', 'cmedqa']: 297 | if name == 't2-05': 298 | start_id = len(self.data) 299 | self.data, sample_num = load_t2_dataset_train(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data) 300 | end_id = len(self.data) 301 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 302 | self.sample_stas[name] = sample_num 303 | if name == 'du': 304 | start_id = len(self.data) 305 | self.data, sample_num = load_du_dataset_train(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data) 306 | end_id = len(self.data) 307 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 308 | self.sample_stas[name] = sample_num 309 | if name == 'mmarco': 310 | start_id = len(self.data) 311 | self.data, sample_num = load_mmarco_dataset_train(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data) 312 | end_id = len(self.data) 313 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 314 | self.sample_stas[name] = sample_num 315 | if name == 'cmedqa': 316 | start_id = len(self.data) 317 | self.data, sample_num = load_cmedqa_dataset_train(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data) 318 | end_id = len(self.data) 319 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 320 | self.sample_stas[name] = sample_num 321 | else: 322 | logger.debug('Unknown dataset: {}'.format(name)) 323 | 324 | self.create_epoch() 325 | 326 | def __len__(self): 327 | return self.steps_per_epoch * self.num_processes 328 | 329 | def create_epoch(self): 330 | epoch = [] 331 | self.steps_per_epoch = 0 332 | for k, v in self.dataset_indices_range.items(): 333 | dataset_range = np.arange(*v) 334 | num_batches, remainer = divmod(len(dataset_range), self.batch_size * self.num_processes) 335 | if remainer != 0: 336 | dataset_range = dataset_range[:num_batches * self.batch_size * self.num_processes] 337 | self.deterministic_generator.shuffle(dataset_range) 338 | batches = dataset_range.reshape(num_batches * self.num_processes, self.batch_size).tolist() 339 | epoch.extend(batches) 340 | self.steps_per_epoch += num_batches 341 | self.deterministic_generator.shuffle(epoch) 342 | self.epoch = epoch 343 | self.step = 0 344 | 345 | 346 | def __getitem__(self, index: int): 347 | if self.step > (self.steps_per_epoch - 1): 348 | self.step = 0 349 | batch_indices = self.epoch[self.step*self.num_processes+self.process_index] 350 | batch_data = np.array(self.data)[batch_indices].tolist() 351 | self.step += 1 352 | 353 | return batch_data 354 | 355 | 356 | 357 | class ValDataset(Dataset): 358 | 359 | def __init__(self, tokenizer, pos_dir, neg_dir, datadir, names=None, batch_size=32, neg_K=8, process_index=0, num_processes=1, seed=2023): 360 | self.dataset_id_dict = DATASET_ID_DICT 361 | self.tokenizer = tokenizer 362 | self.data = [] 363 | self.batch_size = batch_size 364 | self.neg_K = neg_K 365 | self.sample_stas = dict() 366 | self.dataset_indices_range = dict() 367 | self.process_index = process_index 368 | self.num_processes = num_processes 369 | self.deterministic_generator = np.random.default_rng(seed) 370 | names.sort(reverse=True) 371 | for name in names: 372 | if name in ['snli-zh']: 373 | if name == 'snli-zh': 374 | start_id = len(self.data) 375 | self.data, sample_num = load_text_dataset(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data, 'validation') 376 | end_id = len(self.data) 377 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 378 | self.sample_stas[name] = sample_num 379 | elif name in ['sts']: 380 | if name == 'sts': 381 | start_id = len(self.data) 382 | self.data, sample_num = load_sts_dataset_val(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data) 383 | end_id = len(self.data) 384 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 385 | self.sample_stas[name] = sample_num 386 | elif name in ['t2', 'du', 'mmarco', 'cmedqa']: 387 | if name == 't2': 388 | start_id = len(self.data) 389 | self.data, sample_num = load_t2_dataset_val(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data) 390 | end_id = len(self.data) 391 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 392 | self.sample_stas[name] = sample_num 393 | if name == 'du': 394 | start_id = len(self.data) 395 | self.data, sample_num = load_du_dataset_val(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data) 396 | end_id = len(self.data) 397 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 398 | self.sample_stas[name] = sample_num 399 | if name == 'mmarco': 400 | start_id = len(self.data) 401 | self.data, sample_num = load_mmarco_dataset_val(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data) 402 | end_id = len(self.data) 403 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 404 | self.sample_stas[name] = sample_num 405 | if name == 'cmedqa': 406 | start_id = len(self.data) 407 | self.data, sample_num = load_cmedqa_dataset_val(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data) 408 | end_id = len(self.data) 409 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id) 410 | self.sample_stas[name] = sample_num 411 | else: 412 | logger.debug('Unknown dataset: {}'.format(name)) 413 | self.create_epoch() 414 | 415 | 416 | def __len__(self): 417 | return self.steps_per_epoch * self.num_processes 418 | 419 | def create_epoch(self): 420 | epoch = [] 421 | self.steps_per_epoch = 0 422 | for k, v in self.dataset_indices_range.items(): 423 | dataset_range = np.arange(*v) 424 | num_batches, remainer = divmod(len(dataset_range), self.batch_size * self.num_processes) 425 | if remainer != 0: 426 | dataset_range = dataset_range[:num_batches * self.batch_size * self.num_processes] 427 | self.deterministic_generator.shuffle(dataset_range) 428 | batches = dataset_range.reshape(num_batches * self.num_processes, self.batch_size).tolist() 429 | epoch.extend(batches) 430 | self.steps_per_epoch += num_batches 431 | self.deterministic_generator.shuffle(epoch) 432 | self.epoch = epoch 433 | self.step = 0 434 | 435 | 436 | def __getitem__(self, index: int): 437 | 438 | if self.step > self.steps_per_epoch - 1: 439 | self.step = 0 440 | batch_indices = self.epoch[self.step*self.num_processes+self.process_index] 441 | batch_data = np.array(self.data)[batch_indices].tolist() 442 | self.step += 1 443 | return batch_data 444 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import warnings 4 | import json 5 | import logging 6 | import argparse 7 | import random 8 | import time 9 | import tracemalloc 10 | from collections import defaultdict 11 | from copy import deepcopy 12 | import deepspeed 13 | import transformers 14 | import torch 15 | import torch.nn as nn 16 | import torch.distributed as dist 17 | from transformers import AutoTokenizer 18 | from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType 19 | from torch.utils.data import DataLoader, Dataset, RandomSampler 20 | from torch.utils.data.distributed import DistributedSampler 21 | from tqdm import tqdm, trange 22 | from torch.utils.tensorboard import SummaryWriter 23 | from dataset.dataset import * 24 | from model.pro_model import * 25 | from utils.common_utils import * 26 | logging.getLogger().setLevel(logging.INFO) 27 | warnings.filterwarnings('ignore') 28 | 29 | 30 | def cal_loss_in_batch(args, student_logits, temperature, criterion): 31 | 32 | bs = student_logits.size(0) 33 | logits = student_logits/temperature 34 | labels = torch.arange(bs, device=logits.device) 35 | loss_bs = criterion(logits, labels) 36 | 37 | return (loss_bs.sum())/ (bs * bs) 38 | 39 | 40 | def cal_loss_hardneg(args, teacher_logits, student_logits, temperature_teacher, temperature, nll_criterion): 41 | 42 | loss_hardneg_weight = args.alpha 43 | 44 | def softmax(X, temp): 45 | X = (X/temp).exp() 46 | res = X / (X.sum(-1, keepdims=True)+1e-20) 47 | return res 48 | 49 | bs = teacher_logits.size(0) 50 | neg_K = teacher_logits.size(1)-1 51 | teacher_logits = softmax(teacher_logits, temperature_teacher)[:,:, 0] 52 | teacher_logits[:, 1:] = 1 - teacher_logits[:, 1:] 53 | inputs = (softmax(student_logits*teacher_logits, temperature)).log() 54 | labels = torch.zeros(bs, dtype=torch.long, device=student_logits.device) 55 | loss_bs = nll_criterion(inputs, labels) 56 | 57 | 58 | loss_bs = loss_bs * loss_hardneg_weight 59 | return loss_bs.sum() / (bs * neg_K) 60 | 61 | 62 | def cal_loss_rd(args, teacher_logits, student_logits, teacher_temperature): 63 | 64 | loss_pearson_weight = args.beta 65 | 66 | def softmax(X, temp): 67 | X = (X/temp).exp() 68 | res = X / (X.sum(-1, keepdims=True)+1e-20) 69 | return res 70 | 71 | def pearsonr(x,y,batch_first=True): 72 | assert x.shape == y.shape 73 | if batch_first: 74 | dim = -1 75 | else: 76 | dim = 0 77 | assert x.shape[dim] > 1 78 | centered_x = x - x.mean(dim=dim, keepdim=True) 79 | centered_y = y - y.mean(dim=dim, keepdim=True) 80 | covariance = (centered_x * centered_y).sum(dim=dim, keepdim=True) 81 | bessel_corrected_covariance = covariance / (x.shape[dim] - 1) 82 | x_std = x.std(dim=dim, keepdim=True) 83 | y_std = y.std(dim=dim, keepdim=True) 84 | corr = bessel_corrected_covariance / ((x_std * y_std)+1e-8) 85 | return corr 86 | 87 | 88 | 89 | bs = student_logits.size(0) 90 | teacher_logits = softmax(teacher_logits, teacher_temperature)[:,:, 0] 91 | spearson = pearsonr(student_logits, teacher_logits).squeeze() 92 | 93 | loss_bs = 1 - spearson 94 | 95 | loss_bs = loss_bs * loss_pearson_weight 96 | 97 | return loss_bs.sum() / bs 98 | 99 | 100 | 101 | def cal_loss_rd2(args, teacher_logits_pos_hardneg, teacher_logits_pos_inbatch, teacher_temperature, student_logits_pos_hardneg, student_logits_pos_inbatch, sigmoid, scale_param): 102 | 103 | loss_bpr_weight = args.gamma 104 | 105 | def softmax(X, temp): 106 | X = (X/temp).exp() 107 | res = X / (X.sum(-1, keepdims=True)+1e-20) 108 | return res 109 | 110 | 111 | teacher_logits_pos_hardneg = softmax(teacher_logits_pos_hardneg, teacher_temperature)[:,:, 0] 112 | teacher_logits_pos_inbatch = softmax(teacher_logits_pos_inbatch, teacher_temperature)[:,:, 0] 113 | 114 | bs = student_logits_pos_hardneg.size(0) 115 | neg_K = student_logits_pos_hardneg.size(1)-1 116 | inbatch = student_logits_pos_inbatch.size(1)-1 117 | student_logits_hardneg = student_logits_pos_hardneg[:, 1:] 118 | eye = torch.eye(bs, dtype=torch.bool) 119 | student_logits_inbatch = student_logits_pos_inbatch[~eye].reshape(bs, -1) 120 | loss_hardneg_inbatch = -((sigmoid(student_logits_hardneg.view(bs, neg_K, 1).expand(-1, -1, inbatch).reshape(bs, -1) - student_logits_inbatch.unsqueeze(1).expand(-1, neg_K,-1).reshape(bs, -1))+1e-8).log()) 121 | weight_hardneg_inbatch = teacher_logits_hardneg.repeat_interleave(inbatch, dim=1) - teacher_logits_inbatch.repeat((1, neg_K)) 122 | weight_hardneg_inbatch = torch.clamp(weight_hardneg_inbatch, min=0) / scale_param 123 | loss_bs = (loss_hardneg_inbatch * weight_hardneg_inbatch).sum(-1) 124 | loss_bs = loss_bs * loss_bpr_weight 125 | 126 | return loss_bs.sum() / (bs * neg_K * inbatch) 127 | 128 | 129 | def cal_feat_loss(args, teacher_feat_cos, student_feature_pos_hardneg): 130 | 131 | loss_feat_weight = args.eta 132 | neg_K = teacher_feat_cos.size(1) 133 | student_feature_pos_hardneg = student_feature_pos_hardneg.transpose(0, 1) 134 | student_feature_pos_hardneg = student_feature_pos_hardneg / student_feature_pos_hardneg.norm(dim=-1, keepdim=True) 135 | student_feat_cos = torch.matmul(student_feature_pos_hardneg, student_feature_pos_hardneg.transpose(-2, -1)) 136 | loss_bs = ((teacher_feat_cos - student_feat_cos) ** 2).sum((-1,-2)) 137 | 138 | loss_bs = loss_bs * loss_feat_weight 139 | 140 | return loss_bs.sum() / (neg_K * neg_K) 141 | 142 | 143 | def str2bool(v): 144 | return v.lower() in ('yes', 'true', 't', '1') 145 | 146 | def main(): 147 | 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('--base_model_dir', default='/mnt/user/415350/download_models/Qwen-7B-Chat', type=str, help='Model directory') 150 | parser.add_argument('--train_data_list', nargs='+') 151 | parser.add_argument('--pos_dir', default='PATH_TO_POS_LOGITS', type=str) 152 | parser.add_argument('--neg_dir', default='PATH_TO_HARDNEG_LOGITS', type=str) 153 | parser.add_argument('--data_dir', default='', type=str) 154 | parser.add_argument('--inbatch_pkl_path_dir', default='PATH_TO_INBATCH_LOGITS_PKL') 155 | parser.add_argument('--feature_pkl_path_dir', default='PATH_TO_FEATURE_PKL') 156 | parser.add_argument('--batch_size', default=32, type=int, help='bs') 157 | parser.add_argument('--neg_K', default=8, type=int, help='num of hard negs') 158 | parser.add_argument('--num_heads', default=32, type=int, help='num_heads of pma') 159 | parser.add_argument('--hidden_dim', default=512, type=int, help='hidden dim of my mlp') 160 | parser.add_argument('--output_dim', default=1, type=int, help='output dim of my mlp') 161 | parser.add_argument('--ln', default=True, type=str2bool, help='layer norm for pma') 162 | parser.add_argument('--norm', default=False, type=str2bool, help='norm after sentence pooling') 163 | parser.add_argument('--num_epochs', default=5, type=int, help='training epochs') 164 | parser.add_argument('--padding_side', default='right', type=str, help='padding side') 165 | parser.add_argument('--max_seq_length', default=250, type=int, help='max_seq_len') 166 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 167 | parser.add_argument('--alpha', default=1, type=float, help='trade-off param') 168 | parser.add_argument('--beta', default=1, type=float, help='trade-off param') 169 | parser.add_argument('--gamma', default=0.01, type=float, help='trade-off param') 170 | parser.add_argument('--eta', default=0.001, type=float, help='trade-off param') 171 | parser.add_argument('--temperature_in_batch', default=1, type=float, help='temperature in in-batch') 172 | parser.add_argument('--temperature_hardneg', default=1, type=float, help='temperature in hardneg') 173 | parser.add_argument('--temperature_teacher_hardneg', default=1, type=float, help='temperature in teacher logits') 174 | parser.add_argument('--scale_param', default=1, type=float, help='scale param') 175 | parser.add_argument('--log_interval', default=20, type=int) 176 | parser.add_argument('--eval_interval', default=200, type=int) 177 | parser.add_argument('--tb_dir', default='PATH_TO_TENSORBOARD_PATH', type=str) 178 | parser.add_argument('--patience', default=5, type=int) 179 | parser.add_argument('--num_ckpt', default=5, type=int) 180 | parser.add_argument('--training_log', default='PATH_TO_TRAINING_LOG') 181 | parser.add_argument('--output_dir', default='PATH_TO_OUTPUT_MODEL', type=str, help='Model output directory') 182 | parser.add_argument('--weight_decay', default=0.01, type=float, help='weight decay') 183 | parser.add_argument('--gradient_clipping', default=1.0, type=float, help='max_grad_norm') 184 | parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help='gradient accumulation steps') 185 | parser.add_argument('--seed', default=2023, type=int) 186 | parser.add_argument('--bf16', default=True, type=str2bool) 187 | parser.add_argument('--verbose', default=True, type=str2bool) 188 | parser.add_argument('--device', default='cuda', type=str) 189 | parser.add_argument('--local_rank', type=int, default=-1, help='ds') 190 | parser.add_argument('--global_rank', type=int, default=-1, help='ds') 191 | parser = deepspeed.add_config_arguments(parser) 192 | args = parser.parse_args() 193 | args.world_size = int(os.getenv('WORLD_SIZE', '0')) 194 | 195 | sigmoid = nn.Sigmoid() 196 | tanh = nn.Tanh() 197 | 198 | os.makedirs(args.output_dir, exist_ok=True) 199 | logging.basicConfig(filename=f'{arg.training_log}')) 200 | 201 | if args.seed is not None: 202 | set_seed(args.seed) 203 | transformers.set_seed(args.seed) 204 | 205 | micro_bs = args.batch_size 206 | 207 | model = Mymodel(model_name_or_path=args.base_model_dir, 208 | alias=None, 209 | max_seq_length=args.max_seq_length, 210 | args=args) 211 | model.plm_model.gradient_checkpointing_enable() 212 | 213 | summary_writer = SummaryWriter(log_dir=args.tb_dir) 214 | 215 | train_data_flag = False 216 | lora_config = LoraConfig( 217 | r=8, 218 | lora_alpha=8, 219 | target_modules=['c_attn', 'c_proj', 'w1', 'w2'], 220 | layers_to_transform=list(range(0, 32)), 221 | lora_dropout=0.05, 222 | bias="none", 223 | inference_mode=False, 224 | task_type=TaskType.CAUSAL_LM 225 | ) 226 | model.plm_model = get_peft_model(model.plm_model, lora_config) 227 | 228 | update_parameters = filter(lambda p: p.requires_grad, model.parameters()) 229 | param_optimizer = list([(n,p) for n,p in model.named_parameters() if p.requires_grad]) 230 | 231 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 232 | optimizer_grouped_parameters = [ 233 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 234 | 'lr': args.lr, 'weight_decay': args.weight_decay, 'betas': [0.8,0.999], 'eps': 1e-6, 'name':'d'}, 235 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 236 | 'lr': args.lr, 'weight_decay': 0.0, 'betas': [0.8,0.999], 'eps': 1e-6, 'name':'nd'}] 237 | 238 | ds_config = { 239 | "bfloat16": { 240 | "enabled": args.bf16 241 | }, 242 | "zero_optimization": { 243 | "stage": 2, 244 | "offload_optimizer": { 245 | "device": "cpu", 246 | "pin_memory": True 247 | }, 248 | "allgather_partitions": True, 249 | "allgather_bucket_size": 2e8, 250 | "overlap_comm": True, 251 | "reduce_scatter": True, 252 | "reduce_bucket_size": 2e8, 253 | "contiguous_gradients": True 254 | }, 255 | "gradient_accumulation_steps": args.gradient_accumulation_steps, 256 | "gradient_clipping": args.gradient_clipping, 257 | "train_batch_size": args.world_size, 258 | "train_micro_batch_size_per_gpu": 1, 259 | "steps_per_print": 1e5 260 | } 261 | 262 | fake_bs = ds_config['train_micro_batch_size_per_gpu'] 263 | optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(optimizer_grouped_parameters) 264 | scheduler = deepspeed.runtime.lr_schedules.WarmupLR(optimizer, warmup_min_lr=[0,0], warmup_max_lr=[args.lr,args.lr], 265 | warmup_num_steps=1000) 266 | 267 | model_engine, optimizer, _, scheduler = deepspeed.initialize(args=args, model=model, model_parameters=update_parameters, optimizer=optimizer, lr_scheduler=scheduler, config=ds_config) 268 | device = torch.device(args.local_rank) 269 | args.device = device 270 | args.global_rank = torch.distributed.get_rank() 271 | 272 | train_dataset = TrainDataset(model.tokenizer, pos_dir=args.pos_dir, neg_dir=args.neg_dir, datadir=args.data_dir, names=args.train_data_list, batch_size=micro_bs, neg_K=args.neg_K, process_index=args.global_rank, num_processes=args.world_size) 273 | val_dataset = ValDataset(model.tokenizer, pos_dir=args.pos_dir, neg_dir=args.neg_dir, datadir=args.data_dir, names=args.train_data_list, batch_size=micro_bs, neg_K=args.neg_K, process_index=args.global_rank, num_processes=args.world_size) 274 | 275 | if args.global_rank == -1: 276 | train_sampler = RandomSampler(train_dataset) 277 | val_sampler = RandomSampler(val_dataset) 278 | else: 279 | train_sampler = DistributedSampler(train_dataset) 280 | val_sampler = DistributedSampler(val_dataset) 281 | train_dataloader = DataLoader(train_dataset, batch_size=fake_bs, shuffle=False, sampler=train_sampler,collate_fn=collate_fn, num_workers=0) 282 | val_dataloader = DataLoader(val_dataset, batch_size=fake_bs, shuffle=False, sampler=val_sampler,collate_fn=collate_fn, num_workers=0) 283 | if len(train_dataset) > 0: 284 | train_data_flag = True 285 | 286 | if not train_data_flag: 287 | raise ValueError("Error, train_file|use_hf_dataset must be specified") 288 | 289 | all_dataset_id = train_dataset.dataset_id_dict 290 | all_dataset_id_reverse = {v:k for k, v in train_dataset.dataset_id_dict.items()} 291 | rel_dataset_id = [all_dataset_id[dataset_name] for dataset_name in args.train_data_list] 292 | os.makedirs(args.output_dir, exist_ok=True) 293 | 294 | train_loader_size = len(train_dataloader) 295 | val_loader_size = len(val_dataloader) 296 | 297 | criterion = nn.CrossEntropyLoss(reduction='none') 298 | nll_criterion = nn.NLLLoss(reduction='none') 299 | 300 | global_step = 0 301 | best_eval_metric = 0 302 | trained_epochs = 0 303 | min_reduce_loss_eval = float('inf') 304 | best_epoch = 0 305 | stop = 0 306 | 307 | teacher_feature_cos_dict = load_pickle(args.feature_pkl_path_dir) 308 | teacher_inbatch = load_pickle(args.inbatch_pkl_path_dir) 309 | 310 | reduce_loss = 0 311 | reduce_loss_eval = 0 312 | reduce_loss_in_batch = 0 313 | reduce_loss_in_batch_eval = 0 314 | reduce_loss_hardneg = 0 315 | reduce_loss_rd = 0 316 | reduce_loss_rd2 = 0 317 | reduce_loss_feat = 0 318 | reduce_inbatch_sample_num = {} 319 | 320 | 321 | for current_epoch in trange(int(args.num_epochs), desc="Epoch", disable=(args.global_rank!=0), mininterval=0): 322 | if stop >= args.patience: 323 | logging.info(f'Early Stop at {current_epoch+1}-th epoch {global_step}-th step') 324 | logging.info(f'Model trained!\nThe best model at {best_epoch+1}-th epoch {best_step}-th step') 325 | break 326 | torch.cuda.empty_cache() 327 | model_engine.train() 328 | 329 | loss_epoch_eval = 0 330 | 331 | batch_iterator = tqdm(train_dataloader, 332 | desc=f"Running Epoch {current_epoch + 1} of {args.num_epochs}", 333 | disable=(args.global_rank!=0), 334 | mininterval=0) 335 | for step, batch in enumerate(batch_iterator): 336 | sentence_a, sentence_b, logits_teacher_pos, sentence_hardneg, logits_teacher_hardneg, task_id = batch 337 | sentence_all = sentence_a + sentence_b + sentence_hardneg 338 | bs = logits_teacher_pos.size(0) 339 | key = 'global_rank' + str(args.global_rank) 340 | logits_teacher_inbatch = teacher_logits_dict[key][step].to(device) 341 | feature_teacher_cos = teacher_feature_cos_dict[key][step].to(device) 342 | 343 | inputs_all = model.tokenizer(sentence_all, padding='max_length', max_length=args.max_seq_length, truncation=True, return_tensors='pt') 344 | inputs_all = inputs_all.to(device) 345 | task_id = task_id.to(device) 346 | logits_student_in_batch, logits_student_hardneg, rep_student_pos_hardneg = model_engine(inputs_all, task_id, 'train') 347 | 348 | loss_in_batch = cal_loss_in_batch(args, logits_student_in_batch, args.temperature_in_batch, criterion) 349 | logits_teacher_pos = logits_teacher_pos.to(args.device) 350 | logits_teacher_hardneg = logits_teacher_hardneg.reshape(micro_bs, args.neg_K, 2).to(args.device) 351 | logits_teacher_hardneg = torch.cat([logits_teacher_pos.unsqueeze(1), logits_teacher_hardneg], dim=1) 352 | loss_hardneg = cal_loss_hardneg(args, logits_teacher_hardneg, logits_student_hardneg, args.temperature_teacher_hardneg, args.temperature_hardneg, nll_criterion) 353 | 354 | loss_rd = cal_loss_rd(args, logits_teacher_hardneg, logits_student_hardneg, args.temperature_teacher_hardneg) 355 | 356 | loss_rd2 = cal_loss_rd2(args, logits_teacher_hardneg, logits_teacher_inbatch, args.temperature_teacher_hardneg, logits_student_hardneg, logits_student_in_batch, sigmoid, args.scale_param) 357 | 358 | loss_feat = cal_feat_loss(args, feature_teacher_cos, rep_student_pos_hardneg) 359 | 360 | loss_batch = loss_in_batch + loss_hardneg + loss_outer_rd + loss_rd + loss_feat 361 | if args.verbose: 362 | batch_iterator.set_description( 363 | f"Epoch: {current_epoch + 1}/{args.num_epochs}, Batch:{step}/{len(train_dataloader)}, Loss: {loss_batch:9.4f}") 364 | 365 | model_engine.backward(loss_batch) 366 | model_engine.step() 367 | 368 | if (step + 1) % args.gradient_accumulation_steps == 0: 369 | global_step += 1 370 | 371 | reduce_loss += loss_batch.detach() 372 | reduce_loss_in_batch += loss_in_batch.detach() 373 | reduce_loss_hardneg += loss_hardneg.detach() 374 | reduce_loss_rd += loss_rd.detach() 375 | reduce_loss_rd2 += loss_rd2.detach() 376 | reduce_loss_feat += loss_feat.detach() 377 | 378 | if global_step % args.log_interval == 0: 379 | dist.all_reduce(reduce_loss, op=dist.ReduceOp.SUM) 380 | dist.all_reduce(reduce_loss_in_batch, op=dist.ReduceOp.SUM) 381 | dist.all_reduce(reduce_loss_hardneg, op=dist.ReduceOp.SUM) 382 | dist.all_reduce(reduce_loss_rd, op=dist.ReduceOp.SUM) 383 | dist.all_reduce(reduce_loss_rd2, op=dist.ReduceOp.SUM) 384 | dist.all_reduce(reduce_loss_feat, op=dist.ReduceOp.SUM) 385 | 386 | reduce_loss = reduce_loss.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size) 387 | reduce_loss_in_batch = reduce_loss_in_batch.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size) 388 | reduce_loss_hardneg = reduce_loss_hardneg.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size) 389 | reduce_loss_rd = reduce_loss_rd.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size) 390 | reduce_loss_rd2 = reduce_loss_rd2.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size) 391 | reduce_loss_feat = reduce_loss_feat.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size) 392 | 393 | if args.global_rank == 0: 394 | train_log_dict = {} 395 | train_log_dict['loss_overall'] = reduce_loss 396 | train_log_dict = {'loss_inbatch':reduce_loss_in_batch} 397 | train_log_dict['loss_hardneg'] = reduce_loss_hardneg 398 | train_log_dict['loss_rd'] = reduce_loss_rd 399 | train_log_dict['loss_rd2'] = reduce_loss_rd2 400 | train_log_dict['loss_feat'] = reduce_loss_feat 401 | write_tensorboard(summary_writer, train_log_dict, global_step) 402 | 403 | reduce_loss = 0 404 | reduce_loss_hardneg = 0 405 | reduce_loss_rd = 0 406 | reduce_loss_rd2 = 0 407 | reduce_loss_feat = 0 408 | reduce_loss_in_batch = 0 409 | 410 | if global_step % args.eval_interval == 0: 411 | model_engine.eval() 412 | batch_iterator_eval = tqdm(val_dataloader, 413 | disable=(args.global_rank!=0), 414 | mininterval=0) 415 | 416 | with torch.no_grad(): 417 | for step, batch in enumerate(batch_iterator_eval): 418 | sentence_a, sentence_b, _, _, _, task_id = batch 419 | sentence_all = sentence_a + sentence_b 420 | bs = dataset_id.size(0) 421 | 422 | key = 'global_rank' + str(args.global_rank) 423 | 424 | inputs_all = model.tokenizer(sentence_all, padding='max_length', max_length=args.max_seq_length, truncation=True, return_tensors='pt') 425 | 426 | inputs_all = inputs_all.to(device) 427 | task_id = task_id.to(device) 428 | logits_student_in_batch_eval, _, _ = model_engine(inputs_all, task_id, 'eval') 429 | 430 | loss_in_batch_dict_eval = cal_loss_in_batch(args, logits_student_in_batch_eval, args.temperature_in_batch, criterion) 431 | 432 | loss_batch_eval = loss_in_batch.detach() 433 | if args.verbose: 434 | batch_iterator_eval.set_description( 435 | f"Epoch: {current_epoch + 1}/{args.num_epochs}, Batch:{step}/{len(val_dataloader)}, Loss: {loss_batch_eval:9.4f}") 436 | 437 | 438 | reduce_loss_eval += loss_batch_eval 439 | 440 | dist.all_reduce(reduce_loss_eval, op=dist.ReduceOp.SUM) 441 | reduce_loss_eval = reduce_loss_eval.item() / (val_loader_size * args.world_size) 442 | 443 | if args.global_rank == 0: 444 | eval_log_dict = {'loss_eval':reduce_loss_eval} 445 | write_tensorboard(summary_writer, eval_log_dict, global_step) 446 | 447 | save_flag = False 448 | 449 | if stop >= args.patience: 450 | break 451 | 452 | if reduce_loss_eval <= min_reduce_loss_eval: 453 | min_reduce_loss_eval = reduce_loss_eval 454 | best_epoch = current_epoch 455 | best_step = global_step 456 | stop = 0 457 | 458 | path = args.output_dir 459 | start_name = 'checkpoint' 460 | current_step_num = global_step 461 | max_save_num = 2 462 | if args.global_rank == 0: 463 | print('removing') 464 | try: 465 | remove_earlier_ckpt(path, start_name, current_step_num, max_save_num) 466 | except: 467 | print('No ckpt to remove.') 468 | else: 469 | stop += 1 470 | 471 | if stop < args.num_ckpt: 472 | save_flag = True 473 | 474 | 475 | if save_flag: 476 | output_dir_current = os.path.join(args.output_dir, "checkpoint-{}-epoch-{}-{}".format(global_step, current_epoch+1, args.mark)) 477 | client_sd = dict() 478 | 479 | save_model(model_engine, output_dir_current, client_state=client_sd) 480 | 481 | reduce_loss_eval = 0 482 | model_engine.train() 483 | 484 | 485 | if __name__ == '__main__': 486 | main() 487 | --------------------------------------------------------------------------------