├── model
├── __init__.py
└── pro_model.py
├── dataset
├── __init__.py
└── dataset.py
├── img
└── Figure_1.png
├── LEGAL.md
├── utils
└── common_utils.py
├── README.md
├── train.sh
├── preprocess
├── save_logits.py
├── save_hardneg_bm25.py
└── save_hardnrg_bi.py
├── LICENSE.md
└── train.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/img/Figure_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/D2LLM/main/img/Figure_1.png
--------------------------------------------------------------------------------
/LEGAL.md:
--------------------------------------------------------------------------------
1 | Legal Disclaimer
2 |
3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
4 |
5 | 法律免责声明
6 |
7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
--------------------------------------------------------------------------------
/utils/common_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import pathlib
4 | import numpy as np
5 | from scipy.stats import pearsonr, spearmanr
6 | import torch
7 | from loguru import logger
8 | import shutil
9 | from torch.utils.tensorboard import SummaryWriter
10 | import pickle
11 | import linecache
12 | import tracemalloc
13 |
14 | def set_seed(seed):
15 | random.seed(seed)
16 | np.random.seed(seed)
17 | torch.manual_seed(seed)
18 | if torch.cuda.is_available():
19 | torch.cuda.manual_seed_all(seed)
20 |
21 | def save_model(model_engine, ckpt_dir, client_state):
22 | model_engine.save_checkpoint(ckpt_dir, client_state=client_state, exclude_frozen_parameters=True)
23 |
24 | def remove_earlier_ckpt(path, start_name, current_step_num, max_save_num):
25 |
26 | filenames=os.listdir(path)
27 | ckpts = [dir_name for dir_name in filenames if dir_name.startswith(start_name) and int(dir_name.split('-')[1])<=current_step_num]
28 |
29 | current_ckpt_num = len(ckpts)
30 | for dir_name in filenames:
31 | if dir_name.startswith(start_name) and int(dir_name.split('-')[1]) <= current_step_num and current_ckpt_num > (max_save_num-1):
32 | shutil.rmtree(os.path.join(path, dir_name))
33 |
34 |
35 | def makedirs(path):
36 | p = pathlib.Path(path)
37 | p.parent.mkdir(parents=True, exist_ok=True)
38 | return path
39 |
40 | def load_pickle(path):
41 | with open(path, "rb") as f:
42 | return pickle.load(f)
43 |
44 | def write_pickle(obj, path:str):
45 | if not os.path.exists(path):
46 | makedirs(path)
47 | with open(path, "wb") as f:
48 | return pickle.dump(obj, f)
49 |
50 | def write_tensorboard(summary_writer, log_dict, completed_steps):
51 | for key, value in log_dict.items():
52 | summary_writer.add_scalar(f'{key}', value, completed_steps)
53 |
54 | def cos_sim(a, b):
55 |
56 | if not isinstance(a, torch.Tensor):
57 | a = torch.tensor(a)
58 |
59 | if not isinstance(b, torch.Tensor):
60 | b = torch.tensor(b)
61 |
62 | if len(a.shape) == 1:
63 | a = a.unsqueeze(0)
64 |
65 | if len(b.shape) == 1:
66 | b = b.unsqueeze(0)
67 |
68 | a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
69 | b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
70 | return torch.mm(a_norm, b_norm.transpose(0, 1))
71 |
72 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # D2LLM: Decomposed and Distilled Large Language Models for Semantic Search
2 |
3 | This is the Pytorch implementation of D2LLM in the ACL'24 paper: D2LLM: Decomposed and Distilled Large Language Models for Semantic Search.
4 |
5 | 
6 |
Figure 1. The network architecture of D2LLM.
7 |
8 | ## Requirements
9 |
10 | * Ubuntu OS
11 | * python==3.10
12 | * torch==2.0.1
13 | * cuda==11.7
14 | * transformers==4.37.0
15 | * deepspeed==0.14.2
16 | * flash-attn==2.3.6
17 | * peft==0.7.0
18 |
19 | Dependencies can be installed by:
20 |
21 | pip install -r requirements.txt
22 |
23 |
24 | The overall directory structure is as follows:
25 |
26 | ${CODE_ROOT}
27 | ......
28 | |-- preprocess
29 | |-- save_hardneg_bm25.py
30 | |-- save_hardneg_bi.py
31 | |-- save_logits.py
32 | |-- dataset
33 | |-- dataset.py
34 | |-- model
35 | |-- pro_model.py
36 | |-- utils
37 | |-- common_utils.py
38 | |-- train.py
39 | |-- train.sh
40 |
41 |
42 |
43 | ## Data preparetion
44 |
45 | The six datasets (SNLI-zh, NLI-zh, T2Ranking, DuReader, cMedQA2 and mMARCO) used in this paper can be downloaded from the following links:
46 |
47 | * [SNLI-zh](https://huggingface.co/datasets/shibing624/snli-zh)
48 | * [NLI-zh](https://huggingface.co/datasets/shibing624/nli_zh)
49 | * [T2Ranking](https://github.com/THUIR/T2Ranking)
50 | * [DuReader](https://github.com/baidu/DuReader)
51 | * [cMedQA2](https://github.com/zhangsheng93/cMedQA2)
52 | * [mMARCO](https://huggingface.co/datasets/unicamp-dl/mmarco)
53 |
54 | Before performing training, we mine hard negatives through BM25 and other bi-encoder evaluations using scripts save_hardneg_bm25.py and save_hardneg_bi.py. Then, we use the script save_logits.py to perform correlation scoring on in-batch negatives and hard negatives through LLM.
55 |
56 | ## Train
57 |
58 | To perform training, just adjust the parameters and run:
59 |
60 | sh train.sh
61 |
62 | ## Evaluate
63 |
64 | Evaluation can be done throw the mteb tools. Note that the cosine similarity should be replace by the IEM module.
65 |
66 | ## Citation
67 |
68 | @inproceedings{
69 | anonymous2024dllm,
70 | title={D2{LLM}: Decomposed and Distilled Large Language Models for Semantic Search},
71 | author={Anonymous},
72 | booktitle={The 62nd Annual Meeting of the Association for Computational Linguistics},
73 | year={2024}
74 | }
75 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | BASE_MODEL_DIR="PATH_OF_BASE_MODEL"
3 | TRAIN_DATA_LIST="TRAIN_DATASETS"
4 | POS_DIR="PATH_TO_POS_LOGITS"
5 | NEG_DIR="PATH_TO_NEG_LOGITS"
6 | DATA_DIR="DATASET_DIR"
7 | INBATCH_PKL_PATH_DIR="PATH_TO_INBATCH_LOGITS_PKL"
8 | FEATURE_PKL_PATH_DIR="PATH_TO_FEATURE_PKL"
9 | BATCH_SIZE=32
10 | NEG_K=8
11 | NUM_HEADS=32
12 | HIDDEN_DIM=512
13 | OUTPUT_DIM=1
14 | LN="True"
15 | NORM="False"
16 | PADDING_SIDE="right"
17 | NUM_EPOCHS=5
18 | MAX_SEQ_LENGTH=250
19 | LR=1e-4
20 | ALPHA=1
21 | BETA=1
22 | GAMMA=0.01
23 | ETA=0.001
24 | TEMPERATURE_IN_BATCH=1
25 | TEMPERATURE_HARDNEG=1
26 | TEMPERATURE_TEACHER_HARDNEG=1
27 | SCALE_PARAM=1
28 | LOG_INTERVAL=10
29 | EVAL_INTERVAL=300
30 | TB_DIR="PATH_TO_TENSORBOARD_PATH"
31 | PATIENCE=5
32 | NUM_CKPT=4
33 | TRAINING_LOG="PATH_TO_TRAINING_LOG"
34 | OUTPUT_DIR="PATH_TO_OUTPUT_MODEL"
35 |
36 | WORLD_SIZE=${WORLD_SIZE:-1}
37 | NODE_RANK=${RANK:-0}
38 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
39 | MASTER_PORT=${MASTER_PORT:-12346}
40 |
41 | python -m torch.distributed.run --nproc_per_node=$gpus --nnode=$WORLD_SIZE --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \
42 | train.py --base_model_dir $BASE_MODEL_DIR \
43 | --train_data_list $TRAIN_DATA_LIST \
44 | --pos_dir $POS_DIR \
45 | --neg_dir $NEG_DIR \
46 | --data_dir $DATA_DIR \
47 | --inbatch_pkl_path_dir $INBATCH_PKL_PATH_DIR \
48 | --feature_pkl_path_dir $FEATURE_PKL_PATH_DIR \
49 | --batch_size $BATCH_SIZE
50 | --neg_K $NEG_K \
51 | --num_heads $NUM_HEADS \
52 | --hidden_dim $HIDDEN_DIM \
53 | --output_dim $OUTPUT_DIM \
54 | --ln $LN \
55 | --norm $NORM \
56 | --num_epochs $NUM_EPOCHS \
57 | --padding_side $PADDING_SIDE \
58 | --max_seq_length $MAX_SEQ_LENGTH \
59 | --lr $LR \
60 | --alpha $ALPHA \
61 | --beta $BETA \
62 | --gamma $GAMMA \
63 | --eta $ETA \
64 | --temperature_in_batch $TEMPERATURE_IN_BATCH \
65 | --temperature_hardneg $TEMPERATURE_HARDNEG \
66 | --temperature_teacher_hardneg $TEMPERATURE_TEACHER_HARDNEG \
67 | --scale_param $SCALE_PARAM \
68 | --log_interval $LOG_INTERVAL \
69 | --eval_interval $EVAL_INTERVAL \
70 | --tb_dir $TB_DIR \
71 | --patience $PATIENCE \
72 | --num_ckpt $NUM_CKPT \
73 | --training_log $TRAINING_LOG \
74 | --output_dir $OUTPUT_DIR \
--------------------------------------------------------------------------------
/preprocess/save_logits.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import argparse
5 | from tqdm import tqdm, trange
6 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
7 | from utils.common_utils import load_pickle, write_pickle
8 |
9 |
10 | def sts_template_v5(text1, text2):
11 | return f'#P和#H将分别描述一种事件或问题,它们可能并无关系。仅使用此描述和您对世界的了解,判断#H是不是一个关于#P中的事件绝对正确的句子,或者#H是不是绝对正确地描述了#P的事件或问题,请回答是或不是,若您不确定,请回答不是。\n#P:{text1}\n#H:{text2}\n回答:'
12 |
13 | def context_template_v5(text1, text2):
14 | return f'#Q将描述一个问题,#A将描述一个网络段落,它们可能并没有关系。仅依据这些描述和您对世界的了解,判断#A能不能正确地回答#Q中提出的问题,请回答能或不能。\n#Q:{text1}\n#A:{text2}\n回答:'
15 |
16 |
17 | def generate_logits(model_dir, neg_pkl_file, task_type, bs, teacher_max_seq_length, num_shards, id_shard):
18 | bm_25_dict = load_pickle(neg_pkl_file)
19 | all_sample_list = []
20 | len_dict = {}
21 | all_logits = []
22 | res_dict = {}
23 | lenth_one = len(list(bm_25_dict.keys()))/num_shards
24 | for i, query in enumerate(bm_25_dict):
25 | if i >= lenth_one*id_shard and i < lenth_one*(id_shard+1):
26 | doc_list = bm_25_dict[query]
27 | len_dict[i] = len(doc_list)
28 | if task_type == 'context':
29 | qry_doc_list = [context_template_v5(query, d) for d in doc_list]
30 | elif task_type == 'sts':
31 | qry_doc_list = [sts_template_v5(query, d) for d in doc_list]
32 | all_sample_list.extend(qry_doc_list)
33 | teacher_tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, pad_token='<|endoftext|>', truncation_side='right', padding_side='left')
34 | teacher_tokenizer.pad_token_id = teacher_tokenizer.eod_id
35 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True).to('cuda')
36 | model.eval()
37 | if task_type == 'sts':
38 | yes_id = teacher_tokenizer.encode('是')[0]
39 | no_id = teacher_tokenizer.encode('不是')[0]
40 | elif task_type == 'context':
41 | yes_id = teacher_tokenizer.encode('能')[0]
42 | no_id = teacher_tokenizer.encode('不能')[0]
43 | else:
44 | raise ValueError(f'Error: No Task Type {task_type}')
45 | with torch.no_grad():
46 | for start_index in trange(0, len(all_sample_list), bs, disable=False):
47 | print(start_index)
48 | cross_sentence_batch = all_sample_list[start_index: start_index+bs]
49 | cross_sentence_inputs = teacher_tokenizer(text=cross_sentence_batch, padding='max_length', max_length=teacher_max_seq_length, truncation=True, return_tensors='pt').to('cuda')
50 | outputs_logits = model(**cross_sentence_inputs).logits
51 | outputs_logits = outputs_logits[:, -1, [yes_id, no_id]].cpu().float().numpy().tolist()
52 | all_logits.extend(outputs_logits)
53 | assert len(all_logits) == len(all_sample_list)
54 | start = 0
55 | for i, query in enumerate(bm_25_dict):
56 | if i >= lenth_one*id_shard and i < lenth_one*(id_shard+1):
57 | end = start + len_dict[i]
58 | doc_list = bm_25_dict[query]
59 | logits_list = all_logits[start:end]
60 | assert len(doc_list) == len(logits_list)
61 | res_doc_logits = list(zip(doc_list, logits_list))
62 | res_dict[query] = res_doc_logits
63 | start = end
64 | return res_dict
65 |
66 |
67 |
68 | if __name__ == '__main__':
69 | parser = argparse.ArgumentParser()
70 | parser.add_argument('--model_dir', default='', type=str)
71 | parser.add_argument('--hardneg_dir', default='', type=str)
72 | parser.add_argument('--output_pkl', default='', type=str)
73 | parser.add_argument('--dataset', default='', type=str)
74 | parser.add_argument('--task_type', default='', type=str)
75 | parser.add_argument('--bs', default=140, type=int)
76 | parser.add_argument('--K', type=int)
77 | parser.add_argument('--teacher_max_seq_length', default=500, type=int)
78 | parser.add_argument('--num_shards', default=8, type=int)
79 | parser.add_argument('--id_shard', default=0, type=int)
80 | args = parser.parse_args()
81 |
82 | neg_pkl_file = args.hardneg_dir
83 | output_pkl_path = args.output_pkl
84 | res_dict = generate_logits(args.model_dir, neg_pkl_file, args.task_type, args.bs, args.teacher_max_seq_length, args.num_shards, args.id_shard)
85 | write_pickle(res_dict, output_pkl_path)
86 |
87 |
--------------------------------------------------------------------------------
/preprocess/save_hardneg_bm25.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import csv
3 | import time
4 | import os
5 | import jieba
6 | import pickle
7 | import argparse
8 | from rank_bm25 import BM25Okapi
9 | from collections import defaultdict
10 | from tqdm import tqdm
11 | from datasets import load_dataset
12 |
13 | def write_pickle(obj, file):
14 | with open(file, 'wb') as f:
15 | pickle.dump(obj, f)
16 |
17 | def load_pickle(file):
18 | with open(file, 'rb') as f:
19 | obj = pickle.load(f)
20 | return obj
21 |
22 |
23 | def load_snli_zh(path):
24 | queries = []
25 | corpus = []
26 |
27 | pos_sample_dict = defaultdict(list)
28 |
29 | with open(path, encoding='utf-8') as f:
30 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
31 | for id, row in enumerate(reader):
32 | text_a = row['sentence1']
33 | text_b = row['sentence2']
34 | label = row['gold_label']
35 |
36 | if isinstance(text_b, str):
37 | corpus.append(text_b)
38 |
39 | if label == 'entailment':
40 | if isinstance(text_a, str):
41 | queries.append(text_a)
42 |
43 | pos_sample_dict[text_a].append(text_b)
44 |
45 | return queries, list(set(corpus)), pos_sample_dict
46 |
47 |
48 | def load_sts_zh(path):
49 | queries = []
50 | corpus = []
51 | pos_sample_dict = defaultdict(list)
52 | dataset = load_dataset(path, split='train')
53 | for id, row in enumerate(dataset):
54 | text_a = row['sentence1']
55 | text_b = row['sentence2']
56 | label = row['label']
57 | if isinstance(text_b, str):
58 | corpus.append(text_b)
59 | if path.split('/')[-1] != 'STS-B':
60 | if label == 1:
61 | if isinstance(text_a, str):
62 | queries.append(text_a)
63 |
64 | pos_sample_dict[text_a].append(text_b)
65 | else:
66 | if label >= 4:
67 | if isinstance(text_a, str) :
68 | queries.append(text_a)
69 | pos_sample_dict[text_a].append(text_b)
70 | return queries, list(set(corpus)), pos_sample_dict
71 |
72 | def load_t2(path):
73 | queries = []
74 | corpus = []
75 |
76 | pos_sample_dict = defaultdict(list)
77 | with open(path, 'r', encoding='utf-8') as f:
78 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
79 | for id, row in enumerate(reader):
80 | text_a = row['sentece1']
81 | text_b = row['sentence2']
82 |
83 | if isinstance(text_b, str):
84 | corpus.append(text_b[:320])
85 |
86 | if isinstance(text_a, str):
87 | queries.append(text_a)
88 |
89 | pos_sample_dict[text_a].append(text_b[:320])
90 |
91 |
92 |
93 | return queries, list(set(corpus)), pos_sample_dict
94 |
95 |
96 |
97 | def main():
98 |
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument('--data_name', default='', type=str)
101 | parser.add_argument('--K', default=10, type=int)
102 | parser.add_argument('--num', default=50, type=int)
103 |
104 | args = parser.parse_args()
105 |
106 |
107 | stopwords = []
108 | with open('STOPWORDS_PATH', 'r', encoding='utf8') as f:
109 | for line in f:
110 | line = line.strip('\n')
111 | stopwords.append(line)
112 | output_dir = "OUTPUTS_NEG_BM25_PATH"
113 | if args.data_name == 'snli-zh':
114 | queries, corpus, pos_sample_dict = load_snli_zh("NLI_DATA_PATH")
115 | output_pickle = os.path.join(output_dir, args.data_name+'.pkl')
116 | if args.data_name in ['ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STS-B']:
117 | queries, corpus, pos_sample_dict = load_sts_zh("STS_DATA_PATH")
118 | output_pickle = os.path.join(output_dir, args.data_name+'.pkl')
119 | if args.data_name == 't2':
120 | queries, corpus, pos_sample_dict = load_t2("T2_DATA_PATH")
121 | output_pickle = os.path.join(output_dir, args.data_name+'.pkl')
122 | tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
123 | tokenized_corpus = [list(set(tokenized_doc).difference(set(stopwords))) for tokenized_doc in tokenized_corpus]
124 | bm25 = BM25Okapi(tokenized_corpus)
125 |
126 |
127 | tokenized_queries = [list(jieba.cut(q)) for q in queries]
128 | tokenized_queries = [list(set(tokenized_query).difference(set(stopwords))) for tokenized_query in tokenized_queries]
129 | assert len(queries) == len(tokenized_queries)
130 |
131 | hard_neg_sample_dict = defaultdict(list)
132 | for i,tokenized_query in enumerate(tqdm(tokenized_queries)):
133 | doc_scores = bm25.get_scores(tokenized_query)
134 | res_docs = bm25.get_top_n(tokenized_query, corpus, n=args.K)
135 | for pos in pos_sample_dict[queries[i]]:
136 | while pos in res_docs:
137 | res_docs.remove(pos)
138 |
139 | hard_neg_sample_dict[queries[i]] = res_docs
140 |
141 |
142 | if not os.path.exists(output_dir):
143 | os.makedirs(output_dir)
144 |
145 | write_pickle(hard_neg_sample_dict, output_pickle)
146 |
--------------------------------------------------------------------------------
/model/pro_model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import math
3 | import time
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | from tqdm import tqdm, trange
9 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
10 |
11 | class MAB(nn.Module):
12 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
13 | super(MAB, self).__init__()
14 | self.dim_V = dim_V
15 | self.num_heads = num_heads
16 | self.fc_q = nn.Linear(dim_Q, dim_V)
17 | self.fc_k = nn.Linear(dim_K, dim_V)
18 | self.fc_v = nn.Linear(dim_K, dim_V)
19 |
20 | if ln:
21 | self.ln0 = nn.LayerNorm(dim_V)
22 | self.ln1 = nn.LayerNorm(dim_V)
23 | self.fc_o = nn.Linear(dim_V, dim_V)
24 | nn.init.xavier_uniform_(self.fc_q.weight)
25 | nn.init.xavier_uniform_(self.fc_k.weight)
26 | nn.init.xavier_uniform_(self.fc_v.weight)
27 | nn.init.xavier_uniform_(self.fc_o.weight)
28 |
29 | class PMA(nn.Module):
30 | def __init__(self, dim, num_heads, num_seeds, ln=False):
31 | super(PMA, self).__init__()
32 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
33 | nn.init.xavier_uniform_(self.S)
34 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
35 | def forward(self, X, pad_mask):
36 | if self.S.dtype != torch.bfloat16:
37 | X = X.float()
38 | return self.mab(self.S.repeat(X.size(0), 1, 1), X, pad_mask)
39 |
40 | def forward(self, Q, K, pad_mask=None):
41 |
42 | Q_ = self.fc_q(Q)
43 | K_, V_ = self.fc_k(K), self.fc_v(K)
44 | dim_split = self.dim_V // self.num_heads
45 | Q_ = torch.cat(Q_.split(dim_split, 2), 0)
46 | K_ = torch.cat(K_.split(dim_split, 2), 0)
47 | V_ = torch.cat(V_.split(dim_split, 2), 0)
48 | pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1)
49 | score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)
50 | score = score.masked_fill(pad_mask == 0, -1e12)
51 | A = torch.softmax(score, 2)
52 | A = A * pad_mask
53 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2)
54 | O = Q + O
55 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
56 | O = O + F.relu(self.fc_o(O))
57 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
58 | return O
59 |
60 |
61 | class IEM(nn.Module):
62 |
63 | def __init__(self, d_model, hidden, d_output, drop_prob=0.0):
64 | super(IEM, self).__init__()
65 | self.linear1 = nn.Linear(2*d_model, hidden)
66 | self.proj0 = nn.Linear(hidden, hidden)
67 | self.proj1 = nn.Linear(hidden, hidden)
68 | self.linear2 = nn.Linear(hidden, d_output)
69 | nn.init.xavier_uniform_(self.linear1.weight)
70 | nn.init.xavier_uniform_(self.proj0.weight)
71 | nn.init.xavier_uniform_(self.proj1.weight)
72 | nn.init.xavier_uniform_(self.linear2.weight)
73 | self.relu = nn.ReLU()
74 | self.dropout = nn.Dropout(p=drop_prob)
75 | self.sftmx = nn.Softmax(dim=-1)
76 |
77 | def forward(self, emb_a, emb_b):
78 | x = torch.cat((emb_a, emb_b), dim=-1)
79 | x = self.linear1(x)
80 | x = self.relu(x)
81 | x = self.dropout(x)
82 | x0 = self.proj0(x)
83 | x1 = self.proj1(x)
84 | x0 = self.relu(x0)
85 | x1 = self.relu(x1)
86 | rep = torch.stack((x0,x1),dim=0)
87 | logits0 = self.linear2(x0)
88 | logits1 = self.linear2(x1)
89 | logits = torch.cat((logits0, logits1), dim=-1)
90 | return logits, rep
91 |
92 |
93 |
94 | class Mymodel(nn.Module):
95 | def __init__(self,
96 | model_name_or_path = None,
97 | alias = None,
98 | max_seq_length = 256,
99 | args = None
100 | ):
101 | super(Mymodel, self).__init__()
102 | self.alias = alias
103 | if self.alias == None:
104 | self.alias = model_name_or_path
105 | self.args = args
106 | self.max_seq_length = max_seq_length
107 | self.model_name_or_path = model_name_or_path
108 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True, pad_token='<|endoftext|>', truncation_side='right', padding_side=self.args.padding_side)
109 | self.tokenizer.pad_token_id = self.tokenizer.eod_id
110 | self.plm_model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, trust_remote_code=True)
111 | self.emb_dim = self.plm_model.transformer.wte.weight.size(1)
112 | self.num_heads = args.num_heads
113 | self.ln = args.ln
114 | self.norm = args.norm
115 | self.mha_pma = PMA(self.emb_dim, self.num_heads, 1, ln=self.ln)
116 | self.iem = IEM(self.emb_dim, self.hidden_dim, self.output_dim)
117 |
118 | def forward(self, inputs_all, task_ids, mode):
119 | if mode == 'train':
120 | output_embeddings_all = self.get_sentence_embedding(**inputs_all).reshape(2+self.args.neg_K, -1, self.emb_dim)
121 | output_embeddings_hardneg = output_embeddings_all[2:]
122 | elif mode == 'eval':
123 | output_embeddings_all = self.get_sentence_embedding(**inputs_all).reshape(2, -1, self.emb_dim)
124 | else:
125 | raise ValueError('Error of mode value')
126 |
127 | output_embeddings_a = output_embeddings_all[0]
128 | output_embeddings_b = output_embeddings_all[1]
129 |
130 | bs = output_embeddings_a.size(0)
131 | a_expand_emb = output_embeddings_a.unsqueeze(1).expand(-1, bs, -1).reshape(-1, self.emb_dim)
132 | b_expand_emb = output_embeddings_b.unsqueeze(0).expand(bs, -1, -1).reshape(-1, self.emb_dim)
133 |
134 | task_expand = task_ids.unsqueeze(1).expand(-1, bs).reshape(-1,1).squeeze()
135 | output_in_batch, _ = self.iem(a_expand_emb, b_expand_emb) # (bs*bs, 2)
136 | output_in_batch_specific_task = output_in_batch[range(task_expand.size(0)), task_expand].squeeze().reshape(bs, -1)
137 |
138 | if mode == 'train':
139 | pos_neg_emb = torch.cat([output_embeddings_b.unsqueeze(0), output_embeddings_hardneg], dim=0)
140 | achr_emb = output_embeddings_a.unsqueeze(0).expand(pos_neg_emb.size(0),-1,-1)
141 | output_hardneg, output_pos_hardneg_rep = self.iem(achr_emb, pos_neg_emb)
142 | task_id_gather = task_ids.unsqueeze(0).unsqueeze(-1).expand(pos_neg_emb.size(0), -1, -1)
143 | output_hardneg_specific_task = torch.gather(output_hardneg, -1, task_id_gather).squeeze().t()
144 | output_pos_hardneg_rep_specific_task = output_pos_hardneg_rep[task_ids[0]]
145 | elif mode == 'eval':
146 | output_hardneg_specific_task = None
147 | output_pos_hardneg_rep_specific_task = None
148 |
149 | return output_in_batch_specific_task, output_hardneg_specific_task, output_pos_hardneg_rep_specific_task
150 |
151 | def pma_embedding(self, A, mask):
152 | res = self.mha_pma(A, mask).squeeze(1)
153 | return res
154 |
155 | def get_sentence_embedding(self, **inputs):
156 | outputs = self.plm_model(**inputs, output_hidden_states=True)
157 | embedding = outputs.hidden_states[self.keep_max_layer]
158 | attention_mask = inputs['attention_mask']
159 | res_embedding = self.pma_embedding(embedding, attention_mask)
160 |
161 | if self.norm:
162 | res_embedding = torch.nn.functional.normalize(res_embedding, p=2.0, dim=-1, eps=1e-12, out=None)
163 | return res_embedding
164 |
165 | def encode(self, sentences, batch_size=64, convert_to_numpy=True,
166 | convert_to_tensor=False, show_progress_bar=True, max_seq_length=None, **kwargs):
167 |
168 | if max_seq_length is None:
169 | max_seq_length = self.max_seq_length
170 |
171 | input_is_string = False
172 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
173 | sentences = [sentences]
174 | input_is_string = True
175 |
176 | all_embeddings = []
177 | length_sorted_idx = np.argsort([-len(s) for s in sentences])
178 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
179 | with torch.no_grad():
180 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
181 | sentences_batch = sentences_sorted[start_index: start_index + batch_size]
182 | with torch.no_grad():
183 | inputs = self.tokenizer(sentences_batch, padding=True, truncation=True, max_length=max_seq_length, return_tensors='pt').to(self.plm_model.device)
184 | embeddings = self.get_sentence_embedding(**inputs)
185 | embeddings = embeddings.detach()
186 | if convert_to_numpy:
187 | if embeddings.dtype == torch.bfloat16:
188 | embeddings = embeddings.cpu().to(torch.float32)
189 | else:
190 | embeddings = embeddings.cpu()
191 | all_embeddings.extend(embeddings)
192 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
193 | if convert_to_tensor:
194 | all_embeddings = torch.stack(all_embeddings)
195 | elif convert_to_numpy:
196 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
197 |
198 | if input_is_string:
199 | all_embeddings = all_embeddings[0]
200 | return all_embeddings
201 |
202 |
--------------------------------------------------------------------------------
/preprocess/save_hardnrg_bi.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import csv
4 | import pathlib
5 | import json
6 | import argparse
7 | import warnings
8 | import deepspeed
9 | from enum import Enum
10 | from typing import Union, List
11 | from datasets import load_dataset
12 | from tqdm import tqdm, trange
13 | from collections import defaultdict
14 | from utils.common_utils import *
15 | warnings.filterwarnings('ignore')
16 | from mteb.mteb import MTEB
17 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
18 | from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
19 | maxInt = sys.maxsize
20 |
21 | while True:
22 | try:
23 | csv.field_size_limit(maxInt)
24 | break
25 | except OverflowError:
26 | maxInt = int(maxInt/10)
27 |
28 | def makedirs(path):
29 | p = pathlib.Path(path)
30 | p.parent.mkdir(parents=True, exist_ok=True)
31 | return path
32 |
33 | class EncoderType(Enum):
34 | FIRST_LAST_AVG = 0
35 | LAST_AVG = 1
36 | CLS = 2
37 | POOLER = 3
38 | MEAN = 4
39 |
40 | def __str__(self):
41 | return self.name
42 |
43 | @staticmethod
44 | def from_string(s):
45 | try:
46 | return EncoderType[s]
47 | except KeyError:
48 | raise ValueError()
49 |
50 | class BaseBertModel:
51 | def __init__(
52 | self,
53 | model_name_or_path = None,
54 | max_seq_length = 512,
55 | encoder_type = 'CLS',
56 | alias = None
57 | ):
58 | self.model_name_or_path = model_name_or_path
59 | encoder_type = EncoderType.from_string(encoder_type) if isinstance(encoder_type, str) else encoder_type
60 | if encoder_type not in list(EncoderType):
61 | raise ValueError(f'encoder_type must be in {list(EncoderType)}')
62 | self.encoder_type = encoder_type
63 | self.max_seq_length = max_seq_length
64 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side='right', padding_side='right')
65 | self.plm_model = AutoModel.from_pretrained(model_name_or_path)
66 | self.results = {}
67 | device = "cuda" if torch.cuda.is_available() else "cpu"
68 | self.device = torch.device(device)
69 | self.plm_model.to(self.device)
70 |
71 |
72 |
73 | def get_sentence_embeddings(self, input_ids, attention_mask, token_type_ids=None):
74 | model_output = self.plm_model(input_ids, attention_mask, token_type_ids, output_hidden_states=True)
75 |
76 | if self.encoder_type == EncoderType.FIRST_LAST_AVG:
77 | first = model_output.hidden_states[1]
78 | last = model_output.hidden_states[-1]
79 | seq_length = first.size(1)
80 |
81 | first_avg = torch.avg_pool1d(first.transpose(1, 2), kernel_size=seq_length).squeeze(-1)
82 | last_avg = torch.avg_pool1d(last.transpose(1, 2), kernel_size=seq_length).squeeze(-1)
83 | final_encoding = torch.avg_pool1d(
84 | torch.cat([first_avg.unsqueeze(1), last_avg.unsqueeze(1)], dim=1).transpose(1, 2),
85 | kernel_size=2).squeeze(-1)
86 | return final_encoding
87 |
88 | if self.encoder_type == EncoderType.LAST_AVG:
89 | sequence_output = model_output.last_hidden_state
90 | seq_length = sequence_output.size(1)
91 | final_encoding = torch.avg_pool1d(sequence_output.transpose(1, 2), kernel_size=seq_length).squeeze(-1)
92 | return final_encoding
93 |
94 | if self.encoder_type == EncoderType.CLS:
95 | sequence_output = model_output.last_hidden_state
96 | return sequence_output[:, 0]
97 |
98 | if self.encoder_type == EncoderType.POOLER:
99 | return model_output.pooler_output
100 |
101 | if self.encoder_type == EncoderType.MEAN:
102 | token_embeddings = model_output.last_hidden_state # Contains all token embeddings
103 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
104 | final_encoding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
105 | input_mask_expanded.sum(1), min=1e-9)
106 | return final_encoding # [batch, hid_size]
107 |
108 | def batch_to_device(self, batch, device):
109 | for key in batch:
110 | if isinstance(batch[key], torch.Tensor):
111 | batch[key] = batch[key].to(device)
112 | return batch
113 |
114 |
115 | def encode(
116 | self,
117 | sentences: Union[str, List[str]],
118 | batch_size: int = 32,
119 | show_progress_bar: bool = False,
120 | convert_to_numpy: bool = True,
121 | convert_to_tensor: bool = False,
122 | device: str = None,
123 | normalize_embeddings: bool = True,
124 | max_seq_length: int = None,
125 | ):
126 | self.plm_model.eval()
127 | if device is None:
128 | device = self.device
129 | self.plm_model.to(device)
130 |
131 | if max_seq_length is None:
132 | max_seq_length = self.max_seq_length
133 | if convert_to_tensor:
134 | convert_to_numpy = False
135 | input_is_string = False
136 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
137 | sentences = [sentences]
138 | input_is_string = True
139 |
140 | all_embeddings = []
141 | length_sorted_idx = np.argsort([-len(s) for s in sentences])
142 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
143 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
144 | sentences_batch = sentences_sorted[start_index: start_index + batch_size]
145 | with torch.no_grad():
146 | features = self.tokenizer(
147 | sentences_batch, max_length=max_seq_length,
148 | padding=True, truncation=True, return_tensors='pt'
149 | )
150 | features = self.batch_to_device(features, device)
151 | embeddings = self.get_sentence_embeddings(**features)
152 | embeddings = embeddings.detach()
153 | if normalize_embeddings:
154 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
155 |
156 | if convert_to_numpy:
157 | embeddings = embeddings.cpu()
158 | all_embeddings.extend(embeddings)
159 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
160 | if convert_to_tensor:
161 | all_embeddings = torch.stack(all_embeddings)
162 | elif convert_to_numpy:
163 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
164 |
165 | if input_is_string:
166 | all_embeddings = all_embeddings[0]
167 |
168 | return all_embeddings
169 |
170 | def write_t2_corpus(model, output_dir):
171 | makedirs(output_dir)
172 | corpus = set()
173 | corpus_path = "PATH_TO_SAVED_CORPUS"
174 | with open(corpus_path, 'r', encoding='utf-8') as f:
175 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
176 | for id, row in enumerate(reader):
177 | corpus.add(row['text'][:320])
178 |
179 | corpus = list(corpus)
180 |
181 | corpus_psg_id_dict = {psg:id for id, psg in enumerate(corpus)}
182 | corpus_id_psg_dict = {id:psg for id, psg in enumerate(corpus)}
183 |
184 | corpus_psg_id_dict_path = "PATH_TO_SAVED_PSG_ID_DICT"
185 | corpus_id_psg_dict_path = "PATH_TO_SAVED_ID_PSG_DICT"
186 | corpus_rep_path = "PATH_TO_SAVED_REP"
187 | corpus_rep = model.encode(corpus, batch_size=1500, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True, max_seq_length=250).to('cpu')
188 |
189 | write_pickle(corpus_psg_id_dict, corpus_psg_id_dict_path)
190 | write_pickle(corpus_id_psg_dict, corpus_id_psg_dict_path)
191 | write_pickle(corpus_rep, corpus_rep_path)
192 |
193 |
194 | def write_t2_qry(model, corpus_psg_id_dict_path, corpus_id_psg_dict_path, corpus_rep_path, output_dir, K):
195 | res = defaultdict(list)
196 | queries = []
197 | pos_sample_dict = defaultdict(list)
198 | corpus_psg_id_dict = load_pickle(corpus_psg_id_dict_path)
199 | corpus_id_psg_dict = load_pickle(corpus_id_psg_dict_path)
200 | corpus_rep = load_pickle(corpus_rep_path)
201 | query_path = f'DATA_PATH'
202 | data_all_path = f'ALL_DATA_PATH'
203 |
204 | with open(data_all_path, 'r', encoding='utf-8') as f:
205 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
206 | for id, row in enumerate(reader):
207 | text_a = row['sentence1']
208 | text_b = row['sentence2'][:320]
209 | pos_sample_dict[text_a].append(text_b)
210 |
211 | with open(query_path, 'r', encoding='utf-8') as f:
212 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
213 | for id, row in enumerate(reader):
214 | text_a = row['sentence1']
215 | queries.append(text_a)
216 |
217 | makedirs("QUERY_PATH")
218 | if not os.path.exists("QUERY_PKL_PATH"):
219 | queries_rep = model.encode(queries, batch_size=1500, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True, max_seq_length=100).to('cpu')
220 | write_pickle(queries_rep, "QUERY_PKL_PATH")
221 | queries_rep = load_pickle("QUERY_PKL_PATH")
222 |
223 |
224 | qry_chunk_size = 2000
225 | qry_num = queries_rep.size(0)
226 | corpus_num = corpus_rep.size(0)
227 | for start in trange(0, qry_num, qry_chunk_size, disable=False):
228 | end = min(start+qry_chunk_size, qry_num)
229 | qry_bch_rep = queries_rep[start:end, :]
230 | score_bch = cos_sim(qry_bch_rep, corpus_rep)
231 | _, ids = torch.topk(score_bch, min(K+1, score_bch.size(1)), dim=1, largest=True,sorted=True)
232 | ids = ids.tolist()
233 | for qry_id in range(start, end):
234 | id_from_zero = qry_id - start
235 | qry_text = queries[qry_id]
236 | pos_text_list = pos_sample_dict[qry_text]
237 | for sub_id in ids[id_from_zero][-100:]:
238 | hardneg_text = corpus_id_psg_dict[sub_id]
239 | if hardneg_text not in pos_text_list and hardneg_text not in res[qry_text]:
240 | res[qry_text].append(hardneg_text)
241 |
242 | res_path = "FINAL_RES_PATH"
243 | write_pickle(res, res_path)
244 |
245 |
246 |
247 | def main():
248 | parser = argparse.ArgumentParser()
249 | parser.add_argument('--dataset', default='', type=str)
250 | parser.add_argument('--data_path', default='', type=str)
251 | parser.add_argument('--output_dir', default='', type=str)
252 | parser.add_argument('--ratio', default=0.5, type=float)
253 | parser.add_argument('--K', default=100, type=int)
254 | parser.add_argument('--base_model_dir', default='', type=str)
255 | parser.add_argument('--max_seq_len', default=250, type=int, help='max sequence length')
256 | parser.add_argument('--seed', default=2023, type=int)
257 | args = parser.parse_args()
258 | set_seed(args.seed)
259 | args.output_corpus_path = os.path.join(args.data_path, 'corpus')
260 | makedirs(args.output_corpus_path)
261 | model = BaseBertModel(model_name_or_path=args.base_model_dir,
262 | alias=None,
263 | encoder_type = 'CLS',
264 | max_seq_length=args.max_seq_len)
265 |
266 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
267 | device = torch.device(device)
268 | model.plm_model.to(device)
269 | model.plm_model.eval()
270 |
271 | if not os.path.exists(os.path.join(f'{args.output_corpus_path}', 'corpus_rep.pkl')):
272 | if args.dataset == 'T2Ranking':
273 | write_t2_corpus(model, f'{args.output_corpus_path}')
274 | corpus_psg_id_dict_path = os.path.join(f'{args.output_corpus_path}', 'corpus_psg_id_dict.pkl')
275 | corpus_id_psg_dict_path = os.path.join(f'{args.output_corpus_path}', 'corpus_id_psg_dict.pkl')
276 | corpus_rep_path = os.path.join(f'{args.output_corpus_path}', 'corpus_rep.pkl')
277 | if args.dataset == 'T2Ranking':
278 | write_t2_qry(args.ratio, model, corpus_psg_id_dict_path, corpus_id_psg_dict_path, corpus_rep_path, args.output_dir, args.K)
279 |
280 |
281 | if __name__ == '__main__':
282 | main()
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Copyright [2023] [Ant Group]
2 | Licensed under the Apache License, Version 2.0 (the "License");
3 | you may not use this file except in compliance with the License.
4 | You may obtain a copy of the License at
5 | http://www.apache.org/licenses/LICENSE-2.0
6 |
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 |
13 |
14 | Apache License
15 | Version 2.0, January 2004
16 | http://www.apache.org/licenses/
17 |
18 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
19 |
20 | 1. Definitions.
21 |
22 | "License" shall mean the terms and conditions for use, reproduction,
23 | and distribution as defined by Sections 1 through 9 of this document.
24 |
25 | "Licensor" shall mean the copyright owner or entity authorized by
26 | the copyright owner that is granting the License.
27 |
28 | "Legal Entity" shall mean the union of the acting entity and all
29 | other entities that control, are controlled by, or are under common
30 | control with that entity. For the purposes of this definition,
31 | "control" means (i) the power, direct or indirect, to cause the
32 | direction or management of such entity, whether by contract or
33 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
34 | outstanding shares, or (iii) beneficial ownership of such entity.
35 |
36 | "You" (or "Your") shall mean an individual or Legal Entity
37 | exercising permissions granted by this License.
38 |
39 | "Source" form shall mean the preferred form for making modifications,
40 | including but not limited to software source code, documentation
41 | source, and configuration files.
42 |
43 | "Object" form shall mean any form resulting from mechanical
44 | transformation or translation of a Source form, including but
45 | not limited to compiled object code, generated documentation,
46 | and conversions to other media types.
47 |
48 | "Work" shall mean the work of authorship, whether in Source or
49 | Object form, made available under the License, as indicated by a
50 | copyright notice that is included in or attached to the work
51 | (an example is provided in the Appendix below).
52 |
53 | "Derivative Works" shall mean any work, whether in Source or Object
54 | form, that is based on (or derived from) the Work and for which the
55 | editorial revisions, annotations, elaborations, or other modifications
56 | represent, as a whole, an original work of authorship. For the purposes
57 | of this License, Derivative Works shall not include works that remain
58 | separable from, or merely link (or bind by name) to the interfaces of,
59 | the Work and Derivative Works thereof.
60 |
61 | "Contribution" shall mean any work of authorship, including
62 | the original version of the Work and any modifications or additions
63 | to that Work or Derivative Works thereof, that is intentionally
64 | submitted to Licensor for inclusion in the Work by the copyright owner
65 | or by an individual or Legal Entity authorized to submit on behalf of
66 | the copyright owner. For the purposes of this definition, "submitted"
67 | means any form of electronic, verbal, or written communication sent
68 | to the Licensor or its representatives, including but not limited to
69 | communication on electronic mailing lists, source code control systems,
70 | and issue tracking systems that are managed by, or on behalf of, the
71 | Licensor for the purpose of discussing and improving the Work, but
72 | excluding communication that is conspicuously marked or otherwise
73 | designated in writing by the copyright owner as "Not a Contribution."
74 |
75 | "Contributor" shall mean Licensor and any individual or Legal Entity
76 | on behalf of whom a Contribution has been received by Licensor and
77 | subsequently incorporated within the Work.
78 |
79 | 2. Grant of Copyright License. Subject to the terms and conditions of
80 | this License, each Contributor hereby grants to You a perpetual,
81 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
82 | copyright license to reproduce, prepare Derivative Works of,
83 | publicly display, publicly perform, sublicense, and distribute the
84 | Work and such Derivative Works in Source or Object form.
85 |
86 | 3. Grant of Patent License. Subject to the terms and conditions of
87 | this License, each Contributor hereby grants to You a perpetual,
88 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
89 | (except as stated in this section) patent license to make, have made,
90 | use, offer to sell, sell, import, and otherwise transfer the Work,
91 | where such license applies only to those patent claims licensable
92 | by such Contributor that are necessarily infringed by their
93 | Contribution(s) alone or by combination of their Contribution(s)
94 | with the Work to which such Contribution(s) was submitted. If You
95 | institute patent litigation against any entity (including a
96 | cross-claim or counterclaim in a lawsuit) alleging that the Work
97 | or a Contribution incorporated within the Work constitutes direct
98 | or contributory patent infringement, then any patent licenses
99 | granted to You under this License for that Work shall terminate
100 | as of the date such litigation is filed.
101 |
102 | 4. Redistribution. You may reproduce and distribute copies of the
103 | Work or Derivative Works thereof in any medium, with or without
104 | modifications, and in Source or Object form, provided that You
105 | meet the following conditions:
106 |
107 | (a) You must give any other recipients of the Work or
108 | Derivative Works a copy of this License; and
109 |
110 | (b) You must cause any modified files to carry prominent notices
111 | stating that You changed the files; and
112 |
113 | (c) You must retain, in the Source form of any Derivative Works
114 | that You distribute, all copyright, patent, trademark, and
115 | attribution notices from the Source form of the Work,
116 | excluding those notices that do not pertain to any part of
117 | the Derivative Works; and
118 |
119 | (d) If the Work includes a "NOTICE" text file as part of its
120 | distribution, then any Derivative Works that You distribute must
121 | include a readable copy of the attribution notices contained
122 | within such NOTICE file, excluding those notices that do not
123 | pertain to any part of the Derivative Works, in at least one
124 | of the following places: within a NOTICE text file distributed
125 | as part of the Derivative Works; within the Source form or
126 | documentation, if provided along with the Derivative Works; or,
127 | within a display generated by the Derivative Works, if and
128 | wherever such third-party notices normally appear. The contents
129 | of the NOTICE file are for informational purposes only and
130 | do not modify the License. You may add Your own attribution
131 | notices within Derivative Works that You distribute, alongside
132 | or as an addendum to the NOTICE text from the Work, provided
133 | that such additional attribution notices cannot be construed
134 | as modifying the License.
135 |
136 | You may add Your own copyright statement to Your modifications and
137 | may provide additional or different license terms and conditions
138 | for use, reproduction, or distribution of Your modifications, or
139 | for any such Derivative Works as a whole, provided Your use,
140 | reproduction, and distribution of the Work otherwise complies with
141 | the conditions stated in this License.
142 |
143 | 5. Submission of Contributions. Unless You explicitly state otherwise,
144 | any Contribution intentionally submitted for inclusion in the Work
145 | by You to the Licensor shall be under the terms and conditions of
146 | this License, without any additional terms or conditions.
147 | Notwithstanding the above, nothing herein shall supersede or modify
148 | the terms of any separate license agreement you may have executed
149 | with Licensor regarding such Contributions.
150 |
151 | 6. Trademarks. This License does not grant permission to use the trade
152 | names, trademarks, service marks, or product names of the Licensor,
153 | except as required for reasonable and customary use in describing the
154 | origin of the Work and reproducing the content of the NOTICE file.
155 |
156 | 7. Disclaimer of Warranty. Unless required by applicable law or
157 | agreed to in writing, Licensor provides the Work (and each
158 | Contributor provides its Contributions) on an "AS IS" BASIS,
159 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
160 | implied, including, without limitation, any warranties or conditions
161 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
162 | PARTICULAR PURPOSE. You are solely responsible for determining the
163 | appropriateness of using or redistributing the Work and assume any
164 | risks associated with Your exercise of permissions under this License.
165 |
166 | 8. Limitation of Liability. In no event and under no legal theory,
167 | whether in tort (including negligence), contract, or otherwise,
168 | unless required by applicable law (such as deliberate and grossly
169 | negligent acts) or agreed to in writing, shall any Contributor be
170 | liable to You for damages, including any direct, indirect, special,
171 | incidental, or consequential damages of any character arising as a
172 | result of this License or out of the use or inability to use the
173 | Work (including but not limited to damages for loss of goodwill,
174 | work stoppage, computer failure or malfunction, or any and all
175 | other commercial damages or losses), even if such Contributor
176 | has been advised of the possibility of such damages.
177 |
178 | 9. Accepting Warranty or Additional Liability. While redistributing
179 | the Work or Derivative Works thereof, You may choose to offer,
180 | and charge a fee for, acceptance of support, warranty, indemnity,
181 | or other liability obligations and/or rights consistent with this
182 | License. However, in accepting such obligations, You may act only
183 | on Your own behalf and on Your sole responsibility, not on behalf
184 | of any other Contributor, and only if You agree to indemnify,
185 | defend, and hold each Contributor harmless for any liability
186 | incurred by, or claims asserted against, such Contributor by reason
187 | of your accepting any such warranty or additional liability.
188 |
189 | END OF TERMS AND CONDITIONS
190 |
191 | APPENDIX: How to apply the Apache License to your work.
192 |
193 | To apply the Apache License to your work, attach the following
194 | boilerplate notice, with the fields enclosed by brackets "[]"
195 | replaced with your own identifying information. (Don't include
196 | the brackets!) The text should be enclosed in the appropriate
197 | comment syntax for the file format. We also recommend that a
198 | file or class name and description of purpose be included on the
199 | same "printed page" as the copyright notice for easier
200 | identification within third-party archives.
201 |
202 | Copyright [yyyy] [name of copyright owner]
203 |
204 | Licensed under the Apache License, Version 2.0 (the "License");
205 | you may not use this file except in compliance with the License.
206 | You may obtain a copy of the License at
207 |
208 | http://www.apache.org/licenses/LICENSE-2.0
209 |
210 | Unless required by applicable law or agreed to in writing, software
211 | distributed under the License is distributed on an "AS IS" BASIS,
212 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
213 | See the License for the specific language governing permissions and
214 | limitations under the License.
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 | from datasets import load_dataset
8 | from transformers import PreTrainedTokenizer
9 | import csv
10 | from loguru import logger
11 | import random
12 | from utils.common_utils import load_pickle
13 | DATASET_ID_DICT = {'snli-zh':1,'sts':2,'t2-05':3,'du-10':4,'mmarco':5,'cmedqa':6}
14 | def load_text_dataset(name, pos_dir, neg_dir, file_path, neg_K, res_data, split):
15 | data = []
16 | if split == 'train':
17 | hard_neg_house = load_pickle(neg_dir)
18 | pos_logis = load_pickle(pos_dir)
19 | with open(file_path, encoding='utf-8') as f:
20 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
21 | for id, row in enumerate(reader):
22 | text_a = row['sentence1']
23 | text_b = row['sentence2']
24 | score = row['gold_label']
25 | if score == 'entailment':
26 | if split == 'train':
27 | if len(hard_neg_house[text_a]) < neg_K:
28 | num = math.ceil(neg_K / len(hard_neg_house[text_a]))
29 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K)
30 | else:
31 | negs_logits = random.sample(hard_neg_house[text_a], neg_K)
32 | hardnegs, hardneg_logits = zip(*negs_logits)
33 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits)
34 | elif split == 'validation':
35 | hardnegs = []
36 | hardneg_logits = []
37 | pos_logits = []
38 | hardnegs = [sample[:100] for sample in hardnegs]
39 | data.append((text_a[:100], text_b[:100], pos_logits, hardnegs, hardneg_logits, 0))
40 | if split == 'train':
41 | split_data = data[:-10000]
42 | sample_num = len(split_data)
43 | elif split == 'validation':
44 | split_data = data[-10000:]
45 | sample_num = len(split_data)
46 | res_data.extend(split_data)
47 |
48 | return res_data, sample_num
49 |
50 |
51 | def load_sts_dataset_train(name, pos_dir, neg_dir, file_path, neg_K, res_data):
52 | data = []
53 | pos_logis = load_pickle(pos_dir)
54 | hard_neg_house = load_pickle(neg_dir)
55 | with open(file_path, encoding='utf-8') as f:
56 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
57 | for id, row in enumerate(reader):
58 | text_a = row['sentence1']
59 | text_b = row['sentence2']
60 | if len(hard_neg_house[text_a]) < neg_K:
61 | num = math.ceil(neg_K / len(hard_neg_house[text_a]))
62 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K)
63 | else:
64 | negs_logits = random.sample(hard_neg_house[text_a], neg_K)
65 | hardnegs, hardneg_logits = zip(*negs_logits)
66 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits)
67 | hardnegs = [sample[:100] for sample in hardnegs]
68 | data.append((text_a[:100], text_b[:100], pos_logits, hardnegs, hardneg_logits, 0))
69 |
70 | sample_num = len(data)
71 | res_data.extend(data)
72 |
73 | return res_data, sample_num
74 |
75 | def load_sts_dataset_val(name, pos_dir, neg_dir, file_path, neg_K, res_data):
76 | data = []
77 | with open(file_path, encoding='utf-8') as f:
78 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
79 | for id, row in enumerate(reader):
80 | text_a = row['sentence1']
81 | text_b = row['sentence2']
82 | data.append((text_a[:100], text_b[:100], [], [], [], 0))
83 |
84 | sample_num = len(data)
85 | res_data.extend(data)
86 |
87 | return res_data, sample_num
88 |
89 | def load_t2_dataset_train(name, pos_dir, neg_dir, file_path, neg_K, res_data):
90 | data = []
91 | pos_logis = load_pickle(pos_dir)
92 | hard_neg_house = load_pickle(neg_dir)
93 | with open(file_path, encoding='utf-8') as f:
94 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
95 | for id, row in enumerate(reader):
96 | text_a = row['sentence1']
97 | text_b = row['sentence2']
98 | if len(hard_neg_house[text_a]) < neg_K:
99 | num = math.ceil(neg_K / len(hard_neg_house[text_a]))
100 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K)
101 | else:
102 | negs_logits = random.sample(hard_neg_house[text_a], neg_K)
103 | hardnegs, hardneg_logits = zip(*negs_logits)
104 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits)
105 | hardnegs = [sample[:320] for sample in hardnegs]
106 | data.append((text_a[:50], text_b[:320], pos_logits, hardnegs, hardneg_logits, 1))
107 |
108 | sample_num = len(data)
109 | res_data.extend(data)
110 |
111 | return res_data, sample_num
112 |
113 | def load_t2_dataset_val(name, pos_dir, neg_dir, file_path, neg_K, res_data):
114 | data = []
115 | with open(file_path, encoding='utf-8') as f:
116 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
117 | for id, row in enumerate(reader):
118 | text_a = row['sentence1']
119 | text_b = row['sentence2']
120 | data.append((text_a[:50], text_b[:320], [], [], [], 1))
121 |
122 | sample_num = len(data)
123 | res_data.extend(data)
124 |
125 | return res_data, sample_num
126 |
127 |
128 | def load_du_dataset_train(name, pos_dir, neg_dir, file_path, neg_K, res_data):
129 | data = []
130 | pos_logits = load_pickle(pos_dir)
131 | hard_neg_house = load_pickle(neg_dir)
132 | with open(file_path, encoding='utf-8') as f:
133 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
134 | for id, row in enumerate(reader):
135 | text_a = row['sentence1']
136 | text_b = row['sentence2']
137 | if len(hard_neg_house[text_a]) < neg_K:
138 | num = math.ceil(neg_K / len(hard_neg_house[text_a]))
139 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K)
140 | else:
141 | negs_logits = random.sample(hard_neg_house[text_a], neg_K)
142 | hardnegs, hardneg_logits = zip(*negs_logits)
143 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits)
144 | hardnegs = [sample[:320] for sample in hardnegs]
145 | data.append((text_a[:50], text_b[:320], pos_logits, hardnegs, hardneg_logits, 1))
146 |
147 | sample_num = len(data)
148 | res_data.extend(data)
149 |
150 | return res_data, sample_num
151 |
152 | def load_du_dataset_val(name, pos_dir, neg_dir, file_path, neg_K, res_data):
153 | data = []
154 | with open(file_path, encoding='utf-8') as f:
155 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
156 | for id, row in enumerate(reader):
157 | text_a = row['sentence1']
158 | text_b = row['sentence2']
159 | data.append((text_a[:50], text_b[:320], [], [], [], 1))
160 |
161 | sample_num = len(data)
162 | res_data.extend(data)
163 |
164 | return res_data, sample_num
165 |
166 | def load_mmarco_dataset_train(name, pos_dir, neg_dir, file_path, neg_K, res_data):
167 | data = []
168 | pos_logis = load_pickle(pos_dir)
169 | hard_neg_house = load_pickle(neg_dir)
170 | with open(file_path, encoding='utf-8') as f:
171 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
172 | for id, row in enumerate(reader):
173 | text_a = row['sentence1']
174 | text_b = row['sentence2']
175 | if len(hard_neg_house[text_a]) < neg_K:
176 | num = math.ceil(neg_K / len(hard_neg_house[text_a]))
177 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K)
178 | else:
179 | negs_logits = random.sample(hard_neg_house[text_a], neg_K)
180 | hardnegs, hardneg_logits = zip(*negs_logits)
181 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits)
182 | hardnegs = [sample[:320] for sample in hardnegs]
183 | data.append((text_a[:50], text_b[:320], pos_logits, hardnegs, hardneg_logits, 1))
184 |
185 | sample_num = len(data)
186 | res_data.extend(data)
187 |
188 | return res_data, sample_num
189 |
190 | def load_mmarco_dataset_val(name, pos_dir, neg_dir, file_path, neg_K, res_data):
191 | data = []
192 | with open(file_path, encoding='utf-8') as f:
193 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
194 | for id, row in enumerate(reader):
195 | text_a = row['sentence1']
196 | text_b = row['sentence2']
197 | data.append((text_a[:50], text_b[:320], [], [], [], 1))
198 |
199 | sample_num = len(data)
200 | res_data.extend(data)
201 |
202 | return res_data, sample_num
203 |
204 | def load_cmedqa_dataset_train(name, pos_dir, neg_dir, file_path, neg_K, res_data):
205 | data = []
206 | pos_logis = load_pickle(pos_dir)
207 | hard_neg_house = load_pickle(neg_dir)
208 | with open(file_path, encoding='utf-8') as f:
209 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
210 | for id, row in enumerate(reader):
211 | text_a = row['sentence1']
212 | text_b = row['sentence2']
213 | if len(hard_neg_house[text_a]) < neg_K:
214 | num = math.ceil(neg_K / len(hard_neg_house[text_a]))
215 | negs_logits = random.sample(hard_neg_house[text_a] * num, neg_K)
216 | else:
217 | negs_logits = random.sample(hard_neg_house[text_a], neg_K)
218 | hardnegs, hardneg_logits = zip(*negs_logits)
219 | hardnegs, hardneg_logits = list(hardnegs), list(hardneg_logits)
220 | hardnegs = [sample[:320] for sample in hardnegs]
221 | data.append((text_a[:50], text_b[:320], pos_logits, hardnegs, hardneg_logits, 1))
222 |
223 | sample_num = len(data)
224 | res_data.extend(data)
225 |
226 | return res_data, sample_num
227 |
228 | def load_cmedqa_dataset_val(name, pos_dir, neg_dir, file_path, neg_K, res_data):
229 | data = []
230 | with open(file_path, encoding='utf-8') as f:
231 | reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
232 | for id, row in enumerate(reader):
233 | text_a = row['sentence1']
234 | text_b = row['sentence2']
235 | data.append((text_a[:50], text_b[:320], [], [], [], 1))
236 |
237 | sample_num = len(data)
238 | res_data.extend(data)
239 |
240 | return res_data, sample_num
241 |
242 |
243 | def collate_fn(data):
244 | res_s_a = []
245 | res_s_b = []
246 | res_pos_logits = []
247 | res_neg_K = []
248 | res_neg_logits = []
249 | res_task_id = []
250 |
251 | for d in data[0]:
252 | res_s_a.append(d[0])
253 | res_s_b.append(d[1])
254 | res_pos_logits.append(d[2])
255 | res_neg_K.append(d[3])
256 | res_neg_logits.extend(d[4])
257 | res_task_id.append(int(d[5]))
258 |
259 | res_neg_K = [list(group) for group in zip(*res_neg_K)]
260 | res_neg_K = [e for l in res_neg_K for e in l]
261 |
262 |
263 | return res_s_a, res_s_b, torch.FloatTensor(res_pos_logits), res_neg_K, torch.FloatTensor(res_neg_logits), torch.LongTensor(res_task_id)
264 |
265 |
266 |
267 | class TrainDataset(Dataset):
268 |
269 | def __init__(self, tokenizer, pos_dir, neg_dir, datadir, names=None, batch_size=32, neg_K=8, process_index=0, num_processes=1, seed=2023):
270 | self.dataset_id_dict = DATASET_ID_DICT
271 | self.tokenizer = tokenizer
272 | self.data = []
273 | self.batch_size = batch_size
274 | self.sample_stas = dict()
275 | self.dataset_indices_range = dict()
276 | self.process_index = process_index
277 | self.num_processes = num_processes
278 | self.neg_K = neg_K
279 | self.deterministic_generator = np.random.default_rng(seed)
280 | names.sort(reverse=True)
281 | for name in names:
282 | if name in ['snli-zh']:
283 | if name == 'snli-zh':
284 | start_id = len(self.data)
285 | self.data, sample_num = load_text_dataset(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data, 'train')
286 | end_id = len(self.data)
287 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
288 | self.sample_stas[name] = sample_num
289 | elif name in ['sts']:
290 | if name == 'sts':
291 | start_id = len(self.data)
292 | self.data, sample_num = load_sts_dataset_train(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), datadir, self.neg_K, self.data)
293 | end_id = len(self.data)
294 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
295 | self.sample_stas[name] = sample_num
296 | elif name in ['t2','du', 'mmarco', 'cmedqa']:
297 | if name == 't2-05':
298 | start_id = len(self.data)
299 | self.data, sample_num = load_t2_dataset_train(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data)
300 | end_id = len(self.data)
301 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
302 | self.sample_stas[name] = sample_num
303 | if name == 'du':
304 | start_id = len(self.data)
305 | self.data, sample_num = load_du_dataset_train(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data)
306 | end_id = len(self.data)
307 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
308 | self.sample_stas[name] = sample_num
309 | if name == 'mmarco':
310 | start_id = len(self.data)
311 | self.data, sample_num = load_mmarco_dataset_train(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data)
312 | end_id = len(self.data)
313 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
314 | self.sample_stas[name] = sample_num
315 | if name == 'cmedqa':
316 | start_id = len(self.data)
317 | self.data, sample_num = load_cmedqa_dataset_train(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data)
318 | end_id = len(self.data)
319 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
320 | self.sample_stas[name] = sample_num
321 | else:
322 | logger.debug('Unknown dataset: {}'.format(name))
323 |
324 | self.create_epoch()
325 |
326 | def __len__(self):
327 | return self.steps_per_epoch * self.num_processes
328 |
329 | def create_epoch(self):
330 | epoch = []
331 | self.steps_per_epoch = 0
332 | for k, v in self.dataset_indices_range.items():
333 | dataset_range = np.arange(*v)
334 | num_batches, remainer = divmod(len(dataset_range), self.batch_size * self.num_processes)
335 | if remainer != 0:
336 | dataset_range = dataset_range[:num_batches * self.batch_size * self.num_processes]
337 | self.deterministic_generator.shuffle(dataset_range)
338 | batches = dataset_range.reshape(num_batches * self.num_processes, self.batch_size).tolist()
339 | epoch.extend(batches)
340 | self.steps_per_epoch += num_batches
341 | self.deterministic_generator.shuffle(epoch)
342 | self.epoch = epoch
343 | self.step = 0
344 |
345 |
346 | def __getitem__(self, index: int):
347 | if self.step > (self.steps_per_epoch - 1):
348 | self.step = 0
349 | batch_indices = self.epoch[self.step*self.num_processes+self.process_index]
350 | batch_data = np.array(self.data)[batch_indices].tolist()
351 | self.step += 1
352 |
353 | return batch_data
354 |
355 |
356 |
357 | class ValDataset(Dataset):
358 |
359 | def __init__(self, tokenizer, pos_dir, neg_dir, datadir, names=None, batch_size=32, neg_K=8, process_index=0, num_processes=1, seed=2023):
360 | self.dataset_id_dict = DATASET_ID_DICT
361 | self.tokenizer = tokenizer
362 | self.data = []
363 | self.batch_size = batch_size
364 | self.neg_K = neg_K
365 | self.sample_stas = dict()
366 | self.dataset_indices_range = dict()
367 | self.process_index = process_index
368 | self.num_processes = num_processes
369 | self.deterministic_generator = np.random.default_rng(seed)
370 | names.sort(reverse=True)
371 | for name in names:
372 | if name in ['snli-zh']:
373 | if name == 'snli-zh':
374 | start_id = len(self.data)
375 | self.data, sample_num = load_text_dataset(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data, 'validation')
376 | end_id = len(self.data)
377 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
378 | self.sample_stas[name] = sample_num
379 | elif name in ['sts']:
380 | if name == 'sts':
381 | start_id = len(self.data)
382 | self.data, sample_num = load_sts_dataset_val(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data)
383 | end_id = len(self.data)
384 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
385 | self.sample_stas[name] = sample_num
386 | elif name in ['t2', 'du', 'mmarco', 'cmedqa']:
387 | if name == 't2':
388 | start_id = len(self.data)
389 | self.data, sample_num = load_t2_dataset_val(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data)
390 | end_id = len(self.data)
391 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
392 | self.sample_stas[name] = sample_num
393 | if name == 'du':
394 | start_id = len(self.data)
395 | self.data, sample_num = load_du_dataset_val(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data)
396 | end_id = len(self.data)
397 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
398 | self.sample_stas[name] = sample_num
399 | if name == 'mmarco':
400 | start_id = len(self.data)
401 | self.data, sample_num = load_mmarco_dataset_val(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data)
402 | end_id = len(self.data)
403 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
404 | self.sample_stas[name] = sample_num
405 | if name == 'cmedqa':
406 | start_id = len(self.data)
407 | self.data, sample_num = load_cmedqa_dataset_val(name, os.path.join(pos_dir, 'PATH_TO_DATA'), os.path.join(neg_dir, 'PATH_TO_DATA'), os.path.join(datadir, 'PATH_TO_DATA'), self.neg_K, self.data)
408 | end_id = len(self.data)
409 | self.dataset_indices_range[self.dataset_id_dict[name]] = (start_id, end_id)
410 | self.sample_stas[name] = sample_num
411 | else:
412 | logger.debug('Unknown dataset: {}'.format(name))
413 | self.create_epoch()
414 |
415 |
416 | def __len__(self):
417 | return self.steps_per_epoch * self.num_processes
418 |
419 | def create_epoch(self):
420 | epoch = []
421 | self.steps_per_epoch = 0
422 | for k, v in self.dataset_indices_range.items():
423 | dataset_range = np.arange(*v)
424 | num_batches, remainer = divmod(len(dataset_range), self.batch_size * self.num_processes)
425 | if remainer != 0:
426 | dataset_range = dataset_range[:num_batches * self.batch_size * self.num_processes]
427 | self.deterministic_generator.shuffle(dataset_range)
428 | batches = dataset_range.reshape(num_batches * self.num_processes, self.batch_size).tolist()
429 | epoch.extend(batches)
430 | self.steps_per_epoch += num_batches
431 | self.deterministic_generator.shuffle(epoch)
432 | self.epoch = epoch
433 | self.step = 0
434 |
435 |
436 | def __getitem__(self, index: int):
437 |
438 | if self.step > self.steps_per_epoch - 1:
439 | self.step = 0
440 | batch_indices = self.epoch[self.step*self.num_processes+self.process_index]
441 | batch_data = np.array(self.data)[batch_indices].tolist()
442 | self.step += 1
443 | return batch_data
444 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import warnings
4 | import json
5 | import logging
6 | import argparse
7 | import random
8 | import time
9 | import tracemalloc
10 | from collections import defaultdict
11 | from copy import deepcopy
12 | import deepspeed
13 | import transformers
14 | import torch
15 | import torch.nn as nn
16 | import torch.distributed as dist
17 | from transformers import AutoTokenizer
18 | from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
19 | from torch.utils.data import DataLoader, Dataset, RandomSampler
20 | from torch.utils.data.distributed import DistributedSampler
21 | from tqdm import tqdm, trange
22 | from torch.utils.tensorboard import SummaryWriter
23 | from dataset.dataset import *
24 | from model.pro_model import *
25 | from utils.common_utils import *
26 | logging.getLogger().setLevel(logging.INFO)
27 | warnings.filterwarnings('ignore')
28 |
29 |
30 | def cal_loss_in_batch(args, student_logits, temperature, criterion):
31 |
32 | bs = student_logits.size(0)
33 | logits = student_logits/temperature
34 | labels = torch.arange(bs, device=logits.device)
35 | loss_bs = criterion(logits, labels)
36 |
37 | return (loss_bs.sum())/ (bs * bs)
38 |
39 |
40 | def cal_loss_hardneg(args, teacher_logits, student_logits, temperature_teacher, temperature, nll_criterion):
41 |
42 | loss_hardneg_weight = args.alpha
43 |
44 | def softmax(X, temp):
45 | X = (X/temp).exp()
46 | res = X / (X.sum(-1, keepdims=True)+1e-20)
47 | return res
48 |
49 | bs = teacher_logits.size(0)
50 | neg_K = teacher_logits.size(1)-1
51 | teacher_logits = softmax(teacher_logits, temperature_teacher)[:,:, 0]
52 | teacher_logits[:, 1:] = 1 - teacher_logits[:, 1:]
53 | inputs = (softmax(student_logits*teacher_logits, temperature)).log()
54 | labels = torch.zeros(bs, dtype=torch.long, device=student_logits.device)
55 | loss_bs = nll_criterion(inputs, labels)
56 |
57 |
58 | loss_bs = loss_bs * loss_hardneg_weight
59 | return loss_bs.sum() / (bs * neg_K)
60 |
61 |
62 | def cal_loss_rd(args, teacher_logits, student_logits, teacher_temperature):
63 |
64 | loss_pearson_weight = args.beta
65 |
66 | def softmax(X, temp):
67 | X = (X/temp).exp()
68 | res = X / (X.sum(-1, keepdims=True)+1e-20)
69 | return res
70 |
71 | def pearsonr(x,y,batch_first=True):
72 | assert x.shape == y.shape
73 | if batch_first:
74 | dim = -1
75 | else:
76 | dim = 0
77 | assert x.shape[dim] > 1
78 | centered_x = x - x.mean(dim=dim, keepdim=True)
79 | centered_y = y - y.mean(dim=dim, keepdim=True)
80 | covariance = (centered_x * centered_y).sum(dim=dim, keepdim=True)
81 | bessel_corrected_covariance = covariance / (x.shape[dim] - 1)
82 | x_std = x.std(dim=dim, keepdim=True)
83 | y_std = y.std(dim=dim, keepdim=True)
84 | corr = bessel_corrected_covariance / ((x_std * y_std)+1e-8)
85 | return corr
86 |
87 |
88 |
89 | bs = student_logits.size(0)
90 | teacher_logits = softmax(teacher_logits, teacher_temperature)[:,:, 0]
91 | spearson = pearsonr(student_logits, teacher_logits).squeeze()
92 |
93 | loss_bs = 1 - spearson
94 |
95 | loss_bs = loss_bs * loss_pearson_weight
96 |
97 | return loss_bs.sum() / bs
98 |
99 |
100 |
101 | def cal_loss_rd2(args, teacher_logits_pos_hardneg, teacher_logits_pos_inbatch, teacher_temperature, student_logits_pos_hardneg, student_logits_pos_inbatch, sigmoid, scale_param):
102 |
103 | loss_bpr_weight = args.gamma
104 |
105 | def softmax(X, temp):
106 | X = (X/temp).exp()
107 | res = X / (X.sum(-1, keepdims=True)+1e-20)
108 | return res
109 |
110 |
111 | teacher_logits_pos_hardneg = softmax(teacher_logits_pos_hardneg, teacher_temperature)[:,:, 0]
112 | teacher_logits_pos_inbatch = softmax(teacher_logits_pos_inbatch, teacher_temperature)[:,:, 0]
113 |
114 | bs = student_logits_pos_hardneg.size(0)
115 | neg_K = student_logits_pos_hardneg.size(1)-1
116 | inbatch = student_logits_pos_inbatch.size(1)-1
117 | student_logits_hardneg = student_logits_pos_hardneg[:, 1:]
118 | eye = torch.eye(bs, dtype=torch.bool)
119 | student_logits_inbatch = student_logits_pos_inbatch[~eye].reshape(bs, -1)
120 | loss_hardneg_inbatch = -((sigmoid(student_logits_hardneg.view(bs, neg_K, 1).expand(-1, -1, inbatch).reshape(bs, -1) - student_logits_inbatch.unsqueeze(1).expand(-1, neg_K,-1).reshape(bs, -1))+1e-8).log())
121 | weight_hardneg_inbatch = teacher_logits_hardneg.repeat_interleave(inbatch, dim=1) - teacher_logits_inbatch.repeat((1, neg_K))
122 | weight_hardneg_inbatch = torch.clamp(weight_hardneg_inbatch, min=0) / scale_param
123 | loss_bs = (loss_hardneg_inbatch * weight_hardneg_inbatch).sum(-1)
124 | loss_bs = loss_bs * loss_bpr_weight
125 |
126 | return loss_bs.sum() / (bs * neg_K * inbatch)
127 |
128 |
129 | def cal_feat_loss(args, teacher_feat_cos, student_feature_pos_hardneg):
130 |
131 | loss_feat_weight = args.eta
132 | neg_K = teacher_feat_cos.size(1)
133 | student_feature_pos_hardneg = student_feature_pos_hardneg.transpose(0, 1)
134 | student_feature_pos_hardneg = student_feature_pos_hardneg / student_feature_pos_hardneg.norm(dim=-1, keepdim=True)
135 | student_feat_cos = torch.matmul(student_feature_pos_hardneg, student_feature_pos_hardneg.transpose(-2, -1))
136 | loss_bs = ((teacher_feat_cos - student_feat_cos) ** 2).sum((-1,-2))
137 |
138 | loss_bs = loss_bs * loss_feat_weight
139 |
140 | return loss_bs.sum() / (neg_K * neg_K)
141 |
142 |
143 | def str2bool(v):
144 | return v.lower() in ('yes', 'true', 't', '1')
145 |
146 | def main():
147 |
148 | parser = argparse.ArgumentParser()
149 | parser.add_argument('--base_model_dir', default='/mnt/user/415350/download_models/Qwen-7B-Chat', type=str, help='Model directory')
150 | parser.add_argument('--train_data_list', nargs='+')
151 | parser.add_argument('--pos_dir', default='PATH_TO_POS_LOGITS', type=str)
152 | parser.add_argument('--neg_dir', default='PATH_TO_HARDNEG_LOGITS', type=str)
153 | parser.add_argument('--data_dir', default='', type=str)
154 | parser.add_argument('--inbatch_pkl_path_dir', default='PATH_TO_INBATCH_LOGITS_PKL')
155 | parser.add_argument('--feature_pkl_path_dir', default='PATH_TO_FEATURE_PKL')
156 | parser.add_argument('--batch_size', default=32, type=int, help='bs')
157 | parser.add_argument('--neg_K', default=8, type=int, help='num of hard negs')
158 | parser.add_argument('--num_heads', default=32, type=int, help='num_heads of pma')
159 | parser.add_argument('--hidden_dim', default=512, type=int, help='hidden dim of my mlp')
160 | parser.add_argument('--output_dim', default=1, type=int, help='output dim of my mlp')
161 | parser.add_argument('--ln', default=True, type=str2bool, help='layer norm for pma')
162 | parser.add_argument('--norm', default=False, type=str2bool, help='norm after sentence pooling')
163 | parser.add_argument('--num_epochs', default=5, type=int, help='training epochs')
164 | parser.add_argument('--padding_side', default='right', type=str, help='padding side')
165 | parser.add_argument('--max_seq_length', default=250, type=int, help='max_seq_len')
166 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
167 | parser.add_argument('--alpha', default=1, type=float, help='trade-off param')
168 | parser.add_argument('--beta', default=1, type=float, help='trade-off param')
169 | parser.add_argument('--gamma', default=0.01, type=float, help='trade-off param')
170 | parser.add_argument('--eta', default=0.001, type=float, help='trade-off param')
171 | parser.add_argument('--temperature_in_batch', default=1, type=float, help='temperature in in-batch')
172 | parser.add_argument('--temperature_hardneg', default=1, type=float, help='temperature in hardneg')
173 | parser.add_argument('--temperature_teacher_hardneg', default=1, type=float, help='temperature in teacher logits')
174 | parser.add_argument('--scale_param', default=1, type=float, help='scale param')
175 | parser.add_argument('--log_interval', default=20, type=int)
176 | parser.add_argument('--eval_interval', default=200, type=int)
177 | parser.add_argument('--tb_dir', default='PATH_TO_TENSORBOARD_PATH', type=str)
178 | parser.add_argument('--patience', default=5, type=int)
179 | parser.add_argument('--num_ckpt', default=5, type=int)
180 | parser.add_argument('--training_log', default='PATH_TO_TRAINING_LOG')
181 | parser.add_argument('--output_dir', default='PATH_TO_OUTPUT_MODEL', type=str, help='Model output directory')
182 | parser.add_argument('--weight_decay', default=0.01, type=float, help='weight decay')
183 | parser.add_argument('--gradient_clipping', default=1.0, type=float, help='max_grad_norm')
184 | parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help='gradient accumulation steps')
185 | parser.add_argument('--seed', default=2023, type=int)
186 | parser.add_argument('--bf16', default=True, type=str2bool)
187 | parser.add_argument('--verbose', default=True, type=str2bool)
188 | parser.add_argument('--device', default='cuda', type=str)
189 | parser.add_argument('--local_rank', type=int, default=-1, help='ds')
190 | parser.add_argument('--global_rank', type=int, default=-1, help='ds')
191 | parser = deepspeed.add_config_arguments(parser)
192 | args = parser.parse_args()
193 | args.world_size = int(os.getenv('WORLD_SIZE', '0'))
194 |
195 | sigmoid = nn.Sigmoid()
196 | tanh = nn.Tanh()
197 |
198 | os.makedirs(args.output_dir, exist_ok=True)
199 | logging.basicConfig(filename=f'{arg.training_log}'))
200 |
201 | if args.seed is not None:
202 | set_seed(args.seed)
203 | transformers.set_seed(args.seed)
204 |
205 | micro_bs = args.batch_size
206 |
207 | model = Mymodel(model_name_or_path=args.base_model_dir,
208 | alias=None,
209 | max_seq_length=args.max_seq_length,
210 | args=args)
211 | model.plm_model.gradient_checkpointing_enable()
212 |
213 | summary_writer = SummaryWriter(log_dir=args.tb_dir)
214 |
215 | train_data_flag = False
216 | lora_config = LoraConfig(
217 | r=8,
218 | lora_alpha=8,
219 | target_modules=['c_attn', 'c_proj', 'w1', 'w2'],
220 | layers_to_transform=list(range(0, 32)),
221 | lora_dropout=0.05,
222 | bias="none",
223 | inference_mode=False,
224 | task_type=TaskType.CAUSAL_LM
225 | )
226 | model.plm_model = get_peft_model(model.plm_model, lora_config)
227 |
228 | update_parameters = filter(lambda p: p.requires_grad, model.parameters())
229 | param_optimizer = list([(n,p) for n,p in model.named_parameters() if p.requires_grad])
230 |
231 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
232 | optimizer_grouped_parameters = [
233 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
234 | 'lr': args.lr, 'weight_decay': args.weight_decay, 'betas': [0.8,0.999], 'eps': 1e-6, 'name':'d'},
235 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
236 | 'lr': args.lr, 'weight_decay': 0.0, 'betas': [0.8,0.999], 'eps': 1e-6, 'name':'nd'}]
237 |
238 | ds_config = {
239 | "bfloat16": {
240 | "enabled": args.bf16
241 | },
242 | "zero_optimization": {
243 | "stage": 2,
244 | "offload_optimizer": {
245 | "device": "cpu",
246 | "pin_memory": True
247 | },
248 | "allgather_partitions": True,
249 | "allgather_bucket_size": 2e8,
250 | "overlap_comm": True,
251 | "reduce_scatter": True,
252 | "reduce_bucket_size": 2e8,
253 | "contiguous_gradients": True
254 | },
255 | "gradient_accumulation_steps": args.gradient_accumulation_steps,
256 | "gradient_clipping": args.gradient_clipping,
257 | "train_batch_size": args.world_size,
258 | "train_micro_batch_size_per_gpu": 1,
259 | "steps_per_print": 1e5
260 | }
261 |
262 | fake_bs = ds_config['train_micro_batch_size_per_gpu']
263 | optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(optimizer_grouped_parameters)
264 | scheduler = deepspeed.runtime.lr_schedules.WarmupLR(optimizer, warmup_min_lr=[0,0], warmup_max_lr=[args.lr,args.lr],
265 | warmup_num_steps=1000)
266 |
267 | model_engine, optimizer, _, scheduler = deepspeed.initialize(args=args, model=model, model_parameters=update_parameters, optimizer=optimizer, lr_scheduler=scheduler, config=ds_config)
268 | device = torch.device(args.local_rank)
269 | args.device = device
270 | args.global_rank = torch.distributed.get_rank()
271 |
272 | train_dataset = TrainDataset(model.tokenizer, pos_dir=args.pos_dir, neg_dir=args.neg_dir, datadir=args.data_dir, names=args.train_data_list, batch_size=micro_bs, neg_K=args.neg_K, process_index=args.global_rank, num_processes=args.world_size)
273 | val_dataset = ValDataset(model.tokenizer, pos_dir=args.pos_dir, neg_dir=args.neg_dir, datadir=args.data_dir, names=args.train_data_list, batch_size=micro_bs, neg_K=args.neg_K, process_index=args.global_rank, num_processes=args.world_size)
274 |
275 | if args.global_rank == -1:
276 | train_sampler = RandomSampler(train_dataset)
277 | val_sampler = RandomSampler(val_dataset)
278 | else:
279 | train_sampler = DistributedSampler(train_dataset)
280 | val_sampler = DistributedSampler(val_dataset)
281 | train_dataloader = DataLoader(train_dataset, batch_size=fake_bs, shuffle=False, sampler=train_sampler,collate_fn=collate_fn, num_workers=0)
282 | val_dataloader = DataLoader(val_dataset, batch_size=fake_bs, shuffle=False, sampler=val_sampler,collate_fn=collate_fn, num_workers=0)
283 | if len(train_dataset) > 0:
284 | train_data_flag = True
285 |
286 | if not train_data_flag:
287 | raise ValueError("Error, train_file|use_hf_dataset must be specified")
288 |
289 | all_dataset_id = train_dataset.dataset_id_dict
290 | all_dataset_id_reverse = {v:k for k, v in train_dataset.dataset_id_dict.items()}
291 | rel_dataset_id = [all_dataset_id[dataset_name] for dataset_name in args.train_data_list]
292 | os.makedirs(args.output_dir, exist_ok=True)
293 |
294 | train_loader_size = len(train_dataloader)
295 | val_loader_size = len(val_dataloader)
296 |
297 | criterion = nn.CrossEntropyLoss(reduction='none')
298 | nll_criterion = nn.NLLLoss(reduction='none')
299 |
300 | global_step = 0
301 | best_eval_metric = 0
302 | trained_epochs = 0
303 | min_reduce_loss_eval = float('inf')
304 | best_epoch = 0
305 | stop = 0
306 |
307 | teacher_feature_cos_dict = load_pickle(args.feature_pkl_path_dir)
308 | teacher_inbatch = load_pickle(args.inbatch_pkl_path_dir)
309 |
310 | reduce_loss = 0
311 | reduce_loss_eval = 0
312 | reduce_loss_in_batch = 0
313 | reduce_loss_in_batch_eval = 0
314 | reduce_loss_hardneg = 0
315 | reduce_loss_rd = 0
316 | reduce_loss_rd2 = 0
317 | reduce_loss_feat = 0
318 | reduce_inbatch_sample_num = {}
319 |
320 |
321 | for current_epoch in trange(int(args.num_epochs), desc="Epoch", disable=(args.global_rank!=0), mininterval=0):
322 | if stop >= args.patience:
323 | logging.info(f'Early Stop at {current_epoch+1}-th epoch {global_step}-th step')
324 | logging.info(f'Model trained!\nThe best model at {best_epoch+1}-th epoch {best_step}-th step')
325 | break
326 | torch.cuda.empty_cache()
327 | model_engine.train()
328 |
329 | loss_epoch_eval = 0
330 |
331 | batch_iterator = tqdm(train_dataloader,
332 | desc=f"Running Epoch {current_epoch + 1} of {args.num_epochs}",
333 | disable=(args.global_rank!=0),
334 | mininterval=0)
335 | for step, batch in enumerate(batch_iterator):
336 | sentence_a, sentence_b, logits_teacher_pos, sentence_hardneg, logits_teacher_hardneg, task_id = batch
337 | sentence_all = sentence_a + sentence_b + sentence_hardneg
338 | bs = logits_teacher_pos.size(0)
339 | key = 'global_rank' + str(args.global_rank)
340 | logits_teacher_inbatch = teacher_logits_dict[key][step].to(device)
341 | feature_teacher_cos = teacher_feature_cos_dict[key][step].to(device)
342 |
343 | inputs_all = model.tokenizer(sentence_all, padding='max_length', max_length=args.max_seq_length, truncation=True, return_tensors='pt')
344 | inputs_all = inputs_all.to(device)
345 | task_id = task_id.to(device)
346 | logits_student_in_batch, logits_student_hardneg, rep_student_pos_hardneg = model_engine(inputs_all, task_id, 'train')
347 |
348 | loss_in_batch = cal_loss_in_batch(args, logits_student_in_batch, args.temperature_in_batch, criterion)
349 | logits_teacher_pos = logits_teacher_pos.to(args.device)
350 | logits_teacher_hardneg = logits_teacher_hardneg.reshape(micro_bs, args.neg_K, 2).to(args.device)
351 | logits_teacher_hardneg = torch.cat([logits_teacher_pos.unsqueeze(1), logits_teacher_hardneg], dim=1)
352 | loss_hardneg = cal_loss_hardneg(args, logits_teacher_hardneg, logits_student_hardneg, args.temperature_teacher_hardneg, args.temperature_hardneg, nll_criterion)
353 |
354 | loss_rd = cal_loss_rd(args, logits_teacher_hardneg, logits_student_hardneg, args.temperature_teacher_hardneg)
355 |
356 | loss_rd2 = cal_loss_rd2(args, logits_teacher_hardneg, logits_teacher_inbatch, args.temperature_teacher_hardneg, logits_student_hardneg, logits_student_in_batch, sigmoid, args.scale_param)
357 |
358 | loss_feat = cal_feat_loss(args, feature_teacher_cos, rep_student_pos_hardneg)
359 |
360 | loss_batch = loss_in_batch + loss_hardneg + loss_outer_rd + loss_rd + loss_feat
361 | if args.verbose:
362 | batch_iterator.set_description(
363 | f"Epoch: {current_epoch + 1}/{args.num_epochs}, Batch:{step}/{len(train_dataloader)}, Loss: {loss_batch:9.4f}")
364 |
365 | model_engine.backward(loss_batch)
366 | model_engine.step()
367 |
368 | if (step + 1) % args.gradient_accumulation_steps == 0:
369 | global_step += 1
370 |
371 | reduce_loss += loss_batch.detach()
372 | reduce_loss_in_batch += loss_in_batch.detach()
373 | reduce_loss_hardneg += loss_hardneg.detach()
374 | reduce_loss_rd += loss_rd.detach()
375 | reduce_loss_rd2 += loss_rd2.detach()
376 | reduce_loss_feat += loss_feat.detach()
377 |
378 | if global_step % args.log_interval == 0:
379 | dist.all_reduce(reduce_loss, op=dist.ReduceOp.SUM)
380 | dist.all_reduce(reduce_loss_in_batch, op=dist.ReduceOp.SUM)
381 | dist.all_reduce(reduce_loss_hardneg, op=dist.ReduceOp.SUM)
382 | dist.all_reduce(reduce_loss_rd, op=dist.ReduceOp.SUM)
383 | dist.all_reduce(reduce_loss_rd2, op=dist.ReduceOp.SUM)
384 | dist.all_reduce(reduce_loss_feat, op=dist.ReduceOp.SUM)
385 |
386 | reduce_loss = reduce_loss.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size)
387 | reduce_loss_in_batch = reduce_loss_in_batch.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size)
388 | reduce_loss_hardneg = reduce_loss_hardneg.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size)
389 | reduce_loss_rd = reduce_loss_rd.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size)
390 | reduce_loss_rd2 = reduce_loss_rd2.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size)
391 | reduce_loss_feat = reduce_loss_feat.item() / (args.gradient_accumulation_steps * args.log_interval * args.world_size)
392 |
393 | if args.global_rank == 0:
394 | train_log_dict = {}
395 | train_log_dict['loss_overall'] = reduce_loss
396 | train_log_dict = {'loss_inbatch':reduce_loss_in_batch}
397 | train_log_dict['loss_hardneg'] = reduce_loss_hardneg
398 | train_log_dict['loss_rd'] = reduce_loss_rd
399 | train_log_dict['loss_rd2'] = reduce_loss_rd2
400 | train_log_dict['loss_feat'] = reduce_loss_feat
401 | write_tensorboard(summary_writer, train_log_dict, global_step)
402 |
403 | reduce_loss = 0
404 | reduce_loss_hardneg = 0
405 | reduce_loss_rd = 0
406 | reduce_loss_rd2 = 0
407 | reduce_loss_feat = 0
408 | reduce_loss_in_batch = 0
409 |
410 | if global_step % args.eval_interval == 0:
411 | model_engine.eval()
412 | batch_iterator_eval = tqdm(val_dataloader,
413 | disable=(args.global_rank!=0),
414 | mininterval=0)
415 |
416 | with torch.no_grad():
417 | for step, batch in enumerate(batch_iterator_eval):
418 | sentence_a, sentence_b, _, _, _, task_id = batch
419 | sentence_all = sentence_a + sentence_b
420 | bs = dataset_id.size(0)
421 |
422 | key = 'global_rank' + str(args.global_rank)
423 |
424 | inputs_all = model.tokenizer(sentence_all, padding='max_length', max_length=args.max_seq_length, truncation=True, return_tensors='pt')
425 |
426 | inputs_all = inputs_all.to(device)
427 | task_id = task_id.to(device)
428 | logits_student_in_batch_eval, _, _ = model_engine(inputs_all, task_id, 'eval')
429 |
430 | loss_in_batch_dict_eval = cal_loss_in_batch(args, logits_student_in_batch_eval, args.temperature_in_batch, criterion)
431 |
432 | loss_batch_eval = loss_in_batch.detach()
433 | if args.verbose:
434 | batch_iterator_eval.set_description(
435 | f"Epoch: {current_epoch + 1}/{args.num_epochs}, Batch:{step}/{len(val_dataloader)}, Loss: {loss_batch_eval:9.4f}")
436 |
437 |
438 | reduce_loss_eval += loss_batch_eval
439 |
440 | dist.all_reduce(reduce_loss_eval, op=dist.ReduceOp.SUM)
441 | reduce_loss_eval = reduce_loss_eval.item() / (val_loader_size * args.world_size)
442 |
443 | if args.global_rank == 0:
444 | eval_log_dict = {'loss_eval':reduce_loss_eval}
445 | write_tensorboard(summary_writer, eval_log_dict, global_step)
446 |
447 | save_flag = False
448 |
449 | if stop >= args.patience:
450 | break
451 |
452 | if reduce_loss_eval <= min_reduce_loss_eval:
453 | min_reduce_loss_eval = reduce_loss_eval
454 | best_epoch = current_epoch
455 | best_step = global_step
456 | stop = 0
457 |
458 | path = args.output_dir
459 | start_name = 'checkpoint'
460 | current_step_num = global_step
461 | max_save_num = 2
462 | if args.global_rank == 0:
463 | print('removing')
464 | try:
465 | remove_earlier_ckpt(path, start_name, current_step_num, max_save_num)
466 | except:
467 | print('No ckpt to remove.')
468 | else:
469 | stop += 1
470 |
471 | if stop < args.num_ckpt:
472 | save_flag = True
473 |
474 |
475 | if save_flag:
476 | output_dir_current = os.path.join(args.output_dir, "checkpoint-{}-epoch-{}-{}".format(global_step, current_epoch+1, args.mark))
477 | client_sd = dict()
478 |
479 | save_model(model_engine, output_dir_current, client_state=client_sd)
480 |
481 | reduce_loss_eval = 0
482 | model_engine.train()
483 |
484 |
485 | if __name__ == '__main__':
486 | main()
487 |
--------------------------------------------------------------------------------