├── .DS_Store ├── assets └── overview.jpg ├── scripts ├── gen_advantageous_pos_for_R1.sh ├── eval_data.sh └── train_nq.sh ├── metrics.py ├── utils.py ├── create_datasets.py ├── gen_advantageous_pos_for_R1.py ├── requirements.txt ├── README.md ├── eval_data.py └── kd ├── customized_kd_trainer.py ├── train_kd.py ├── customized_sft_dataset.py └── loss_design.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/Pos2Distill/HEAD/.DS_Store -------------------------------------------------------------------------------- /assets/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AMAP-ML/Pos2Distill/HEAD/assets/overview.jpg -------------------------------------------------------------------------------- /scripts/gen_advantageous_pos_for_R1.sh: -------------------------------------------------------------------------------- 1 | total_doc=20 2 | num_gpus=8 3 | # export CUDA_VISIBLE_DEVICES=3 4 | dataset_name="nq" 5 | 6 | INPUT_PATH="nq-open-${total_doc}_total_documents.jsonl.gz" 7 | model_name=Mistral-7B-Instruct-v0.3 8 | dataset_name=nq 9 | python gen_advantageous_pos_for_R1.py \ 10 | --input-path "$INPUT_PATH" \ 11 | --model $model_name \ 12 | --output-path raw_data/$dataset_name \ 13 | --max-prompt-length 32768 \ 14 | --max-new-tokens 100 \ 15 | --num_gpus "$num_gpus" \ 16 | --total_doc "$total_doc" \ 17 | --cache_dir $cache_dir \ 18 | --gold_doc 0 19 | 20 | 21 | model_name=Mistral-7B-Instruct-v0.3 22 | dataset_name=nq 23 | python create_datasets.py \ 24 | --model_name $model_name \ 25 | --position_sample_num 4 \ 26 | --example_num 400 \ 27 | --dataset_name $dataset_name -------------------------------------------------------------------------------- /scripts/eval_data.sh: -------------------------------------------------------------------------------- 1 | total_docs=20 2 | num_gpus=8 3 | # export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | # export CUDA_VISIBLE_DEVICES=4,5,6,7 5 | 6 | random_num=4 7 | strengthen=1 8 | INPUT_PATH="nq-open-${total_docs}_total_documents.jsonl.gz" 9 | # model_name=Mistral-7B-Instruct-v0.3_20total_docs_filter_4random_1strengthen_300_kd0.0_lm0.0_rank0.0_adaptive1.0_1.0 10 | # model_name=Mistral-7B-Instruct-v0.3_20total_docs_filter_4random_1strengthen_400_kd0.0_lm0.0_rank0.0_adaptive1.0_1.0 11 | # /mnt/workspace/wangyifei/miniconda3/envs/openrlhf/bin/python eval_data.py \ 12 | # --input-path "$INPUT_PATH" \ 13 | # --model $model_name \ 14 | # --output-path evaluate \ 15 | # --sample_num 500 \ 16 | # --max-prompt-length 32768 \ 17 | # --max-new-tokens 100 \ 18 | # --num_gpus "$num_gpus" \ 19 | # --total_doc "$total_docs" \ 20 | 21 | model_name=Mistral-7B-Instruct-v0.3 22 | INPUT_PATH="webq_dev.jsonl.gz" 23 | total_docs=20 24 | num_gpus=8 25 | python eval_data.py \ 26 | --input-path "$INPUT_PATH" \ 27 | --model $model_name \ 28 | --output-path evaluate \ 29 | --sample_num 500 \ 30 | --max-prompt-length 32768 \ 31 | --max-new-tokens 100 \ 32 | --num_gpus "$num_gpus" \ 33 | --total_doc "$total_docs" \ 34 | 35 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # Code for metrics comes from beam_retriever/blob/main/gpt_turbo_exp.py 2 | 3 | import re 4 | import string 5 | import collections 6 | 7 | def normalize_answer(s): 8 | 9 | def remove_articles(text): 10 | return re.sub(r'\b(a|an|the)\b', ' ', text) 11 | 12 | def white_space_fix(text): 13 | return ' '.join(text.split()) 14 | 15 | def remove_punc(text): 16 | exclude = set(string.punctuation) 17 | return ''.join(ch for ch in text if ch not in exclude) 18 | 19 | def lower(text): 20 | return text.lower() 21 | 22 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 23 | 24 | def get_tokens(s): 25 | if not s: 26 | return [] 27 | return normalize_answer(s).split() 28 | 29 | def compute_exact(a_gold, a_pred): 30 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 31 | 32 | def compute_subem(a_gold, a_pred): 33 | return int(normalize_answer(a_gold) in normalize_answer(a_pred)) 34 | 35 | def compute_f1(a_gold, a_pred): 36 | gold_toks = get_tokens(a_gold) 37 | pred_toks = get_tokens(a_pred) 38 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 39 | num_same = sum(common.values()) 40 | if len(gold_toks) == 0 or len(pred_toks) == 0: 41 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 42 | return int(gold_toks == pred_toks) 43 | if num_same == 0: 44 | return 0 45 | precision = 1.0 * num_same / len(pred_toks) 46 | recall = 1.0 * num_same / len(gold_toks) 47 | f1 = (2 * precision * recall) / (precision + recall) 48 | return f1 49 | 50 | # metrics for attribution 51 | 52 | def compute_attr_metrics(response, ground_truth): 53 | attr_matches = re.findall(r'\[(\d+)\]', response) 54 | attr_predicted = [int(match) for match in attr_matches] 55 | 56 | predicted_set = set(attr_predicted) 57 | ground_truth_set = set(ground_truth) 58 | 59 | true_positives = len(predicted_set & ground_truth_set) 60 | precision = true_positives / len(predicted_set) if predicted_set else 0 61 | recall = true_positives / len(ground_truth_set) if ground_truth_set else 0 62 | f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 63 | 64 | return precision, recall, f1 65 | 66 | def best_subspan_em(prediction, ground_truths) -> float: 67 | normalized_prediction = normalize_answer(prediction) 68 | 69 | for ground_truth in ground_truths: 70 | normalized_ground_truth = normalize_answer(ground_truth) 71 | if normalized_ground_truth.lower() in normalized_prediction.lower(): 72 | return 1.0 73 | return 0.0 -------------------------------------------------------------------------------- /scripts/train_nq.sh: -------------------------------------------------------------------------------- 1 | ############################ RUNNING CONFIG ############################ 2 | # export LD_PRELOAD="/usr/lib64/libjemalloc.so.2" 3 | # sudo mount -o size=20480M -o nr_inodes=1000000 -o noatime,nodiratime -o remount /dev/shm 4 | # if [ -f "/opt/conda/etc/profile.d/conda.sh" ]; then 5 | # . "/opt/conda/etc/profile.d/conda.sh" 6 | # else 7 | # export PATH="/opt/conda/bin:$PATH" 8 | # fi 9 | # conda activate /mnt/workspace/wangyifei/miniconda3/envs/openrlhf 10 | # workdir="/mnt/workspace/wangyifei/projects/self_training" 11 | # cd $workdir 12 | ############################ RUNNING CONFIG ############################ 13 | 14 | 15 | 16 | 17 | export CUDA_HOME=/usr/local/cuda-12.4 18 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64 19 | export PATH=${CUDA_HOME}/bin:${PATH} 20 | export NCCL_DEBUG=ERROR 21 | export NCCL_IB_DISABLE=0 22 | export NCCL_P2P_LEVEL=PIX 23 | export HF_ENDPOINT=https://hf-mirror.com 24 | # export CUDA_VISIBLE_DEVICES=0,1,2,3 25 | # export CUDA_VISIBLE_DEVICES=4,5,6,7 26 | #Meta-Llama-3-8B-Instruct Qwen1.5-7B-Chat 0.1 0.2 0.3 0.4 0.0 0.6 0.7 0.8 0.9 1.0 27 | # /mnt/workspace/wangyifei/projects/self_training/kd/eval_kl.sh 28 | 29 | 30 | 31 | strengthen=1 32 | total_docs=20 33 | dataset_name="nq" 34 | for i in 0 ; do 35 | for model_name in Mistral-7B-Instruct-v0.3; do 36 | for num in 400; do 37 | for kd_coef in "0.0,0.0,0.0,1.0" ; do 38 | for K in 4 ; do 39 | deepspeed --master_port 6666 \ 40 | kd/train_kd.py \ 41 | --max_len 32000 \ 42 | --dataset ${model_name}_${total_docs}total_docs_filter_${K}random_${strengthen}strengthen_${num}\ 43 | --input_key question \ 44 | --output_key oracle_answer \ 45 | --train_batch_size 32 \ 46 | --micro_train_batch_size 4 \ 47 | --max_samples 500000 \ 48 | --pretrain $model_name \ 49 | --teacher_model $model_name \ 50 | --save_path checkpoints/${dataset_name}\ 51 | --save_steps -1 \ 52 | --logging_steps 1 \ 53 | --eval_steps -1 \ 54 | --zero_stage 3 \ 55 | --max_epochs 1 \ 56 | --l2 0.01 \ 57 | --bf16 \ 58 | --flash_attn \ 59 | --kd_coef $kd_coef \ 60 | --learning_rate 3e-6 \ 61 | --teacher_offload \ 62 | --apply_chat_template \ 63 | --gradient_checkpointing \ 64 | --perserve 1.0 \ 65 | --dataset_name nq \ 66 | --use_tensorboard tensorBoard 67 | # --use_wandb "c0c629e7ba14b453e7da5da4ff86f33816c5cc6a" # > training_kl2.log 2>&1 68 | done 69 | done 70 | done 71 | done 72 | done 73 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import string 3 | import re 4 | from metrics import compute_exact,best_subspan_em 5 | from typing import List 6 | import regex 7 | import argparse 8 | import json 9 | import logging 10 | import statistics 11 | import sys 12 | from copy import deepcopy 13 | import os 14 | from tqdm import tqdm 15 | from xopen import xopen 16 | from pydantic.dataclasses import dataclass 17 | from typing import List, Optional, Tuple, Type, TypeVar 18 | logger = logging.getLogger(__name__) 19 | def read_xopen(file_path): 20 | data = [] 21 | with xopen(file_path) as fin: 22 | for line in fin: 23 | example = json.loads(line) 24 | data.append(example) 25 | return data 26 | def write_xopen(file_path,data): 27 | with xopen(file_path, "w") as f: 28 | for d in data: 29 | f.write(json.dumps(d) + "\n") 30 | 31 | def get_template_prompt(ins,model_name,tokenizer): 32 | # assert tokenizer.chat_template is not None 33 | context = ''.join([f"- Title: {doc['title']}\n{doc['text']}\n" for doc in ins["documents"]]) 34 | task_instruction = "Please write a high-quantify answer for the given question using only the provided search documents (some of which might be irrelevant)." 35 | prompt_message = f"{task_instruction}\n{context}\nQuestion: {ins['question']}\n" 36 | system_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." 37 | messages = [ 38 | {"role": "user", "content": prompt_message}, 39 | ] 40 | 41 | if tokenizer.chat_template is not None: 42 | prompt = tokenizer.apply_chat_template( 43 | messages, 44 | tokenize=False, 45 | add_generation_prompt=True 46 | ) 47 | ins["prompt"] = prompt 48 | return ins 49 | 50 | 51 | def get_metrics_for_example(example,METRICS): 52 | gold_answers = example["answers"] 53 | model_answer = example["current_answer"] 54 | # NOTE: we take everything up to the first newline, since otherwise models could hack 55 | model_answer = model_answer.strip("\n").split("\n")[0].strip() 56 | example_metrics = {} 57 | for (metric, metric_name) in METRICS: 58 | example_metrics[metric_name] = metric(prediction=model_answer, ground_truths=gold_answers) 59 | return (example_metrics, example) 60 | 61 | 62 | 63 | def evaluate_qa_data(input_path,output_path=None,sample_num = None): 64 | METRICS = [(best_subspan_em, "best_subspan_em"),] 65 | all_examples = [] 66 | with xopen(input_path) as fin: 67 | for line in tqdm(fin): 68 | input_example = json.loads(line) 69 | all_examples.append(input_example) 70 | if sample_num: 71 | all_examples = all_examples[:sample_num] 72 | # Compute normal metrics in parallel, if applicable 73 | logger.info("Computing metrics") 74 | all_example_metrics = [] 75 | for example in tqdm(all_examples): 76 | all_example_metrics.append(get_metrics_for_example(example,METRICS)) 77 | # Average metrics across examples 78 | for (_, metric_name) in METRICS: 79 | average_metric_value = statistics.mean( 80 | example_metrics[metric_name] for (example_metrics, _) in all_example_metrics 81 | ) 82 | print(f"{metric_name}: {average_metric_value}") 83 | logger.info(f"{metric_name}: {average_metric_value}") 84 | 85 | # summary_path = os.path.join(os.path.dirname(input_path),"A_metrics_summary.txt") 86 | # with xopen(summary_path,"a") as f: 87 | # f.write(f"{input_path.split('/')[-1].split('.jsonl.gz')[0]}\n{metric_name}: {average_metric_value}\n\n") 88 | if output_path: 89 | with xopen(output_path, "w") as f: 90 | for (example_metrics, example) in all_example_metrics: 91 | example_with_metrics = deepcopy(example) 92 | for metric_name, metric_value in example_metrics.items(): 93 | example_with_metrics[f"metric_{metric_name}"] = metric_value 94 | f.write(json.dumps(example_with_metrics) + "\n") 95 | return average_metric_value -------------------------------------------------------------------------------- /create_datasets.py: -------------------------------------------------------------------------------- 1 | from utils import best_subspan_em,read_xopen,write_xopen 2 | from xopen import xopen 3 | import json 4 | from random import sample 5 | import random 6 | from copy import deepcopy 7 | from collections import defaultdict 8 | from datasets import Dataset 9 | from collections import Counter 10 | import itertools 11 | import os 12 | import numpy as np 13 | import re 14 | from metrics import compute_exact 15 | from argparse import ArgumentParser 16 | random.seed(42) 17 | def best_subspan_em_musique(r,answers): 18 | match = re.search(r'the answer is[::]?\s*(.*)', r.lower()) 19 | if match is None: 20 | return int(0) 21 | a_pred = match.group(1) 22 | qa_em_score = 0 23 | for a_gold in answers: 24 | qa_em_score = max(qa_em_score, compute_exact(a_gold,a_pred)) 25 | return int(qa_em_score) 26 | 27 | def chunk_list(lst, chunk_size): 28 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 29 | 30 | 31 | 32 | def get_chunked_answer(model_answer): 33 | # chunked_answer = model_answer.strip("\n").strip(" ").split(".")[0] + "." 34 | chunked_answer = model_answer 35 | print(chunked_answer) 36 | return chunked_answer 37 | 38 | def process_example(example,position_sample_num,strengthen=2,total_doc=20,train_counter=Counter()): 39 | sample_examples = [] 40 | gold_doc = example["gold_document"] 41 | sample_pool = list(range(1,total_doc)) if strengthen!=0 else list(range(total_doc+1)) 42 | random_pos = sample(sample_pool,position_sample_num) 43 | print(f"iteration1: {random_pos}") 44 | 45 | 46 | 47 | 48 | # 加上位置 0 的加强数据 49 | sample_positions = random_pos + [0] * strengthen 50 | for gold_idx in sample_positions: 51 | new_example = deepcopy(example) 52 | new_example["distractors"] = random.sample(example["distractors"],len(example["distractors"])) 53 | new_example["distractors"].insert(gold_idx,gold_doc) 54 | new_example["documents"]=deepcopy(new_example["distractors"]) 55 | new_example["distractors"]=example["distractors"] 56 | new_example["gold_idx"] = gold_idx 57 | assert len(new_example["documents"])==total_doc 58 | assert new_example["documents"] != new_example["distractors"] 59 | new_example["oracle_answer"]=get_chunked_answer(example["oracle_answer"]) 60 | sample_examples.append(new_example) 61 | train_counter[gold_idx] += 1 62 | 63 | return sample_examples,train_counter 64 | 65 | 66 | # dict_keys(['question', 'answers', 'ctxs', 'nq_annotated_gold', 'gold_document', 'distractors', 'documents', 'prompt', 'model_answer']) 67 | def generate_dataset(model_name,position_sample_num,total_doc=20,strengthen=2,train_size=200,position=0,filter=True,dataset_name="nq"): 68 | datasets_path = f"datasets/{dataset_name}" 69 | file_path = f"raw_data/{dataset_name}" 70 | file_name = f"{model_name}_{total_doc}docs_{position}.jsonl.gz" 71 | file_path = os.path.join(file_path,file_name) 72 | position_statistics = [] 73 | examples = read_xopen(file_path) 74 | train_datasets = defaultdict(list) 75 | print(f"example nums: {len(examples)}") 76 | 77 | 78 | # process train_examples 79 | cnt = 0 80 | train_counter = Counter({i: 0 for i in range(0,total_doc+1)}) 81 | 82 | for idx,example in enumerate(examples): 83 | if cnt == train_size: 84 | break 85 | is_correct = best_subspan_em(get_chunked_answer(example["oracle_answer"]),example["answers"]) 86 | # is_correct = best_subspan_em_musique(get_chunked_answer(example["oracle_answer"]),example["answers"]) 87 | print(f"is_correct: {is_correct}") 88 | if filter: 89 | if int(is_correct)==0: 90 | continue 91 | if "predicted_label" not in example: 92 | example["predicted_label"] = "correct" 93 | if "predicted" not in example: 94 | example["predicted"] = example["answers"][0] 95 | if example["predicted_label"] == "incorrect": 96 | continue 97 | if not example["predicted"]: 98 | continue 99 | cnt+=1 100 | sample_examples,train_counter = process_example(example,position_sample_num,strengthen,total_doc,train_counter) 101 | assert len(sample_examples) == (position_sample_num + strengthen) 102 | train_datasets["data"].extend(sample_examples) 103 | # train_datasets["position_statistics"].append(random_pos) 104 | # train_positions = list(itertools.chain.from_iterable(train_datasets["position_statistics"])) 105 | 106 | print(f"Train filter generate datapoints number: {len(train_datasets['data'])}") 107 | train_last_idx = idx 108 | assert len(train_datasets["data"]) == train_size*(position_sample_num + strengthen) 109 | filter_name = "filter" if filter else "unfilter" 110 | train_datasets_path = os.path.join(datasets_path,f"{model_name}_{total_doc}total_docs_{filter_name}_{position_sample_num}random_{strengthen}strengthen_{train_size}") 111 | train = Dataset.from_list(train_datasets["data"]) 112 | train.save_to_disk(train_datasets_path) 113 | # save position_statistics 114 | print(f"{model_name}_{file_name}_{position_sample_num}_{strengthen}: {train_counter}") 115 | 116 | 117 | position_statistics = {f"train_{file_name}":train_counter} 118 | os.makedirs(datasets_path+"/position_statistics",exist_ok=True) 119 | with open(datasets_path+f"/position_statistics/{model_name}.json", "w") as f: 120 | json.dump(position_statistics, f,indent=1) 121 | return 122 | 123 | 124 | if __name__=="__main__": 125 | parser = ArgumentParser() 126 | parser.add_argument("--model_name", type=str, default="Mistral-7B-Instruct-v0.3") 127 | parser.add_argument("--position_sample_num", type=int, default=4) 128 | parser.add_argument("--example_num", type=int, default=250) 129 | parser.add_argument("--dataset_name", type=str, default="nq") 130 | args = parser.parse_args() 131 | generate_dataset(args.model_name,position_sample_num=args.position_sample_num,total_doc=20,strengthen=1,train_size=args.example_num,position=0,filter=True,dataset_name=args.dataset_name) 132 | 133 | -------------------------------------------------------------------------------- /gen_advantageous_pos_for_R1.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import AutoTokenizer, AutoModelForCausalLM 3 | import argparse 4 | import json 5 | import numpy as np 6 | import logging 7 | import pathlib 8 | import random 9 | import sys 10 | from copy import deepcopy 11 | import torch 12 | from tqdm import tqdm 13 | from transformers import AutoTokenizer,LlamaTokenizer 14 | from xopen import xopen 15 | from vllm import LLM, SamplingParams 16 | import pandas as pd 17 | import random 18 | import itertools 19 | from utils import * 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | 24 | 25 | def seed_everything(seed): 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | np.random.seed(seed) 29 | random.seed(seed) 30 | torch.backends.cudnn.benchmark = False 31 | torch.backends.cudnn.deterministic = True 32 | torch.cuda.manual_seed_all(seed) 33 | 34 | seed_everything(42) 35 | def get_prompt_at_gold_index(input_path,gold_doc,model_name,total_doc,tokenizer=None): 36 | examples = [] 37 | prompts = [] 38 | with xopen(input_path) as fin: 39 | for line in tqdm(fin): 40 | input_example = json.loads(line) 41 | gold_document = input_example["gold_document"] 42 | distractors = input_example["distractors"] 43 | random.shuffle(distractors) 44 | # Get the prediction for the input example 45 | 46 | shuffled_distractors = distractors[:] 47 | shuffled_distractors.insert(gold_doc, gold_document) 48 | 49 | ins = deepcopy(input_example) 50 | ins["documents"] = shuffled_distractors if total_doc >0 else [] 51 | assert len(ins["documents"])==total_doc 52 | prompt = get_template_prompt(ins,model_name,tokenizer)["prompt"] 53 | prompts.append(prompt) 54 | examples.append(ins) 55 | return prompts,examples 56 | 57 | def write_responses(output_path,examples, responses,prompts): 58 | with xopen(output_path, "w") as f: 59 | for (example,response,prompt) in zip(examples,responses,prompts): 60 | example["oracle_answer"] = response 61 | example["oracle_prompt"] = prompt 62 | f.write(json.dumps(example) + "\n") 63 | average_metric_value = evaluate_qa_data(input_path=output_path) 64 | print(f"acc when gold doc at 0: {average_metric_value}") 65 | return average_metric_value 66 | 67 | def main( 68 | input_path, 69 | model_name, 70 | temperature, 71 | top_p, 72 | num_gpus, 73 | max_new_tokens, 74 | max_prompt_length, 75 | output_path, 76 | gold_doc, 77 | total_doc, 78 | cache_dir 79 | ): 80 | os.makedirs(output_path, exist_ok=True) 81 | output_path_gold_index = os.path.join(output_path,f"{model_name}_{total_doc}docs_{gold_doc}.jsonl.gz") 82 | model_path = os.path.join(cache_dir,model_name) 83 | model = LLM( 84 | model=os.path.join(cache_dir,model_name), 85 | tensor_parallel_size=num_gpus, 86 | # load_format="safetensors", 87 | # load_format="pt", 88 | max_num_batched_tokens=max_prompt_length, 89 | ) 90 | sampling_params = SamplingParams(temperature=0.7, max_tokens=max_new_tokens,top_p=top_p) 91 | output_path_gold_index = os.path.join(output_path,f"{model_name}_{total_doc}docs_{gold_doc}.jsonl.gz") 92 | if "mistral" in model_name.lower(): 93 | prompts,examples = get_prompt_at_gold_index(input_path,total_doc,model_name,total_doc,model.get_tokenizer()) 94 | else: 95 | prompts,examples = get_prompt_at_gold_index(input_path,0,model_name,total_doc,model.get_tokenizer()) 96 | raw_responses = model.generate(prompts, sampling_params) 97 | responses = [output.outputs[0].text.strip() for output in raw_responses] 98 | write_responses(output_path_gold_index,examples,responses,prompts) 99 | return 100 | 101 | 102 | 103 | 104 | if __name__ == "__main__": 105 | logging.basicConfig(format="%(asctime)s - %(module)s - %(levelname)s - %(message)s", level=logging.INFO) 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument("--input-path", help="Path to data with questions and documents to use.", required=True) 108 | parser.add_argument( 109 | "--model", 110 | help="Model to use in generating responses", 111 | required=True, 112 | choices=[ 113 | "Mistral-7B-Instruct-v0.3", 114 | "Qwen1.5-7B-Chat", 115 | "Meta-Llama-3-8B-Instruct", 116 | "Qwen2.5-7B-Instruct" 117 | ] 118 | ) 119 | parser.add_argument("--temperature", help="Temperature to use in generation", type=float, default=0.6) 120 | parser.add_argument("--num-gpus", help="Number of GPUs to use", type=int, default=1) 121 | parser.add_argument("--cache_dir", help="Path to huggingface cache to use.") 122 | parser.add_argument("--output-path", help="Path to write output file of generated responses", required=True) 123 | parser.add_argument( 124 | "--max-new-tokens", 125 | help="Maximum number of new tokens to generate", 126 | type=int, 127 | default=100, 128 | ) 129 | parser.add_argument( 130 | "--max-prompt-length", 131 | help="Maximum number of tokens in the prompt. Longer prompts will be skipped.", 132 | type=int, 133 | default=100, 134 | ) 135 | parser.add_argument( 136 | "--top_p", 137 | help="top_p", 138 | type=float, 139 | default=0.95, 140 | ) 141 | parser.add_argument( 142 | "--gold_doc", 143 | type=int, 144 | default=0, 145 | ) 146 | parser.add_argument( 147 | "--num_gpus", 148 | type=int, 149 | default=0, 150 | ) 151 | parser.add_argument( 152 | "--total_doc", 153 | type=int, 154 | ) 155 | 156 | args = parser.parse_args() 157 | 158 | logger.info("running %s", " ".join(sys.argv)) 159 | main( 160 | args.input_path, 161 | args.model, 162 | args.temperature, 163 | args.top_p, 164 | args.num_gpus, 165 | args.max_new_tokens, 166 | args.max_prompt_length, 167 | args.output_path, 168 | args.gold_doc, 169 | args.total_doc, 170 | args.cache_dir 171 | ) 172 | logger.info("finished running %s", sys.argv[0]) 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==1.5.1 3 | aclpubcheck @ git+https://github.com/acl-org/aclpubcheck@a340fc0a1a7d1c9808f08ab1dab1d228f63af405 4 | aiohappyeyeballs==2.6.1 5 | aiohttp==3.11.13 6 | aiohttp-cors==0.7.0 7 | aiosignal==1.3.2 8 | airportsdata==20250224 9 | annotated-types==0.7.0 10 | anyio==4.9.0 11 | argon2-cffi==23.1.0 12 | argon2-cffi-bindings==21.2.0 13 | arrow==1.3.0 14 | astor==0.8.1 15 | asttokens==3.0.0 16 | async-lru==2.0.5 17 | async-timeout==5.0.1 18 | attrs==25.3.0 19 | babel==2.17.0 20 | beautifulsoup4==4.13.3 21 | bibtexparser==1.4.3 22 | bitsandbytes==0.45.3 23 | black==25.1.0 24 | blake3==1.0.4 25 | bleach==6.2.0 26 | cachetools==5.5.2 27 | certifi==2025.1.31 28 | cffi==1.17.1 29 | charset-normalizer==3.4.1 30 | click==8.1.8 31 | cloudpickle==3.1.1 32 | colorful==0.5.6 33 | comm==0.2.2 34 | compressed-tensors==0.9.1 35 | contourpy==1.3.1 36 | cryptography==45.0.3 37 | cycler==0.12.1 38 | datasets==3.3.2 39 | debugpy==1.8.13 40 | decorator==5.2.1 41 | deepspeed==0.16.4 42 | defusedxml==0.7.1 43 | depyf==0.18.0 44 | dill==0.3.8 45 | diskcache==5.6.3 46 | distlib==0.3.9 47 | distro==1.9.0 48 | docker-pycreds==0.4.0 49 | einops==0.8.1 50 | et_xmlfile==2.0.0 51 | exceptiongroup==1.2.2 52 | executing==2.2.0 53 | fastapi==0.115.12 54 | fastchat==0.1.0 55 | fastjsonschema==2.21.1 56 | filelock==3.18.0 57 | flake8==7.2.0 58 | flash-attn==2.7.0.post2 59 | fonttools==4.57.0 60 | fqdn==1.5.1 61 | frozenlist==1.5.0 62 | fschat==0.2.36 63 | fsspec==2024.6.1 64 | gguf==0.10.0 65 | gitdb==4.0.12 66 | GitPython==3.1.44 67 | google-api-core==2.24.2 68 | google-auth==2.38.0 69 | googleapis-common-protos==1.69.1 70 | grpcio==1.71.0 71 | h11==0.14.0 72 | hf_transfer==0.1.9 73 | hjson==3.1.0 74 | httpcore==1.0.7 75 | httptools==0.6.4 76 | httpx==0.28.1 77 | huggingface-hub==0.29.3 78 | idna==3.10 79 | importlib_metadata==8.6.1 80 | iniconfig==2.1.0 81 | interegular==0.3.3 82 | ipykernel==6.29.5 83 | ipython==8.34.0 84 | ipywidgets==8.1.5 85 | isal==1.7.2 86 | isoduration==20.11.0 87 | isort==6.0.1 88 | jedi==0.19.2 89 | Jinja2==3.1.4 90 | jiter==0.9.0 91 | json5==0.12.0 92 | jsonlines==4.0.0 93 | jsonpointer==3.0.0 94 | jsonschema==4.23.0 95 | jsonschema-specifications==2024.10.1 96 | jupyter==1.1.1 97 | jupyter-console==6.6.3 98 | jupyter-events==0.12.0 99 | jupyter-lsp==2.2.5 100 | jupyter_client==8.6.3 101 | jupyter_core==5.7.2 102 | jupyter_server==2.15.0 103 | jupyter_server_terminals==0.5.3 104 | jupyterlab==4.3.6 105 | jupyterlab_pygments==0.3.0 106 | jupyterlab_server==2.27.3 107 | jupyterlab_widgets==3.0.13 108 | kagglehub==0.3.12 109 | kiwisolver==1.4.8 110 | lark==1.2.2 111 | latexcodec==3.0.0 112 | lightning-utilities==0.14.0 113 | lm-format-enforcer==0.10.11 114 | loralib==0.1.2 115 | Markdown==3.7 116 | markdown-it-py==3.0.0 117 | MarkupSafe==2.1.5 118 | matplotlib==3.10.1 119 | matplotlib-inline==0.1.7 120 | mccabe==0.7.0 121 | mistral_common==1.5.4 122 | mistune==3.1.3 123 | modelscope==1.25.0 124 | mpmath==1.3.0 125 | msgpack==1.1.0 126 | msgspec==0.19.0 127 | multidict==6.1.0 128 | multiprocess==0.70.16 129 | mypy_extensions==1.1.0 130 | nbclient==0.10.2 131 | nbconvert==7.16.6 132 | nbformat==5.10.4 133 | nest-asyncio==1.6.0 134 | networkx==3.3 135 | nh3==0.2.21 136 | ninja==1.11.1.3 137 | notebook==7.3.3 138 | notebook_shim==0.2.4 139 | numpy==1.26.4 140 | nvidia-cublas-cu12==12.4.5.8 141 | nvidia-cuda-cupti-cu12==12.4.127 142 | nvidia-cuda-nvrtc-cu12==12.4.127 143 | nvidia-cuda-runtime-cu12==12.4.127 144 | nvidia-cudnn-cu12==9.1.0.70 145 | nvidia-cufft-cu12==11.2.1.3 146 | nvidia-curand-cu12==10.3.5.147 147 | nvidia-cusolver-cu12==11.6.1.9 148 | nvidia-cusparse-cu12==12.3.1.170 149 | nvidia-ml-py==12.570.86 150 | nvidia-nccl-cu12==2.21.5 151 | nvidia-nvjitlink-cu12==12.4.127 152 | nvidia-nvtx-cu12==12.4.127 153 | nvitop==1.4.2 154 | openai==1.75.0 155 | opencensus==0.11.4 156 | opencensus-context==0.1.3 157 | opencv-python-headless==4.11.0.86 158 | openpyxl==3.1.5 159 | openrlhf==0.6.1.post1 160 | optimum==1.24.0 161 | outlines==0.1.11 162 | outlines_core==0.1.26 163 | overrides==7.7.0 164 | packaging==24.2 165 | pandas==2.2.3 166 | pandocfilters==1.5.1 167 | parso==0.8.4 168 | partial-json-parser==0.2.1.1.post5 169 | pathspec==0.12.1 170 | pdfminer.six==20250327 171 | pdfplumber==0.11.6 172 | peft==0.14.0 173 | pexpect==4.9.0 174 | pillow==11.0.0 175 | platformdirs==4.3.6 176 | pluggy==1.5.0 177 | prometheus-fastapi-instrumentator==7.1.0 178 | prometheus_client==0.21.1 179 | prompt_toolkit==3.0.50 180 | propcache==0.3.0 181 | proto-plus==1.26.1 182 | protobuf==5.29.3 183 | psutil==7.1.2 184 | ptyprocess==0.7.0 185 | pure_eval==0.2.3 186 | py-cpuinfo==9.0.0 187 | py-spy==0.4.0 188 | pyarrow==19.0.1 189 | pyasn1==0.6.1 190 | pyasn1_modules==0.4.1 191 | pybtex==0.24.0 192 | pycodestyle==2.13.0 193 | pycountry==24.6.1 194 | pycparser==2.22 195 | pydantic==2.10.6 196 | pydantic_core==2.27.2 197 | pyflakes==3.3.2 198 | Pygments==2.19.1 199 | pylatexenc==2.10 200 | pynvml==12.0.0 201 | pyparsing==3.2.3 202 | pypdfium2==4.30.1 203 | pytest==8.3.5 204 | python-dateutil==2.9.0.post0 205 | python-dotenv==1.1.0 206 | python-json-logger==3.3.0 207 | pytz==2025.1 208 | PyYAML==6.0.2 209 | pyzmq==26.4.0 210 | ray==2.42.0 211 | rebiber==1.1.3 212 | referencing==0.36.2 213 | regex==2024.11.6 214 | requests==2.32.3 215 | rewardbench==0.1.3 216 | rfc3339-validator==0.1.4 217 | rfc3986-validator==0.1.1 218 | rich==14.0.0 219 | rpds-py==0.23.1 220 | rsa==4.9 221 | safetensors==0.5.3 222 | scipy==1.15.3 223 | seaborn==0.13.2 224 | Send2Trash==1.8.3 225 | sentencepiece==0.2.0 226 | sentry-sdk==2.22.0 227 | setproctitle==1.3.5 228 | six==1.17.0 229 | smart-open==7.1.0 230 | smmap==5.0.2 231 | sniffio==1.3.1 232 | soupsieve==2.6 233 | stack-data==0.6.3 234 | starlette==0.46.2 235 | sympy==1.13.1 236 | tabulate==0.9.0 237 | tensorboard==2.19.0 238 | tensorboard-data-server==0.7.2 239 | termcolor==3.1.0 240 | terminado==0.18.1 241 | tiktoken==0.6.0 242 | tinycss2==1.4.0 243 | tokenizers==0.21.1 244 | tomli==2.2.1 245 | torch==2.5.1+cu124 246 | torchaudio==2.5.1+cu124 247 | torchmetrics==1.6.3 248 | torchvision==0.20.1+cu124 249 | tornado==6.4.2 250 | tqdm==4.67.1 251 | traitlets==5.14.3 252 | transformers==4.48.3 253 | transformers-stream-generator==0.0.5 254 | triton==3.1.0 255 | trl==0.12.2 256 | tsv==1.2 257 | types-python-dateutil==2.9.0.20241206 258 | typing_extensions==4.12.2 259 | tzdata==2025.1 260 | Unidecode==1.4.0 261 | uri-template==1.3.0 262 | urllib3==2.3.0 263 | uvicorn==0.34.2 264 | uvloop==0.21.0 265 | virtualenv==20.29.3 266 | vllm==0.7.2 267 | wandb==0.19.8 268 | watchfiles==1.0.5 269 | wcwidth==0.2.13 270 | webcolors==24.11.1 271 | webencodings==0.5.1 272 | websocket-client==1.8.0 273 | websockets==15.0.1 274 | Werkzeug==3.1.3 275 | widgetsnbextension==4.0.13 276 | wrapt==1.17.2 277 | xformers==0.0.28.post3 278 | xgrammar==0.1.18 279 | xopen==2.0.2 280 | xxhash==3.5.0 281 | yarl==1.18.3 282 | zipp==3.21.0 283 | zlib-ng==0.5.1 284 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

POSITION BIAS MITIGATES POSITION BIAS: Mitigate Position Bias Through Inter-Position Knowledge Distillation

3 |

4 | 5 | Paper PDF 6 | 7 |

8 |

9 | 10 | 11 | 12 | We propose **Pos2Distill**, a novel position to position knowledge distillation framework, transferring knowledge from advantageous positions to rectify responses at unfavorable ones, therefore mitigating position bias naturally. 13 | ![](./assets/overview.jpg) 14 |
15 | Tabel of Contents 16 |
    17 |
  1. 18 | Installation 19 |
  2. 20 |
  3. 21 | Data Preparation 22 |
  4. 23 |
  5. 24 | Training 25 |
  6. 26 |
  7. 27 | Evaluation 28 |
  8. 29 |
  9. 30 | Citation 31 |
  10. 32 |
  11. 33 | Acknowledgement 34 |
  12. 35 |
36 |
37 | 38 | ## News 39 | - [2025-11-03] We have released our test code, data, and model weights for Pos2Distill. 40 | 41 | ## Installation 42 | Create a conda environment and install the required packages: 43 | ```bash 44 | conda create -n pos2distill python=3.10 45 | conda activate pos2distill 46 | 47 | git clone https://github.com/AMAP-ML/Pos2Distill.git 48 | cd Pos2Distill 49 | pip install -r requirements.txt 50 | ``` 51 | Our knowledge distillation freamwork is based on [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF), an open-source RLHF framework. 52 | ```bash 53 | pip install openrlhf==0.6.1.post1 54 | ``` 55 | ## Data Preparation 56 | We provide a NaturalQuestions (20docs) in [google drive](): 57 | - Firstly, we need to generate high quality responses from advantagous positions through **gen_advantageous_pos_for_R1.sh**, the resulting responses are saved in raw_data/$dataset_name directory. 58 | - Second, we need construct advantagous-trivia pairs for knowledge distillation in the training stage: 59 | - example_num: how many different datapoints are used; 60 | - position_sample_num: how many trivial positions are sampled. 61 | ```bash 62 | model_name=Mistral-7B-Instruct-v0.3 63 | dataset_name=nq 64 | python create_datasets.py \ 65 | --model_name $model_name \ 66 | --position_sample_num 4 \ 67 | --example_num 400 \ 68 | --dataset_name $dataset_name 69 | ``` 70 | 71 | ## Training 72 | We train our model through the script train_nq.sh: kd_coef controls which training approach you choose, "0.0,0.0,0.0,1.0" means you train models through Pos2Distill. 73 | ```bash 74 | kd_coef="0.0,0.0,0.0,1.0" 75 | deepspeed --master_port 6666 \ 76 | kd/train_kd.py \ 77 | --max_len 32000 \ 78 | --dataset ${model_name}_${total_docs}total_docs_filter_${K}random_${strengthen}strengthen_${num}\ 79 | --input_key question \ 80 | --output_key oracle_answer \ 81 | --train_batch_size 32 \ 82 | --micro_train_batch_size 4 \ 83 | --max_samples 500000 \ 84 | --pretrain $model_name \ 85 | --teacher_model $model_name \ 86 | --save_path checkpoints/${dataset_name}\ 87 | --save_steps -1 \ 88 | --logging_steps 1 \ 89 | --eval_steps -1 \ 90 | --zero_stage 3 \ 91 | --max_epochs 1 \ 92 | --l2 0.01 \ 93 | --bf16 \ 94 | --flash_attn \ 95 | --kd_coef $kd_coef \ 96 | --learning_rate 3e-6 \ 97 | --teacher_offload \ 98 | --apply_chat_template \ 99 | --gradient_checkpointing \ 100 | --perserve 1.0 \ 101 | --dataset_name $dataset_name \ 102 | --use_tensorboard tensorBoard 103 | ``` 104 | ## Evaluation 105 | 106 | 1. Currently, the model weights are saved in checkpoints directory. You can run the test code (eval_data.sh) using the command below to evaluate the 107 | performance of Pos2Distill on NaturalQuestion datasets. 108 | ```bash 109 | total_docs=20 110 | num_gpus=8 111 | INPUT_PATH="nq-open-${total_docs}_total_documents.jsonl.gz" 112 | random_num=4 113 | strengthen=1 114 | model_name=Mistral-7B-Instruct-v0.3_20total_docs_filter_4random_1strengthen_400_kd0.0_lm0.0_rank0.0_adaptive1.0_1.0 115 | python eval_data.py \ 116 | --input-path "$INPUT_PATH" \ 117 | --model $model_name \ 118 | --output-path evaluate \ 119 | --sample_num 500 \ 120 | --max-prompt-length 32768 \ 121 | --max-new-tokens 100 \ 122 | --num_gpus "$num_gpus" \ 123 | --total_doc "$total_docs" 124 | ``` 125 | 2. If you want evaluate Pos2Distill on webq datasets (20 docs), run code below: 126 | ```bash 127 | INPUT_PATH="webq_dev.jsonl.gz" 128 | python eval_data.py \ 129 | --input-path "$INPUT_PATH" \ 130 | --model $model_name \ 131 | --output-path evaluate \ 132 | --sample_num 500 \ 133 | --max-prompt-length 32768 \ 134 | --max-new-tokens 100 \ 135 | --num_gpus "$num_gpus" \ 136 | --total_doc "$total_docs" 137 | ``` 138 | 3. If you want evaluate Pos2Distill on tqa datasets (20 docs), run code below: 139 | ```bash 140 | INPUT_PATH="tqa_dev.jsonl.gz" 141 | python eval_data.py \ 142 | --input-path "$INPUT_PATH" \ 143 | --model $model_name \ 144 | --output-path evaluate \ 145 | --sample_num 500 \ 146 | --max-prompt-length 32768 \ 147 | --max-new-tokens 100 \ 148 | --num_gpus "$num_gpus" \ 149 | --total_doc "$total_docs" 150 | ``` 151 | 152 | ## Acknowledgement 153 | We would like to thank the following works for their code and models: 154 | - Training: [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF), [TRL](https://huggingface.co/docs/trl/sft_trainer) 155 | - Datasets: [NaturalQuestions](https://github.com/nelson-liu/lost-in-the-middle), [TQA](https://huggingface.co/datasets/vsearch/tqa), [WEBQA](https://huggingface.co/datasets/vsearch/webq) 156 | 157 | We are extremely grateful to **Linjing Li**, **Yong Wang**, **Xiangxiang Chu** and many other friends in our Machine Learning Group for their helpful feedback and insightful discussions. 158 | 159 | ## Citation 160 | If you feel this project is helpful, please consider cite our report :blush:. If you have any question about us, please contact the email: [Yifei Wang](https://scholar.google.com/citations?hl=zh-CN&user=AHD4c24AAAAJ). 161 | ```bibtex 162 | @article{wang2025position, 163 | title={Position Bias Mitigates Position Bias: Mitigate Position Bias Through Inter-Position Knowledge Distillation}, 164 | author={Wang, Yifei and Xiong, Feng and Wang, Yong and Li, Linjing and Chu, Xiangxiang and Zeng, Daniel Dajun}, 165 | journal={arXiv preprint arXiv:2508.15709}, 166 | year={2025} 167 | } 168 | ``` 169 | -------------------------------------------------------------------------------- /eval_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import logging 5 | import pathlib 6 | import random 7 | import sys 8 | from copy import deepcopy 9 | import torch 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer,LlamaTokenizer 12 | from xopen import xopen 13 | from vllm import LLM, SamplingParams 14 | import pandas as pd 15 | import random 16 | import numpy as np 17 | from utils import * 18 | from transformers import AutoModelForCausalLM, AutoTokenizer 19 | from gen_advantageous_pos_for_R1 import seed_everything,get_prompt_at_gold_index 20 | cache_dir="/mnt/workspace/wangyifei/projects/huggingcache" # use your cache dir of model weights 21 | seed = random.randint(0, 2**32 - 1) 22 | seed_everything(seed) 23 | def write_responses(output_path,examples, responses,prompts): 24 | with xopen(output_path, "w") as f: 25 | for (example,response,prompt) in zip(examples,responses,prompts): 26 | is_correct = int(best_subspan_em(response.split("\n")[0],example["answers"])) 27 | example["current_answer"] = response 28 | example["current_prompt"] = prompt 29 | f.write(json.dumps(example) + "\n") 30 | average_metric_value = evaluate_qa_data(input_path=output_path) 31 | print(f"average_metric_value: {average_metric_value}") 32 | return average_metric_value 33 | 34 | 35 | def main( 36 | input_path, 37 | model_name, 38 | temperature, 39 | top_p, 40 | num_gpus, 41 | max_new_tokens, 42 | max_prompt_length, 43 | output_path, 44 | sample_num, 45 | gold_doc, 46 | total_doc, 47 | dataset_name 48 | 49 | ): 50 | 51 | # os.makedirs(f"{output_path}/{total_doc}",exist_ok=True) 52 | os.makedirs(output_path,exist_ok=True) 53 | logger.info(f"model name {model_name}") 54 | if model_name.lower().endswith("instruct") or model_name.lower().endswith("chat") or model_name.lower().endswith("v0.3") or model_name.lower().endswith("7b") or model_name.lower().endswith("lite") or model_name.lower().endswith("hf"): 55 | model_path = os.path.join(cache_dir,model_name) 56 | else: 57 | model_path = os.path.join(f"checkpoints/nq",model_name) 58 | model = LLM( 59 | model=model_path, 60 | tensor_parallel_size=num_gpus, 61 | # load_format="safetensors", 62 | # load_format="pt", 63 | trust_remote_code=True, 64 | max_num_batched_tokens=max_prompt_length, 65 | ) 66 | sampling_params = SamplingParams(temperature=0.0, max_tokens=max_new_tokens,top_p=top_p,seed=seed) 67 | 68 | results_average = {"Model_name": model_name} 69 | results_std = {"Model_name": model_name} 70 | for i in range(0,total_doc+1,5): 71 | output_path_gold_index = os.path.join(output_path, f"{model_name}_{i}.jsonl.gz") 72 | average_iterations = 1 73 | average_metric_values = [] 74 | for j in range(average_iterations): 75 | prompts, examples = get_prompt_at_gold_index(input_path, i, model_name, total_doc, model.get_tokenizer()) 76 | prompts = prompts[-sample_num:] 77 | examples = examples[-sample_num:] 78 | if i==0: 79 | # compute average tokens 80 | stats = [] 81 | for prompt in prompts: 82 | tok=model.get_tokenizer() 83 | tok.pad_token = tok.eos_token 84 | prompt_len = tok(prompt,padding=True,add_special_tokens=False,return_tensors="pt")["input_ids"].size()[1] 85 | stats.append(prompt_len) 86 | logger.info(f"tokens stats: max={max(stats)} min={min(stats)} avg={np.mean(stats)}") 87 | sampling_params.seed+=1 88 | raw_responses = model.generate(prompts, sampling_params) 89 | responses = [output.outputs[0].text.strip() for output in raw_responses] 90 | 91 | 92 | average_metric_value = write_responses(output_path_gold_index, examples, responses,prompts) 93 | average_metric_values.append(average_metric_value) 94 | # 记录均值和标准差 95 | results_average[i] = np.mean(average_metric_values) 96 | results_std[i] = np.std(average_metric_values) 97 | # 转换为 DataFrame 98 | df_avg = pd.DataFrame([results_average]) 99 | file_path = f"{dataset_name}_{total_doc}.xlsx" 100 | if os.path.exists(file_path): 101 | existing_df = pd.read_excel(file_path, engine='openpyxl') 102 | combined_df = pd.concat([existing_df, df_avg], ignore_index=True) 103 | else: 104 | combined_df = pd.concat([df_avg], ignore_index=True) 105 | combined_df.to_excel(file_path, index=False, engine='openpyxl') 106 | print(f"Results saved to {file_path}") 107 | return 108 | 109 | 110 | 111 | 112 | if __name__ == "__main__": 113 | logging.basicConfig(format="%(asctime)s - %(module)s - %(levelname)s - %(message)s", level=logging.INFO) 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--input-path", help="Path to data with questions and documents to use.", required=True) 116 | parser.add_argument( 117 | "--model", 118 | help="Model to use in generating responses", 119 | required=True, 120 | ) 121 | parser.add_argument("--temperature", help="Temperature to use in generation", type=float, default=0.6) 122 | parser.add_argument("--top-p", help="Top-p to use in generation", type=float, default=1.0) 123 | parser.add_argument("--num-gpus", help="Number of GPUs to use", type=int, default=1) 124 | parser.add_argument("--output-path", help="Path to write output file of generated responses", required=True) 125 | parser.add_argument( 126 | "--max-new-tokens", 127 | help="Maximum number of new tokens to generate", 128 | type=int, 129 | default=100, 130 | ) 131 | parser.add_argument( 132 | "--max-prompt-length", 133 | help="Maximum number of tokens in the prompt. Longer prompts will be skipped.", 134 | type=int, 135 | default=4096, 136 | ) 137 | parser.add_argument( 138 | "--sample_num", 139 | help="sample size", 140 | type=int, 141 | default=500, 142 | ) 143 | parser.add_argument( 144 | "--top_p", 145 | help="top_p", 146 | type=float, 147 | default=1.0, 148 | ) 149 | parser.add_argument( 150 | "--gold_doc", 151 | type=int, 152 | default=0, 153 | ) 154 | parser.add_argument( 155 | "--num_gpus", 156 | type=int, 157 | default=0, 158 | ) 159 | parser.add_argument( 160 | "--total_doc", 161 | type=int, 162 | ) 163 | parser.add_argument( 164 | "--dataset_name", 165 | type=str, 166 | ) 167 | 168 | args = parser.parse_args() 169 | 170 | logger.info("running %s", " ".join(sys.argv)) 171 | main( 172 | args.input_path, 173 | args.model, 174 | args.temperature, 175 | args.top_p, 176 | args.num_gpus, 177 | args.max_new_tokens, 178 | args.max_prompt_length, 179 | args.output_path, 180 | args.sample_num, 181 | args.gold_doc, 182 | args.total_doc, 183 | args.dataset_name 184 | 185 | ) 186 | logger.info("finished running %s", sys.argv[0]) 187 | -------------------------------------------------------------------------------- /kd/customized_kd_trainer.py: -------------------------------------------------------------------------------- 1 | from openrlhf.trainer import KDTrainer 2 | from typing import Optional, Tuple 3 | from tqdm import tqdm 4 | import torch 5 | import torch.distributed as dist 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import os 9 | from abc import ABC 10 | import torch 11 | 12 | from torch.optim import Optimizer 13 | from tqdm import tqdm 14 | from openrlhf.models import GPTLMLoss, KDLoss 15 | from openrlhf.utils.distributed_sampler import DistributedSampler 16 | import logging 17 | from typing import Optional, Tuple 18 | import torch 19 | import torch.distributed as dist 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from loss_design import AdaptiveKLWeightedKLLoss,MY_KDLoss,MY_rankLoss,MY_topkLoss,MY_GPTLMLoss,AdaptiveKLWeightedKLLoss_multi 23 | logger = logging.getLogger('customized kd trainer') 24 | 25 | class MY_KDTrainer(KDTrainer): 26 | def __init__(self, *args, **kwargs): 27 | attention_dillution_coef = kwargs.pop("attention_dillution_coef", 1.0) 28 | multi_gold_doc = kwargs.pop("multi_gold_doc", False) 29 | super().__init__(*args, **kwargs) 30 | self.kd_loss = MY_KDLoss() 31 | self.rank_loss = MY_topkLoss() 32 | self.loss_fn = MY_GPTLMLoss() 33 | self.adaptive_kd_loss = AdaptiveKLWeightedKLLoss(batch_norm=True,attention_dillution_coef = attention_dillution_coef) if not multi_gold_doc else AdaptiveKLWeightedKLLoss_multi(batch_norm=True,attention_dillution_coef = attention_dillution_coef) 34 | # self.loss_fn = MY_rankLoss() 35 | 36 | def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): 37 | # get eval and save steps 38 | if args.eval_steps == -1: 39 | args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch 40 | if args.save_steps == -1: 41 | args.save_steps = float("inf") # do not save ckpt 42 | 43 | # Restore step and start_epoch 44 | step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 45 | start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch 46 | consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) 47 | 48 | epoch_bar = tqdm( 49 | range(start_epoch, self.epochs), 50 | desc="Train epoch", 51 | disable=not self.strategy.is_rank_0(), 52 | ) 53 | loss_sum = 0 54 | for epoch in range(start_epoch, self.epochs): 55 | if isinstance(self.train_dataloader.sampler, DistributedSampler): 56 | self.train_dataloader.sampler.set_epoch( 57 | epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples 58 | ) 59 | step_bar = tqdm( 60 | range(self.train_dataloader.__len__()), 61 | desc="Train step of epoch %d" % epoch, 62 | disable=not self.strategy.is_rank_0(), 63 | ) 64 | # train 65 | self.model.train() 66 | self.teacher_model.eval() 67 | 68 | for teacher_prompt_id_lens,teacher_inputs,teacher_attention_masks, student_prompt_id_lens,student_inputs,student_attention_masks,infos in self.train_dataloader: 69 | student_inputs = student_inputs.squeeze(1).to(torch.cuda.current_device()) 70 | student_attention_mask = student_attention_masks.squeeze(1).to(torch.cuda.current_device()) 71 | student_output = self.model(student_inputs, attention_mask=student_attention_mask, return_output=True) 72 | student_labels = torch.where( 73 | student_attention_mask.bool(), 74 | student_inputs, 75 | self.loss_fn.IGNORE_INDEX, 76 | ) 77 | teacher_inputs = teacher_inputs.squeeze(1).to(torch.cuda.current_device()) 78 | teacher_attention_mask = teacher_attention_masks.squeeze(1).to(torch.cuda.current_device()) 79 | teacher_labels = torch.where( 80 | teacher_attention_mask.bool(), 81 | teacher_inputs, 82 | self.loss_fn.IGNORE_INDEX, 83 | ) 84 | if not self.pretrain_mode: 85 | for label, source_len in zip(student_labels, student_prompt_id_lens): 86 | label[:source_len] = self.loss_fn.IGNORE_INDEX 87 | 88 | for label, source_len in zip(teacher_labels, teacher_prompt_id_lens): 89 | label[:source_len] = self.loss_fn.IGNORE_INDEX 90 | if args.kd_coef[1] > 0: 91 | gpt_loss = self.loss_fn(student_output.logits, student_labels) 92 | else: 93 | gpt_loss = torch.tensor(0.0).to(student_output.logits.device) 94 | 95 | with torch.no_grad(): 96 | teacher_logits = self.teacher_model(teacher_inputs, attention_mask=teacher_attention_mask, return_output=True)[ 97 | "logits" 98 | ] 99 | if args.kd_coef[0] > 0: 100 | distil_loss = self.kd_loss(student_output.logits, teacher_logits, student_labels,teacher_labels) 101 | else: 102 | distil_loss = torch.tensor(0.0).to(teacher_logits.device) 103 | 104 | if args.kd_coef[2] > 0: 105 | rank_loss = self.rank_loss(student_output.logits, teacher_logits, student_labels,teacher_labels) 106 | else: 107 | rank_loss = torch.tensor(0.0).to(teacher_logits.device) 108 | 109 | if args.kd_coef[3] > 0: 110 | adaptive_kd_loss= self.adaptive_kd_loss(student_output.logits, teacher_logits, student_labels,teacher_labels,infos["gold_idx"]) 111 | else: 112 | adaptive_kd_loss = torch.tensor(0.0).to(teacher_logits.device) 113 | #均衡系数 114 | 115 | loss = distil_loss * self.args.kd_coef[0] + gpt_loss* args.kd_coef[1] + rank_loss * args.kd_coef[2] + adaptive_kd_loss* self.args.kd_coef[3] 116 | 117 | self.strategy.backward(loss, self.model, self.optimizer) 118 | self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) 119 | 120 | loss_sum += loss.item() 121 | logs_dict = { 122 | "gpt_loss": gpt_loss.item(), 123 | "distil_loss": distil_loss.item(), 124 | "rank_loss": rank_loss.item(), 125 | "adaptive_kd_loss": adaptive_kd_loss.item(), 126 | # "lr": self.scheduler.get_last_lr()[0], 127 | } 128 | # step bar 129 | logs_dict = self.strategy.all_reduce(logs_dict) 130 | step_bar.set_postfix(logs_dict) 131 | step_bar.update() 132 | 133 | # logs/checkpoints/evaluation 134 | if step % self.strategy.accumulated_gradient == 0: 135 | logs_dict["loss_mean"] = loss_sum / self.strategy.accumulated_gradient 136 | loss_sum = 0 137 | global_step = step // self.strategy.accumulated_gradient 138 | client_states = {"consumed_samples": global_step * args.train_batch_size} 139 | self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) 140 | 141 | step += 1 142 | 143 | epoch_bar.update() 144 | 145 | if self._wandb is not None and self.strategy.is_rank_0(): 146 | self._wandb.finish() 147 | if self._tensorboard is not None and self.strategy.is_rank_0(): 148 | self._tensorboard.close() 149 | 150 | # logs/checkpoints/evaluation 151 | def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): 152 | if global_step % args.logging_steps == 0: 153 | # wandb 154 | if self._wandb is not None and self.strategy.is_rank_0(): 155 | logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} 156 | self._wandb.log(logs) 157 | # TensorBoard 158 | elif self._tensorboard is not None and self.strategy.is_rank_0(): 159 | for k, v in logs_dict.items(): 160 | self._tensorboard.add_scalar(f"train/{k}", v, global_step) 161 | 162 | # eval 163 | # if global_step % args.eval_steps == 0: 164 | # # do eval when len(dataloader) > 0, avoid zero division in eval. 165 | # if len(self.eval_dataloader) > 0: 166 | # self.evaluate(self.eval_dataloader, global_step) 167 | # save ckpt 168 | # TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric 169 | # if global_step % args.save_steps == 0: 170 | # tag = f"global_step{global_step}" 171 | # self.strategy.save_ckpt( 172 | # self.model.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states 173 | # ) 174 | 175 | def evaluate(self, eval_dataloader, steps=0): 176 | times = 0 177 | self.model.eval() 178 | self.teacher_model.eval() 179 | with torch.no_grad(): 180 | loss_sum = 0 181 | step_bar = tqdm( 182 | range(eval_dataloader.__len__()), 183 | desc="Eval stage of steps %d" % steps, 184 | disable=not self.strategy.is_rank_0(), 185 | ) 186 | for teacher_prompt_id_lens,teacher_inputs,teacher_attention_masks, student_prompt_id_lens,student_inputs,student_attention_masks,infos in eval_dataloader: 187 | student_inputs = student_inputs.squeeze(1).to(torch.cuda.current_device()) 188 | student_attention_mask = student_attention_masks.squeeze(1).to(torch.cuda.current_device()) 189 | student_output = self.model(student_inputs, attention_mask=student_attention_mask, return_output=True) 190 | student_labels = torch.where( 191 | student_attention_mask.bool(), 192 | student_inputs, 193 | self.loss_fn.IGNORE_INDEX, 194 | ) 195 | teacher_inputs = teacher_inputs.squeeze(1).to(torch.cuda.current_device()) 196 | teacher_attention_mask = teacher_attention_masks.squeeze(1).to(torch.cuda.current_device()) 197 | teacher_labels = torch.where( 198 | teacher_attention_mask.bool(), 199 | teacher_inputs, 200 | self.loss_fn.IGNORE_INDEX, 201 | ) 202 | if not self.pretrain_mode: 203 | for label, source_len in zip(student_labels, student_prompt_id_lens): 204 | label[:source_len] = self.loss_fn.IGNORE_INDEX 205 | 206 | for label, source_len in zip(teacher_labels, teacher_prompt_id_lens): 207 | label[:source_len] = self.loss_fn.IGNORE_INDEX 208 | if args.kd_coef[1] > 0: 209 | gpt_loss = self.loss_fn(student_output.logits, student_labels) 210 | else: 211 | gpt_loss = torch.tensor(0.0).to(student_output.logits.device) 212 | 213 | with torch.no_grad(): 214 | teacher_logits = self.teacher_model(teacher_inputs, attention_mask=teacher_attention_mask, return_output=True)[ 215 | "logits" 216 | ] 217 | if args.kd_coef[0] > 0: 218 | distil_loss = self.kd_loss(student_output.logits, teacher_logits, student_labels,teacher_labels) 219 | else: 220 | distil_loss = torch.tensor(0.0).to(teacher_logits.device) 221 | 222 | if args.kd_coef[2] > 0: 223 | rank_loss = self.rank_loss(student_output.logits, teacher_logits, student_labels,teacher_labels) 224 | else: 225 | rank_loss = torch.tensor(0.0).to(teacher_logits.device) 226 | 227 | if args.kd_coef[3] > 0: 228 | adaptive_kd_loss = self.adaptive_kd_loss(student_output.logits, teacher_logits, student_labels,teacher_labels,infos["gold_idx"]) 229 | else: 230 | adaptive_kd_loss = torch.tensor(0.0).to(teacher_logits.device) 231 | #均衡系数 232 | 233 | loss = distil_loss * self.args.kd_coef[0] + gpt_loss* args.kd_coef[1] + rank_loss * args.kd_coef[2] + adaptive_kd_loss* self.args.kd_coef[3] 234 | 235 | 236 | 237 | 238 | times += 1 239 | loss_sum += loss.item() 240 | bar_dict = {"eval loss": loss_sum / times} 241 | step_bar.update() 242 | logs = self.strategy.all_reduce(bar_dict) 243 | step_bar.set_postfix(logs) 244 | 245 | if self.strategy.is_rank_0(): 246 | if self._wandb is not None: 247 | logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} 248 | self._wandb.log(logs) 249 | elif self._tensorboard is not None: 250 | for k, v in logs.items(): 251 | self._tensorboard.add_scalar(f"eval/{k}", v, steps) 252 | 253 | self.model.train() # reset model state 254 | -------------------------------------------------------------------------------- /kd/train_kd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from datetime import datetime 5 | import logging 6 | from transformers.trainer import get_scheduler 7 | import torch 8 | from customized_sft_dataset import SFTDataset 9 | from openrlhf.models import Actor 10 | # from openrlhf.trainer import KDTrainer 11 | from customized_kd_trainer import MY_KDTrainer 12 | from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer 13 | # from openrlhf.utils.deepspeed import DeepspeedStrategy 14 | from torch import distributed as dist 15 | # from torch.utils.data import DataLoader 16 | # from openrlhf.utils.distributed_sampler import DistributedSampler 17 | import numpy as np 18 | import random 19 | import transformers 20 | 21 | def seed_everything(seed): 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | np.random.seed(seed) 25 | random.seed(seed) 26 | torch.backends.cudnn.benchmark = False 27 | torch.backends.cudnn.deterministic = True 28 | torch.cuda.manual_seed_all(seed) 29 | 30 | seed_everything(42) 31 | 32 | 33 | 34 | 35 | logging.basicConfig( 36 | level=logging.INFO, # 设置日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL 37 | format='%(asctime)s - %(levelname)s - %(message)s', 38 | datefmt='%Y-%m-%d %H:%M:%S' 39 | ) 40 | logger = logging.getLogger('kd_training') 41 | def train(args): 42 | cache_dir = "/mnt/workspace/wangyifei/projects/huggingcache" 43 | dataset_dir = "/mnt/workspace/wangyifei/projects/self_training/datasets" 44 | if args.multi_gold_doc: 45 | assert args.dataset_name in ["musique"] 46 | else: 47 | assert args.dataset_name in ["nq","tqa","webq"] 48 | # configure strategy 49 | strategy = get_strategy(args) 50 | strategy.setup_distributed() 51 | 52 | # configure model 53 | # load huggingface model 54 | logger.info(f"Loading student model from {os.path.join(cache_dir, args.pretrain)}") 55 | model = Actor( 56 | pretrain_or_model = os.path.join(cache_dir, args.pretrain), 57 | # pretrain_or_model="/mnt/workspace/wangyifei/projects/self_training/kl_checkpoints/Qwen1.5-7B-Chat_20total_docs_filter_5random_2strengthen_400_kl_0.8", 58 | use_flash_attention_2=args.flash_attn, 59 | bf16=args.bf16, 60 | load_in_4bit=args.load_in_4bit, 61 | lora_rank=args.lora_rank, 62 | lora_alpha=args.lora_alpha, 63 | target_modules=args.target_modules, 64 | lora_dropout=args.lora_dropout, 65 | ds_config=strategy.get_ds_train_config(is_actor=True), 66 | ) 67 | logger.info(f"Loading teacher model from {os.path.join(cache_dir, args.teacher_model)}") 68 | # load teacher model for inference 69 | teacher_model = Actor( 70 | os.path.join(cache_dir, args.teacher_model), 71 | use_flash_attention_2=args.flash_attn, 72 | bf16=args.bf16, 73 | load_in_4bit=args.load_in_4bit, 74 | ds_config=strategy.get_ds_eval_config(offload=args.teacher_offload), 75 | ) 76 | if args.teacher_offload: 77 | teacher_model._offload = True 78 | 79 | 80 | # configure tokenizer 81 | logger.info(f"Loading tokenizer from {os.path.join(cache_dir, args.pretrain)}") 82 | tokenizer = get_tokenizer(os.path.join(cache_dir, args.pretrain), model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer) 83 | 84 | strategy.print(model) 85 | 86 | 87 | # configure optimizer 88 | logger.info(f"Creating optimizer....") 89 | optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2) 90 | 91 | logger.info(f"Prepare for data and dataset....") 92 | # prepare for data and dataset 93 | logger.info(f"args.dataset_probs: {args.dataset_probs}") 94 | logger.info(f"args.dataset: {args.dataset}") 95 | dataset_dir = os.path.join(dataset_dir, args.dataset_name) 96 | train_data = blending_datasets( 97 | os.path.join(dataset_dir, args.dataset), 98 | args.dataset_probs, 99 | strategy, 100 | args.seed, 101 | return_eval=False, 102 | max_count=args.max_samples, 103 | train_split=args.train_split, 104 | eval_split=args.eval_split, 105 | ) 106 | 107 | # eval_dataset_name = args.dataset.split("_unfilter")[0] + "_dev" 108 | # logger.info(f"Prepare for eval data and dataset....") 109 | # logger.info(f"eval dataset name: {eval_dataset_name}") 110 | # eval_data = blending_datasets( 111 | # os.path.join(dataset_dir, eval_dataset_name), 112 | # args.dataset_probs, 113 | # strategy, 114 | # args.seed, 115 | # return_eval=False, 116 | # max_count=args.max_samples, 117 | # train_split=args.train_split, 118 | # eval_split=args.eval_split, 119 | # ) 120 | 121 | train_data = train_data.select(range(min(args.max_samples, len(train_data)))) 122 | train_data = train_data.shuffle(seed=args.seed) 123 | # eval_data = eval_data.select(range(min(args.max_samples, len(eval_data)))) 124 | 125 | train_dataset = SFTDataset( 126 | train_data, 127 | tokenizer, 128 | args.max_len, 129 | strategy, 130 | pretrain_mode=args.pretrain_mode, 131 | input_template=args.input_template, 132 | multi_gold_doc=args.multi_gold_doc, 133 | num_processors=1, 134 | ) 135 | 136 | # eval_dataset = SFTDataset( 137 | # eval_data, 138 | # tokenizer, 139 | # args.max_len, 140 | # strategy, 141 | # pretrain_mode=args.pretrain_mode, 142 | # input_template=args.input_template, 143 | # ) 144 | 145 | train_dataloader = strategy.setup_dataloader( 146 | replay_buffer = train_dataset, 147 | batch_size = args.micro_train_batch_size, 148 | pin_memory = True, 149 | shuffle = True, 150 | collate_fn = train_dataset.collate_fn, 151 | # use_block_split=True , 152 | # block_size= args.micro_train_batch_size 153 | ) 154 | 155 | num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size 156 | max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) 157 | 158 | scheduler = get_scheduler( 159 | args.lr_scheduler, 160 | optim, 161 | num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), 162 | num_training_steps=max_steps, 163 | scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1}, 164 | ) 165 | 166 | # gradient_checkpointing 167 | if args.gradient_checkpointing: 168 | model.gradient_checkpointing_enable( 169 | gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} 170 | ) 171 | 172 | # prepare models 173 | ((model, optim, scheduler), teacher_model) = strategy.prepare((model, optim, scheduler), teacher_model) 174 | 175 | # load checkpoint 176 | args.kd_coef = [float(x) for x in args.kd_coef.split(',')] 177 | consumed_samples = 0 178 | if args.load_checkpoint and os.path.exists(args.ckpt_path): 179 | _, states = strategy.load_ckpt(model.model, args.ckpt_path) 180 | consumed_samples = states["consumed_samples"] 181 | strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") 182 | if args.output_key == "oracle_answer": 183 | checkpoint_name = args.dataset + f"_kd{str(args.kd_coef[0])}_lm{str(args.kd_coef[1])}_rank{str(args.kd_coef[2])}_adaptive{str(args.kd_coef[3])}_{args.perserve}" 184 | # args.save_path = os.path.join(args.save_path, args.dataset + f"_kd{str(args.kd_coef[0])}_lm{str(args.kd_coef[1])}_rank{str(args.kd_coef[2])}_adaptive{str(args.kd_coef[3])}") 185 | strategy.args.wandb_run_name = args.dataset + f"_kd{str(args.kd_coef[0])}_lm{str(args.kd_coef[1])}_rank{str(args.kd_coef[2])}_adaptive{str(args.kd_coef[3])}_{args.perserve}" 186 | else: 187 | checkpoint_name = args.dataset + f"_sft" 188 | # args.save_path = os.path.join(args.save_path, args.dataset + f"_sft") 189 | strategy.args.wandb_run_name = args.dataset + f"_sft" 190 | assert args.kd_coef[0] == 0.0 191 | 192 | 193 | # configure Trainer 194 | attention_dillution_coef = args.perserve 195 | trainer = MY_KDTrainer( 196 | model=model, 197 | teacher_model=teacher_model, 198 | strategy=strategy, 199 | optim=optim, 200 | train_dataloader=train_dataloader, 201 | eval_dataloader=None, 202 | scheduler=scheduler, 203 | max_norm=args.max_norm, 204 | pretrain_mode=args.pretrain_mode, 205 | batch_size=args.train_batch_size, 206 | max_epochs=args.max_epochs, 207 | tokenizer=tokenizer, 208 | multi_gold_doc=args.multi_gold_doc, 209 | attention_dillution_coef=round(attention_dillution_coef, 1), 210 | # **({"attention_dillution_coef": round(attention_dillution_coef, 1)} if args.output_key == "oracle_answer" else {}), 211 | 212 | ) 213 | 214 | trainer.fit(args, consumed_samples, num_update_steps_per_epoch) 215 | # save model checkpoint after fitting on only rank0 216 | save_path = os.path.join(args.save_path, checkpoint_name) 217 | os.makedirs(save_path, exist_ok=True) 218 | strategy.save_model(model, tokenizer, save_path) 219 | 220 | 221 | 222 | if __name__ == "__main__": 223 | parser = argparse.ArgumentParser() 224 | # Checkpoints 225 | parser.add_argument("--save_path", type=str, default="./ckpt") 226 | parser.add_argument("--save_steps", type=int, default=-1) 227 | parser.add_argument("--logging_steps", type=int, default=1) 228 | parser.add_argument("--eval_steps", type=int, default=-1) 229 | parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_kd") 230 | parser.add_argument("--max_ckpt_num", type=int, default=3) 231 | parser.add_argument("--max_ckpt_mem", type=int, default=1e8) 232 | parser.add_argument("--load_checkpoint", action="store_true", default=False) 233 | 234 | # DeepSpeed 235 | parser.add_argument("--micro_train_batch_size", type=int, default=8, help="batch size per GPU") 236 | parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") 237 | parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") 238 | parser.add_argument("--gradient_checkpointing", action="store_true", default=False) 239 | parser.add_argument("--seed", type=int, default=42) 240 | parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") 241 | parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") 242 | parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") 243 | parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") 244 | parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") 245 | parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") 246 | parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") 247 | parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") 248 | parser.add_argument("--overlap_comm", action="store_true", default=False) 249 | parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) 250 | parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) 251 | 252 | # LoRA 253 | parser.add_argument("--load_in_4bit", action="store_true", default=False) 254 | parser.add_argument("--lora_rank", type=int, default=0) 255 | parser.add_argument("--lora_alpha", type=int, default=16) 256 | parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") 257 | parser.add_argument("--lora_dropout", type=float, default=0) 258 | 259 | # KD 260 | parser.add_argument("--pretrain", type=str, default=None) 261 | parser.add_argument("--teacher_model", type=str, default=None) 262 | parser.add_argument("--max_epochs", type=int, default=1) 263 | parser.add_argument("--kd_coef", type=str) 264 | parser.add_argument("--learning_rate", type=float, default=5e-6) 265 | parser.add_argument("--lr_warmup_ratio", type=float, default=0.05) 266 | parser.add_argument("--pretrain_mode", action="store_true", default=False, help="Use pretrain loss") 267 | parser.add_argument("--lr_scheduler", type=str, default="cosine_with_min_lr") 268 | parser.add_argument("--l2", type=float, default=0, help="weight decay loss") 269 | parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") 270 | parser.add_argument("--teacher_offload", action="store_true", default=False) 271 | 272 | # Custom dataset 273 | parser.add_argument("--dataset", type=str, default=None) 274 | parser.add_argument("--dataset_name", type=str, default=None) 275 | parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets") 276 | parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset") 277 | parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset") 278 | 279 | parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") 280 | parser.add_argument("--output_key", type=str, default="output", help="JSON dataset key") 281 | parser.add_argument("--input_template", type=str, default="User: {}\nAssistant: ") 282 | parser.add_argument( 283 | "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" 284 | ) 285 | 286 | parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples") 287 | parser.add_argument("--max_len", type=int, default=2048, help="Max tokens for the samples") 288 | 289 | # wandb parameters 290 | parser.add_argument("--use_wandb", type=str, default=None) 291 | parser.add_argument("--wandb_org", type=str, default=None) 292 | parser.add_argument("--wandb_group", type=str, default=None) 293 | parser.add_argument("--wandb_project", type=str, default="openrlhf_train_sft") 294 | # parser.add_argument( 295 | # "--wandb_run_name", 296 | # type=str, 297 | # default="sft_%s" % datetime.now().strftime("%m%dT%H:%M"), 298 | # ) 299 | parser.add_argument( 300 | "--wandb_run_name", 301 | type=str, 302 | default="temp") 303 | # TensorBoard parameters 304 | parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") 305 | 306 | # ModelScope parameters 307 | parser.add_argument("--use_ms", action="store_true", default=False) 308 | # kd trainer parameters 309 | parser.add_argument("--perserve", type=float,default=1.0) 310 | parser.add_argument("--multi_gold_doc",action="store_true",default=False) 311 | 312 | args = parser.parse_args() 313 | 314 | if args.input_template and "{}" not in args.input_template: 315 | print("[Warning] {} not in args.input_template, set to None") 316 | args.input_template = None 317 | 318 | if args.input_template and "\\n" in args.input_template: 319 | print( 320 | "[Warning] input_template contains \\n chracters instead of newline. " 321 | "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." 322 | ) 323 | 324 | if args.use_ms: 325 | from modelscope.utils.hf_util import patch_hub 326 | 327 | # Patch hub to download models from modelscope to speed up. 328 | # patch_hub() 329 | 330 | train(args) 331 | -------------------------------------------------------------------------------- /kd/customized_sft_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import pdb 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset 6 | from openrlhf.datasets.utils import zero_pad_sequences 7 | import logging 8 | import torch.distributed as dist 9 | logger = logging.getLogger('customized sft dataset') 10 | def get_teacher_student_prompt(data,multi_gold_doc=False): 11 | # logger.info(f"Get teacher student prompt") 12 | question = data["question"] 13 | distractors = data["distractors"] 14 | gold_doc = data["gold_document"] 15 | 16 | # key = "idx" if multi_gold_doc else "id" 17 | # distractor_indexs = [x[key] for x in distractors] 18 | # gold_indexs = [x[key] for x in gold_doc] 19 | 20 | if not multi_gold_doc: 21 | distractors.insert(0, gold_doc) 22 | else: 23 | distractors[len(distractors):len(distractors)] = gold_doc 24 | # complex_indexs = [x[key] for x in distractors] 25 | # assert complex_indexs[-len(gold_indexs):] == gold_indexs 26 | # print(f"complex_indexs: {complex_indexs}") 27 | # print(f"distractor_indexs: {distractor_indexs}") 28 | # print(f"gold_indexs: {gold_indexs}") 29 | 30 | 31 | 32 | teacher_documents = distractors 33 | student_documents = data["documents"] 34 | 35 | student_context = ''.join([f"- Title: {doc['title']}\n{doc['text']}\n" for doc in student_documents]) 36 | teacher_context = ''.join([f"- Title: {doc['title']}\n{doc['text']}\n" for doc in teacher_documents]) 37 | if not multi_gold_doc: 38 | task_instruction = "Please write a high-quantify answer for the given question using only the provided search documents (some of which might be irrelevant)." 39 | else: 40 | task_instruction = """Let’s first identify the relevant information from the long context and list it. Then, carry out step-by-step reasoning based on that information, and \ 41 | finally, provide the answer. The final answer must end with “The answer is:”.""" 42 | student_prompt = f"{task_instruction}\n{student_context}\nQuestion: {question}\n" 43 | teacher_prompt = f"{task_instruction}\n{teacher_context}\nQuestion: {question}\n" 44 | 45 | return student_prompt, teacher_prompt 46 | 47 | 48 | 49 | def preprocess_data( 50 | data, input_template=None, input_key="input", output_key=None, apply_chat_template=None, multiturn=False,multi_gold_doc=False 51 | ): 52 | if apply_chat_template: 53 | # logger.info(f"Apply chat template") 54 | if output_key: 55 | # question = data[input_key] 56 | # oracle_answer = data[output_key] 57 | student_prompt, teacher_prompt = get_teacher_student_prompt(data,multi_gold_doc) 58 | if not multi_gold_doc: 59 | response = data[output_key].split("\n")[0] 60 | else: 61 | response = data[output_key] 62 | 63 | 64 | if isinstance(student_prompt, str) and isinstance(teacher_prompt, str) and isinstance(response, str): 65 | student_prompt_message = [{"role": "user", "content": student_prompt}] 66 | teacher_prompt_message = [{"role": "user", "content": teacher_prompt}] 67 | response_message = [{"role": "assistant", "content": response}] 68 | 69 | student_template_prompt = apply_chat_template(student_prompt_message, tokenize=False, add_generation_prompt=True) 70 | teacher_template_prompt = apply_chat_template(teacher_prompt_message, tokenize=False, add_generation_prompt=True) 71 | student_response = apply_chat_template(student_prompt_message + response_message, tokenize=False)[len(student_template_prompt):] 72 | teacher_response = apply_chat_template(teacher_prompt_message + response_message, tokenize=False)[len(teacher_template_prompt):] 73 | assert student_response == teacher_response 74 | else: 75 | prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True) 76 | response = apply_chat_template(data[input_key], tokenize=False)[len(prompt) :] 77 | else: 78 | prompt = data[input_key] 79 | if input_template: 80 | prompt = input_template.format(prompt) 81 | # output_key is None for continue pretrain 82 | response = data[output_key] if output_key else "" 83 | return student_template_prompt, teacher_template_prompt, student_response 84 | 85 | 86 | class SFTDataset(Dataset): 87 | """ 88 | Dataset for SFT model 89 | 90 | Args: 91 | dataset: dataset for SFT model 92 | tokenizer: tokenizer for SFT model 93 | max_length: max length of input 94 | """ 95 | 96 | def __init__( 97 | self, 98 | dataset, 99 | tokenizer: Callable, 100 | max_length: int, 101 | strategy, 102 | input_template=None, 103 | pretrain_mode=False, 104 | num_processors= 48, # Specify the number of processors you want to use 105 | multiple_of=1, 106 | multiturn=False, 107 | multi_gold_doc=False 108 | ) -> None: 109 | super().__init__() 110 | 111 | self.tokenizer = tokenizer 112 | self.strategy = strategy 113 | self.pretrain_mode = pretrain_mode 114 | self.max_length = max_length 115 | self.multiple_of = multiple_of 116 | self.multiturn = multiturn 117 | self.multi_gold_doc = multi_gold_doc 118 | 119 | # chat template 120 | self.input_template = input_template 121 | self.input_key = getattr(self.strategy.args, "input_key", None) 122 | self.output_key = getattr(self.strategy.args, "output_key", None) 123 | self.apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) 124 | 125 | if self.apply_chat_template: 126 | self.apply_chat_template = self.tokenizer.apply_chat_template 127 | tokenizer_chat_template = getattr(self.strategy.args, "tokenizer_chat_template", None) 128 | if tokenizer_chat_template: 129 | self.tokenizer.chat_template = tokenizer_chat_template 130 | 131 | # Parallel loading datasets 132 | 133 | processed_dataset = dataset.map( 134 | self.process_data, 135 | remove_columns=dataset.column_names, 136 | num_proc=num_processors, 137 | ) 138 | processed_dataset = processed_dataset.filter(lambda x: x["teacher_prompt"] is not None and x["student_prompt"] is not None and x["response"] is not None) 139 | 140 | # Store the processed data in class attributes 141 | self.student_prompts = processed_dataset["student_prompt"] 142 | self.teacher_prompts = processed_dataset["teacher_prompt"] 143 | self.responses = processed_dataset["response"] 144 | self.student_prompts_ids_lens = processed_dataset["student_prompt_ids_len"] 145 | self.teacher_prompts_ids_lens = processed_dataset["teacher_prompt_ids_len"] 146 | self.response_ranges = processed_dataset["response_ranges"] if self.multiturn else None 147 | self.gold_idx = processed_dataset["gold_idx"] 148 | 149 | def process_data(self, data): 150 | if self.multiturn and self.output_key: 151 | data[self.input_key].append(data[self.output_key]) 152 | data[self.output_key] = None 153 | 154 | if self.multiturn: 155 | assert ( 156 | not self.output_key or not data[self.output_key] 157 | ), "You should put the whole trajactory into data[input_key] and do not set output_key" 158 | input_key = self.input_key 159 | apply_chat_template = self.apply_chat_template 160 | response_ranges = [] 161 | for idx, message in enumerate(data[input_key]): 162 | if message["role"] == "assistant": 163 | prompt = apply_chat_template(data[input_key][:idx], tokenize=False, add_generation_prompt=True) 164 | response = apply_chat_template(data[input_key][: idx + 1], tokenize=False)[len(prompt) :] 165 | 166 | start_idx = ( 167 | self.tokenizer( 168 | prompt, 169 | max_length=self.max_length, 170 | padding=False, 171 | truncation=True, 172 | return_tensors="pt", 173 | add_special_tokens=False, 174 | )["attention_mask"] 175 | .int() 176 | .sum() 177 | .item() 178 | ) 179 | 180 | end_idx = ( 181 | start_idx 182 | + self.tokenizer( 183 | response, 184 | max_length=self.max_length, 185 | padding=False, 186 | truncation=True, 187 | return_tensors="pt", 188 | add_special_tokens=False, 189 | )["attention_mask"] 190 | .int() 191 | .sum() 192 | .item() 193 | - 1 194 | ) 195 | response_ranges.append((start_idx, end_idx)) # left close right open 196 | 197 | student_template_prompt, teacher_template_prompt, student_response = preprocess_data( 198 | data, 199 | None if self.pretrain_mode else self.input_template, 200 | self.input_key, 201 | self.output_key, 202 | apply_chat_template=None if self.pretrain_mode else self.apply_chat_template, 203 | multiturn=self.multiturn, 204 | multi_gold_doc=self.multi_gold_doc 205 | ) 206 | 207 | if not self.pretrain_mode: 208 | student_prompt_token = self.tokenizer( 209 | student_template_prompt, 210 | max_length=self.max_length, 211 | padding=False, 212 | truncation=True, 213 | return_tensors="pt", 214 | add_special_tokens=False, 215 | ) 216 | student_prompt_ids_len = student_prompt_token["attention_mask"].int().sum().item() 217 | 218 | teacher_prompt_token = self.tokenizer( 219 | teacher_template_prompt, 220 | max_length=self.max_length, 221 | padding=False, 222 | truncation=True, 223 | return_tensors="pt", 224 | add_special_tokens=False, 225 | ) 226 | teacher_prompt_ids_len = teacher_prompt_token["attention_mask"].int().sum().item() 227 | 228 | # filter the sample whose length is greater than max_length (2 for answer length) 229 | if not student_template_prompt or not student_response or student_prompt_ids_len >= self.max_length - 2: 230 | student_template_prompt = None 231 | else: 232 | prompt_ids_len = 0 233 | 234 | return { 235 | "teacher_prompt": teacher_template_prompt, 236 | "student_prompt":student_template_prompt, 237 | "response": student_response, 238 | "student_prompt_ids_len": student_prompt_ids_len, 239 | "teacher_prompt_ids_len": teacher_prompt_ids_len, 240 | "response_ranges": response_ranges if self.multiturn else None, 241 | "gold_idx":data["gold_idx"] 242 | } 243 | 244 | def __len__(self): 245 | length = len(self.teacher_prompts) 246 | return length 247 | 248 | def __getitem__(self, idx): 249 | teacher_prompt = self.teacher_prompts[idx] 250 | student_prompt = self.student_prompts[idx] 251 | response = self.responses[idx] 252 | teacher_prompt_ids_len = self.teacher_prompts_ids_lens[idx] 253 | student_prompt_ids_len = self.student_prompts_ids_lens[idx] 254 | 255 | if not self.pretrain_mode: 256 | teacher_text = (teacher_prompt + response).rstrip("\n") 257 | student_text = (student_prompt + response).rstrip("\n") 258 | 259 | if not teacher_text.endswith(self.tokenizer.eos_token): 260 | teacher_text += " " + self.tokenizer.eos_token 261 | if not student_text.endswith(self.tokenizer.eos_token): 262 | student_text += " " + self.tokenizer.eos_token 263 | else: 264 | text = prompt 265 | 266 | teacher_input_token = self.tokenizer( 267 | teacher_text, 268 | max_length=self.max_length, 269 | padding=False, 270 | truncation=True, 271 | return_tensors="pt", 272 | add_special_tokens=False, 273 | ) 274 | student_input_token = self.tokenizer( 275 | student_text, 276 | max_length=self.max_length, 277 | padding=False, 278 | truncation=True, 279 | return_tensors="pt", 280 | add_special_tokens=False, 281 | ) 282 | 283 | if not self.pretrain_mode: 284 | # to avoid EOS_token truncation 285 | teacher_input_token["input_ids"][0][-1] = self.tokenizer.eos_token_id 286 | teacher_input_token["attention_mask"][0][-1] = True 287 | 288 | student_input_token["input_ids"][0][-1] = self.tokenizer.eos_token_id 289 | student_input_token["attention_mask"][0][-1] = True 290 | info = { 291 | "teacher_input": teacher_prompt, 292 | "student_input": student_prompt, 293 | "output": response, 294 | "teacher_input_length": teacher_input_token["attention_mask"].int().sum().item(), 295 | "student_input_length": student_input_token["attention_mask"].int().sum().item(), 296 | "response_ranges": self.response_ranges[idx] if self.multiturn else None, 297 | "gold_idx":self.gold_idx[idx] 298 | } 299 | # print(info) 300 | return teacher_input_token["input_ids"], teacher_input_token["attention_mask"], teacher_prompt_ids_len, \ 301 | student_input_token["input_ids"], student_input_token["attention_mask"],student_prompt_ids_len, info 302 | 303 | def collate_fn(self, item_list): 304 | # logger.info(f"Collate fn ....") 305 | teacher_prompt_id_lens = [] 306 | student_prompt_id_lens = [] 307 | 308 | teacher_input_ids = [] 309 | student_input_ids = [] 310 | 311 | teacher_attention_masks = [] 312 | student_attention_masks = [] 313 | infos = {"teacher_prompt": [], "output": [],\ 314 | "student_prompt": [],"gold_idx":[]} 315 | 316 | for teacher_input_id, teacher_attention_mask, teacher_prompt_ids_len, \ 317 | student_input_id, student_attention_mask, student_prompt_ids_len, info in item_list: 318 | 319 | teacher_prompt_id_lens.append(teacher_prompt_ids_len) 320 | student_prompt_id_lens.append(student_prompt_ids_len) 321 | teacher_input_ids.append(teacher_input_id) 322 | student_input_ids.append(student_input_id) 323 | teacher_attention_masks.append(teacher_attention_mask) 324 | student_attention_masks.append(student_attention_mask) 325 | infos["teacher_prompt"].append(info["teacher_input"]) 326 | infos["student_prompt"].append(info["student_input"]) 327 | infos["output"].append(info["output"]) 328 | infos["gold_idx"].append(info["gold_idx"]) 329 | 330 | 331 | teacher_input_ids = zero_pad_sequences(teacher_input_ids, "right", self.tokenizer.pad_token_id) 332 | teacher_attention_masks = zero_pad_sequences(teacher_attention_masks, "right") 333 | 334 | student_input_ids = zero_pad_sequences(student_input_ids, "right", self.tokenizer.pad_token_id) 335 | student_attention_masks = zero_pad_sequences(student_attention_masks, "right") 336 | return teacher_prompt_id_lens,teacher_input_ids,teacher_attention_masks, \ 337 | student_prompt_id_lens,student_input_ids,student_attention_masks,infos 338 | 339 | def packing_collate_fn(self, item_list): 340 | packed_input_ids = [] 341 | packed_attention_masks = [] 342 | prompt_ids_lens = [] 343 | infos = {"input_length": [], "response_ranges": [] if self.multiturn else None} 344 | index = 1 345 | for prompt_ids_len, input_id, attention_mask, info in item_list: 346 | packed_input_ids.append(input_id.flatten()) 347 | packed_attention_masks.append(torch.full_like(input_id.flatten(), index)) 348 | prompt_ids_lens.append(prompt_ids_len) 349 | infos["input_length"].append(info["input_length"]) 350 | if self.multiturn: 351 | if len(infos["response_ranges"]) >= 1: 352 | for i in range(len(info["response_ranges"])): 353 | info["response_ranges"][i][0] += infos["response_ranges"][-1][-1][ 354 | 1 355 | ] # end_index of the last response of the last item 356 | info["response_ranges"][i][1] += infos["response_ranges"][-1][-1][1] 357 | infos["response_ranges"].append(info["response_ranges"]) 358 | index += 1 359 | 360 | packed_input_ids = torch.cat(packed_input_ids, dim=0).unsqueeze(0) 361 | packed_attention_masks = torch.cat(packed_attention_masks, dim=0).unsqueeze(0) 362 | 363 | if ( 364 | self.multiple_of > 1 and packed_input_ids.numel() % self.multiple_of != 0 365 | ): # not divisible by multiple_of; here we align for grouping 366 | padding_len = self.multiple_of - (packed_input_ids.numel() % self.multiple_of) 367 | packed_input_ids = F.pad(packed_input_ids, (0, padding_len), value=self.tokenizer.pad_token_id) 368 | packed_attention_masks = F.pad(packed_attention_masks, (0, padding_len), value=0) 369 | 370 | return prompt_ids_lens, packed_input_ids, packed_attention_masks, infos 371 | -------------------------------------------------------------------------------- /kd/loss_design.py: -------------------------------------------------------------------------------- 1 | from openrlhf.trainer import KDTrainer 2 | from typing import Optional, Tuple,List 3 | from tqdm import tqdm 4 | import torch 5 | import torch.distributed as dist 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import os 9 | from abc import ABC 10 | import torch 11 | 12 | from torch.optim import Optimizer 13 | from tqdm import tqdm 14 | from openrlhf.models import GPTLMLoss, KDLoss 15 | from openrlhf.utils.distributed_sampler import DistributedSampler 16 | import logging 17 | from typing import Optional, Tuple 18 | import torch 19 | from torch.distributed import get_rank, get_world_size, all_gather, broadcast 20 | import torch.distributed as dist 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | class AdaptiveKLWeightedKLLoss(nn.Module): 24 | """ 25 | KL Divergence Loss with Adaptive KL-Based Weighting 26 | """ 27 | def __init__(self, temperature=1.0, batch_norm=False,attention_dillution_coef=1): 28 | super().__init__() 29 | self.IGNORE_INDEX = -100 30 | self.temperature = temperature 31 | self.epsilon = 1e-4 # 防止除零 32 | self.attention_dillution_coef = attention_dillution_coef 33 | self.batch_norm = batch_norm 34 | 35 | def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, 36 | student_label: torch.Tensor, teacher_label: torch.Tensor,gold_idxs: List[int]) -> torch.Tensor: 37 | """ 38 | Inputs: 39 | - logits: Student logits, shape [B, L, C] 40 | - teacher_logits: Teacher logits, shape [B, L, C] 41 | - *_label: token-level masks, shape [B, L] 42 | """ 43 | gold_idxs = torch.tensor(gold_idxs, device=logits.device) 44 | 45 | teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L] 46 | student_mask = (student_label != self.IGNORE_INDEX) # [B, L] 47 | B = teacher_mask.shape[0] 48 | assert torch.all(torch.sum(teacher_mask, dim=1) == torch.sum(student_mask, dim=1)), "Mask mismatch" 49 | 50 | teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) # [B, L, C] 51 | student_log_probs = F.log_softmax(logits / self.temperature, dim=-1) # [B, L, C] 52 | kl_divs = torch.stack([ 53 | F.kl_div( 54 | student_log_probs[i][student_mask[i]], 55 | teacher_log_probs[i][teacher_mask[i]].exp(), 56 | reduction="batchmean" 57 | ) 58 | for i in range(B) 59 | ]) 60 | # mask = gold_idxs == 0 61 | # local_invaried_tensor = kl_divs[mask].sum() # [1] 62 | # local_kl_tensor = kl_divs[~mask] # [3] 63 | local_kl_tensor = kl_divs 64 | print(f"rank = {get_rank()}") 65 | print(f"local_kl_tensor = {local_kl_tensor}") 66 | print(f"gold idxs = {gold_idxs}") 67 | if self.batch_norm: 68 | world_size = dist.get_world_size() 69 | gathered_kl_list = [torch.zeros_like(local_kl_tensor) for _ in range(world_size)] 70 | gathered_gold_indexs = [torch.zeros_like(gold_idxs) for _ in range(world_size)] 71 | all_gather(gathered_kl_list, local_kl_tensor) 72 | all_gather(gathered_gold_indexs, gold_idxs) 73 | all_kl = torch.stack(gathered_kl_list, dim=0) # [8, 4] 74 | all_gold_indexs = torch.stack(gathered_gold_indexs, dim=0) #[8,4] 75 | print(f"all_kl.shape = {all_kl.shape}") 76 | print(f"all_gold_index shape = {all_gold_indexs.shape}") 77 | # 仅在Rank 0计算权重后广播,确保一致性 78 | if dist.get_rank() == 0: 79 | weights = torch.zeros_like(all_gold_indexs,dtype=torch.float32) 80 | mask = all_gold_indexs == 0 81 | weights[mask] = self.attention_dillution_coef 82 | unique_gold_indexs = torch.unique(all_gold_indexs[~mask]) 83 | # unique_gold_indexs = torch.unique(all_gold_indexs) 84 | pos_mask_list = {} 85 | pos_kl_list = [] 86 | for unique_gold_index in unique_gold_indexs: 87 | pos_mask = all_gold_indexs == unique_gold_index 88 | avg_pos_kl = torch.sum(all_kl[pos_mask])/torch.sum(pos_mask) 89 | # add frequency 90 | 91 | pos_kl_list.append(avg_pos_kl) 92 | print(f"uniaue gold indexs = {unique_gold_indexs}") 93 | print(f"pos_kl_list = {pos_kl_list}") 94 | 95 | pos_kl_tensor = torch.stack(pos_kl_list,dim=0) 96 | # pos_kl_coefs = pos_kl_tensor*pos_kl_tensor.numel()/(torch.sum(pos_kl_tensor)+self.epsilon) 97 | pos_kl_coefs =F.softmax(pos_kl_tensor,dim=0)*pos_kl_tensor.numel() 98 | # pos_kl_coefs =F.softmax(pos_kl_tensor,dim=0) 99 | 100 | 101 | print(f"pos_kl_coefs = {pos_kl_coefs}") 102 | 103 | for pos_idx,unique_gold_index in enumerate(unique_gold_indexs): 104 | pos_mask = all_gold_indexs == unique_gold_index 105 | weights[pos_mask] = pos_kl_coefs[pos_idx]/torch.sum(pos_mask)*(all_kl[pos_mask]/max(all_kl[pos_mask])) 106 | 107 | else: 108 | weights = torch.zeros_like(all_gold_indexs,dtype=torch.float32) 109 | dist.broadcast(weights, src=0) 110 | 111 | print(f"rank = {get_rank()}, weights = {weights}") 112 | print(f"rank = {get_rank()}, all_kl = {all_kl}") 113 | total_loss = torch.sum(weights[get_rank()].detach() * local_kl_tensor) 114 | else: 115 | total_loss = torch.sum(local_kl_tensor) + local_invaried_tensor 116 | # total_loss = torch.sum(torch.clamp((local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor)).detach(),min=0.3,max=self.attention_dillution_coef) * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef 117 | # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef 118 | # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor) 119 | 120 | print(f"total_loss = {total_loss}") 121 | return total_loss * (self.temperature ** 2) 122 | 123 | 124 | class AdaptiveKLWeightedKLLoss_multi(nn.Module): 125 | """ 126 | KL Divergence Loss with Adaptive KL-Based Weighting 127 | """ 128 | def __init__(self, temperature=1.0, batch_norm=False,attention_dillution_coef=1): 129 | super().__init__() 130 | self.IGNORE_INDEX = -100 131 | self.temperature = temperature 132 | self.epsilon = 1e-4 # 防止除零 133 | self.attention_dillution_coef = attention_dillution_coef 134 | self.batch_norm = batch_norm 135 | 136 | def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, 137 | student_label: torch.Tensor, teacher_label: torch.Tensor,gold_idxs: List[int]) -> torch.Tensor: 138 | """ 139 | Inputs: 140 | - logits: Student logits, shape [B, L, C] 141 | - teacher_logits: Teacher logits, shape [B, L, C] 142 | - *_label: token-level masks, shape [B, L] 143 | """ 144 | gold_idxs = torch.tensor(gold_idxs, device=logits.device) 145 | 146 | teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L] 147 | student_mask = (student_label != self.IGNORE_INDEX) # [B, L] 148 | B = teacher_mask.shape[0] 149 | assert torch.all(torch.sum(teacher_mask, dim=1) == torch.sum(student_mask, dim=1)), "Mask mismatch" 150 | 151 | teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) # [B, L, C] 152 | student_log_probs = F.log_softmax(logits / self.temperature, dim=-1) # [B, L, C] 153 | # breakpoint() 154 | kl_divs = torch.stack([ 155 | F.kl_div( 156 | student_log_probs[i][student_mask[i]], 157 | teacher_log_probs[i][teacher_mask[i]].exp(), 158 | reduction="batchmean" 159 | ) 160 | for i in range(B) 161 | ]) 162 | # mask = gold_idxs == 0 163 | # local_invaried_tensor = kl_divs[mask].sum() # [1] 164 | # local_kl_tensor = kl_divs[~mask] # [3] 165 | local_kl_tensor = kl_divs 166 | print(f"rank = {get_rank()}") 167 | print(f"local_kl_tensor = {local_kl_tensor}") 168 | print(f"gold idxs = {gold_idxs}") 169 | print(f"observed pos: {torch.sum(teacher_mask, dim=1)}") 170 | if self.batch_norm: 171 | world_size = dist.get_world_size() 172 | gathered_kl_list = [torch.zeros_like(local_kl_tensor) for _ in range(world_size)] 173 | gathered_gold_indexs = [torch.zeros_like(gold_idxs) for _ in range(world_size)] 174 | all_gather(gathered_kl_list, local_kl_tensor) 175 | all_gather(gathered_gold_indexs, gold_idxs) 176 | all_kl = torch.stack(gathered_kl_list, dim=0) # [8, 4] 177 | all_gold_indexs = torch.stack(gathered_gold_indexs, dim=0) #[8,4,2] 178 | print(f"all_kl.shape = {all_kl.shape}") 179 | print(f"all_gold_index shape = {all_gold_indexs.shape}") 180 | # 仅在Rank 0计算权重后广播,确保一致性 181 | if dist.get_rank() == 0: 182 | m,n,_ = all_gold_indexs.size() 183 | weights = torch.zeros((m,n),dtype=torch.float32,device=all_gold_indexs.device) 184 | # weights = F.softmax(all_kl.flatten())*all_kl.numel().view(m,n) 185 | # mask = (all_gold_indexs == torch.tensor([0, 1],device=all_gold_indexs.device)).all(dim=-1) 186 | # weights[mask] = self.attention_dillution_coef 187 | # weights[~mask] = all_kl[~mask]/torch.sum(all_kl[~mask])*all_kl[~mask].numel() 188 | # weights = F.softmax(all_kl.flatten(),dim=0).view(m,n)*all_kl.numel() 189 | weights = all_kl.flatten()/torch.sum(all_kl.flatten()+self.epsilon)*all_kl.numel() 190 | 191 | 192 | 193 | # mask = all_gold_indexs == 0 194 | # weights[mask] = self.attention_dillution_coef 195 | # unique_gold_indexs = torch.unique(all_gold_indexs[~mask]) 196 | # # unique_gold_indexs = torch.unique(all_gold_indexs) 197 | # pos_mask_list = {} 198 | # pos_kl_list = [] 199 | # for unique_gold_index in unique_gold_indexs: 200 | # pos_mask = all_gold_indexs == unique_gold_index 201 | # avg_pos_kl = torch.sum(all_kl[pos_mask])/torch.sum(pos_mask) 202 | # # add frequency 203 | 204 | # pos_kl_list.append(avg_pos_kl) 205 | # print(f"unique gold indexs = {unique_gold_indexs}") 206 | # print(f"pos_kl_list = {pos_kl_list}") 207 | 208 | # pos_kl_tensor = torch.stack(pos_kl_list,dim=0) 209 | # # pos_kl_coefs = pos_kl_tensor*pos_kl_tensor.numel()/(torch.sum(pos_kl_tensor)+self.epsilon) 210 | # pos_kl_coefs =F.softmax(pos_kl_tensor,dim=0)*pos_kl_tensor.numel() 211 | # # pos_kl_coefs =F.softmax(pos_kl_tensor,dim=0) 212 | 213 | 214 | # print(f"pos_kl_coefs = {pos_kl_coefs}") 215 | 216 | # for pos_idx,unique_gold_index in enumerate(unique_gold_indexs): 217 | # pos_mask = all_gold_indexs == unique_gold_index 218 | # weights[pos_mask] = pos_kl_coefs[pos_idx]/torch.sum(pos_mask)*(all_kl[pos_mask]/max(all_kl[pos_mask])) 219 | 220 | else: 221 | m,n,_ = all_gold_indexs.size() 222 | weights = torch.zeros((m,n),dtype=torch.float32,device=all_gold_indexs.device) 223 | dist.broadcast(weights, src=0) 224 | 225 | print(f"rank = {get_rank()}, weights = {weights}") 226 | print(f"rank = {get_rank()}, all_kl = {all_kl}") 227 | total_loss = torch.sum(weights[get_rank()].detach() * local_kl_tensor) 228 | else: 229 | total_loss = torch.sum(local_kl_tensor) + local_invaried_tensor 230 | # total_loss = torch.sum(torch.clamp((local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor)).detach(),min=0.3,max=self.attention_dillution_coef) * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef 231 | # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef 232 | # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor) 233 | 234 | print(f"total_loss = {total_loss}") 235 | return total_loss * (self.temperature ** 2) 236 | # class AdaptiveKLWeightedKLLoss(nn.Module): 237 | # """ 238 | # KL Divergence Loss with Adaptive KL-Based Weighting 239 | # """ 240 | # def __init__(self, temperature=1.0, batch_norm=False,attention_dillution_coef=1): 241 | # super().__init__() 242 | # self.IGNORE_INDEX = -100 243 | # self.temperature = temperature 244 | # self.epsilon = 1e-4 # 防止除零 245 | # self.attention_dillution_coef = attention_dillution_coef 246 | # self.batch_norm = batch_norm 247 | 248 | # def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, 249 | # student_label: torch.Tensor, teacher_label: torch.Tensor,gold_idxs: List[int]) -> torch.Tensor: 250 | # """ 251 | # Inputs: 252 | # - logits: Student logits, shape [B, L, C] 253 | # - teacher_logits: Teacher logits, shape [B, L, C] 254 | # - *_label: token-level masks, shape [B, L] 255 | # """ 256 | # gold_idxs = torch.tensor(gold_idxs, device=logits.device) 257 | 258 | # teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L] 259 | # student_mask = (student_label != self.IGNORE_INDEX) # [B, L] 260 | # B = teacher_mask.shape[0] 261 | # assert torch.all(torch.sum(teacher_mask, dim=1) == torch.sum(student_mask, dim=1)), "Mask mismatch" 262 | 263 | # teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) # [B, L, C] 264 | # student_log_probs = F.log_softmax(logits / self.temperature, dim=-1) # [B, L, C] 265 | # kl_divs = torch.stack([ 266 | # F.kl_div( 267 | # student_log_probs[i][student_mask[i]], 268 | # teacher_log_probs[i][teacher_mask[i]].exp(), 269 | # reduction="batchmean" 270 | # ) 271 | # for i in range(B) 272 | # ]) 273 | # # mask = gold_idxs == 0 274 | # # local_invaried_tensor = kl_divs[mask].sum() # [1] 275 | # # local_kl_tensor = kl_divs[~mask] # [3] 276 | # local_kl_tensor = kl_divs 277 | # print(f"rank = {get_rank()}") 278 | # print(f"local_kl_tensor = {local_kl_tensor}") 279 | # print(f"gold idxs = {gold_idxs}") 280 | # if self.batch_norm: 281 | # world_size = dist.get_world_size() 282 | # gathered_kl_list = [torch.zeros_like(local_kl_tensor) for _ in range(world_size)] 283 | # gathered_gold_indexs = [torch.zeros_like(gold_idxs) for _ in range(world_size)] 284 | # all_gather(gathered_kl_list, local_kl_tensor) 285 | # all_gather(gathered_gold_indexs, gold_idxs) 286 | # all_kl = torch.stack(gathered_kl_list, dim=0) # [8, 4] 287 | # all_gold_indexs = torch.stack(gathered_gold_indexs, dim=0) #[8,4] 288 | # print(f"all_kl.shape = {all_kl.shape}") 289 | # print(f"all_gold_index shape = {all_gold_indexs.shape}") 290 | # # 仅在Rank 0计算权重后广播,确保一致性 291 | # if dist.get_rank() == 0: 292 | # mask = all_gold_indexs == 0 293 | # weights = torch.zeros_like(all_gold_indexs,dtype=torch.float32) 294 | # weights[mask] = self.attention_dillution_coef 295 | # print(f"mask = {mask}") 296 | # print(f"~mask = {~mask}") 297 | # weights[~mask] = all_kl[~mask]*torch.sum(~mask)/(torch.sum(all_kl[~mask]) + self.epsilon) 298 | # else: 299 | # weights = torch.zeros_like(all_gold_indexs,dtype=torch.float32) 300 | # dist.broadcast(weights, src=0) 301 | 302 | # print(f"rank = {get_rank()}, weights = {weights}") 303 | # print(f"rank = {get_rank()}, all_kl = {all_kl}") 304 | # total_loss = torch.sum(weights[get_rank()].detach() * local_kl_tensor) 305 | # else: 306 | # total_loss = torch.sum(local_kl_tensor) + local_invaried_tensor 307 | # # total_loss = torch.sum(torch.clamp((local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor)).detach(),min=0.3,max=self.attention_dillution_coef) * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef 308 | # # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef 309 | # # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor) 310 | 311 | # print(f"total_loss = {total_loss}") 312 | # return total_loss * (self.temperature ** 2) 313 | class MY_KDLoss(nn.Module): 314 | """ 315 | Language Model Knowledge Distillation Loss 316 | """ 317 | def __init__(self): 318 | super().__init__() 319 | self.IGNORE_INDEX = -100 320 | self.temperature = 1.0 321 | self.chunk_size = 64 322 | self.decay_factor = 0.99 323 | def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, 324 | student_label: torch.Tensor, teacher_label: torch.Tensor) -> torch.Tensor: 325 | # 计算教师和学生的有效 token mask 326 | teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L] 327 | student_mask = (student_label != self.IGNORE_INDEX) # [B, L] 328 | # 确保 mask 位置一致 329 | assert torch.sum(teacher_mask) == torch.sum(student_mask), "Teacher and student masks must have the same number of valid tokens" 330 | B = teacher_mask.shape[0] 331 | L = teacher_mask.shape[1] 332 | teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1, dtype=torch.float32) # 教师 log_softmax 333 | student_log_probs = F.log_softmax(logits / self.temperature, dim=-1, dtype=torch.float32) # [B, L, V] 334 | 335 | position_weights = torch.tensor([3*self.decay_factor ** i for i in range(L)], dtype=torch.float32, device=teacher_logits.device) 336 | # 对每个样本的有效位置计算 KL loss 337 | kl_divs = [] 338 | for i in range(B): 339 | kl_loss = F.kl_div(student_log_probs[i][student_mask[i]], teacher_log_probs[i][teacher_mask[i]].exp(),reduction="none") # seq*V 340 | # breakpoint() 341 | 342 | kl_loss = torch.mean(torch.sum(kl_loss,dim=-1)*position_weights[:torch.sum(student_mask[i])]) 343 | 344 | kl_divs.append(kl_loss) 345 | kl_divs = torch.stack(kl_divs) 346 | # 计算最终的 KL loss 347 | return kl_divs.mean() 348 | 349 | # class MY_KDLoss(nn.Module): 350 | # """ 351 | # Language Model Knowledge Distillation Loss 352 | # """ 353 | # def __init__(self): 354 | # super().__init__() 355 | # self.IGNORE_INDEX = -100 356 | # self.temperature = 1.0 357 | # def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, 358 | # student_label: torch.Tensor, teacher_label: torch.Tensor) -> torch.Tensor: 359 | # # 计算教师和学生的有效 token mask 360 | # teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L] 361 | # student_mask = (student_label != self.IGNORE_INDEX) # [B, L] 362 | # # 确保 mask 位置一致 363 | # B = teacher_mask.shape[0] 364 | # assert torch.sum(teacher_mask) == torch.sum(student_mask), "Teacher and student masks must have the same number of valid tokens" 365 | 366 | # teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1, dtype=torch.float32) # 教师 log_softmax 367 | # student_log_probs = F.log_softmax(logits / self.temperature, dim=-1, dtype=torch.float32) # 学生 log_softmax 368 | 369 | # kl_divs = torch.stack([ 370 | # F.kl_div( 371 | # student_log_probs[i][student_mask[i]], 372 | # teacher_log_probs[i][teacher_mask[i]].exp(), 373 | # reduction="batchmean" 374 | # ) 375 | # for i in range(B) 376 | # ]) 377 | # kl_loss = kl_divs.mean() 378 | # return kl_loss * (self.temperature ** 2) 379 | class MY_rankLoss(nn.Module): 380 | """ 381 | Language Model Knowledge Distillation Loss 382 | """ 383 | def __init__(self,top_k=30, margin=0.5): 384 | super().__init__() 385 | self.IGNORE_INDEX = -100 386 | self.temperature = 2.0 387 | self.top_k = top_k 388 | self.margin = margin 389 | self.ranking_loss = nn.MarginRankingLoss(margin=self.margin) 390 | def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, 391 | student_label: torch.Tensor, teacher_label: torch.Tensor) -> torch.Tensor: 392 | # 计算教师和学生的有效 token mask 393 | teacher_mask = (teacher_label != self.IGNORE_INDEX).unsqueeze(-1) # [B, L, 1] 394 | student_mask = (student_label != self.IGNORE_INDEX).unsqueeze(-1) # [B, L, 1] 395 | # 确保 mask 位置一致 396 | assert torch.sum(teacher_mask) == torch.sum(student_mask), "Teacher and student masks must have the same number of valid tokens" 397 | teacher_indices = teacher_mask.squeeze(-1).nonzero(as_tuple=True) 398 | student_indices = student_mask.squeeze(-1).nonzero(as_tuple=True) 399 | # 提取有效 token 的概率 400 | teacher_logits = teacher_logits[teacher_indices] # [有效Token数, C] 401 | student_logits = logits[student_indices] # [有效Token数, C] 402 | topk_values, topk_indices = torch.topk(teacher_logits, self.top_k, dim=-1) # 获取 top-k token 403 | student_topk_logits = torch.gather(student_logits, dim=-1, index=topk_indices) 404 | # 生成pairs 405 | i_indices, j_indices = torch.triu_indices(self.top_k, self.top_k, offset=1) 406 | # 计算pairwise差异 407 | teacher_i = topk_values[:, i_indices] 408 | teacher_j = topk_values[:, j_indices] 409 | student_i = student_topk_logits[:, i_indices] 410 | student_j = student_topk_logits[:, j_indices] 411 | 412 | # 生成ranking标签 413 | ranking_labels = (teacher_i > teacher_j).float() * 2 - 1 414 | 415 | # 添加权重 416 | rank_diff = torch.abs(i_indices - j_indices).float() 417 | weights = 1.0 / (rank_diff + 1) 418 | weights = weights.to(student_i.device) 419 | 420 | # 计算loss 421 | ranking_loss = (self.ranking_loss(student_i, student_j, ranking_labels) * weights).mean() 422 | 423 | return ranking_loss 424 | class MY_topkLoss(nn.Module): 425 | """ 426 | Language Model Knowledge Distillation Loss 427 | """ 428 | def __init__(self): 429 | super().__init__() 430 | self.IGNORE_INDEX = -100 431 | self.temperature = 1.0 432 | self.top_k = 20 433 | def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, 434 | student_label: torch.Tensor, teacher_label: torch.Tensor) -> torch.Tensor: 435 | # 计算教师和学生的有效 token mask 436 | teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L] 437 | student_mask = (student_label != self.IGNORE_INDEX) # [B, L] 438 | # 确保 mask 位置一致 439 | assert torch.sum(teacher_mask) == torch.sum(student_mask), "Teacher and student masks must have the same number of valid tokens" 440 | teacher_indices = teacher_mask.squeeze(-1).nonzero(as_tuple=True) 441 | student_indices = student_mask.squeeze(-1).nonzero(as_tuple=True) 442 | # 提取有效 token 的概率 443 | teacher_logits = teacher_logits[teacher_indices] # [有效Token数, C] 444 | student_logits = logits[student_indices] # [有效Token数, C] 445 | topk_values, topk_indices = torch.topk(teacher_logits, self.top_k, dim=-1) # 获取 top-k token 446 | student_topk_logits = torch.gather(student_logits, dim=-1, index=topk_indices) 447 | 448 | 449 | teacher_probs = F.softmax(topk_values / self.temperature, dim=-1) 450 | student_log_probs = F.log_softmax(student_topk_logits / self.temperature, dim=-1) 451 | kl_loss = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") 452 | return kl_loss * (self.temperature ** 2) 453 | class MY_GPTLMLoss(nn.Module): 454 | """ 455 | GPT Language Model Loss 456 | """ 457 | 458 | def __init__(self, ring_attn_group=None): 459 | super().__init__() 460 | self.IGNORE_INDEX = -100 461 | self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX) 462 | 463 | self.ring_attn_group = ring_attn_group 464 | if self.ring_attn_group: 465 | self.ring_attn_rank = dist.get_rank(self.ring_attn_group) 466 | self.ring_attn_world_size = dist.get_world_size(self.ring_attn_group) 467 | 468 | def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 469 | # RingAttention 470 | 471 | if self.ring_attn_group is not None: 472 | logger.info(f" gptloss is entered: the first choice of ring_attn_group.") 473 | total_seq_len = labels.size(-1) 474 | seq_len_per_process = total_seq_len // self.ring_attn_world_size 475 | start_idx = self.ring_attn_rank * seq_len_per_process 476 | end_idx = min(start_idx + seq_len_per_process, total_seq_len) 477 | labels = labels[..., start_idx:end_idx] 478 | 479 | shift_logits = logits[..., :-1, :].contiguous() 480 | shift_labels = labels[..., 1:].contiguous() 481 | 482 | # if labels are all IGNORE_INDEX, then nn.CrossEntropyLoss will be nan 483 | if torch.all(shift_labels == self.IGNORE_INDEX): 484 | # Use mean of logits multiplied by 0 to maintain gradient flow 485 | loss = shift_logits.mean() * 0 486 | else: 487 | loss = self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 488 | 489 | dist.all_reduce(loss, op=dist.ReduceOp.SUM, group=self.ring_attn_group) 490 | loss = loss / self.ring_attn_world_size 491 | else: 492 | # breakpoint() 493 | shift_logits = logits[..., :-1, :].contiguous() 494 | shift_labels = labels[..., 1:].contiguous() 495 | # loss = self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 496 | all_logits = [] 497 | all_labels = [] 498 | batch_size = shift_labels.shape[0] 499 | shift_mask = (shift_labels != self.IGNORE_INDEX).int().unsqueeze(-1) 500 | for i in range(batch_size): 501 | shift_nonzero_indices = shift_mask[i].nonzero()[:,0] 502 | shift_nonzero_logits = shift_logits[i,shift_nonzero_indices,:] 503 | shift_nonzero_labels = shift_labels[i,shift_nonzero_indices] 504 | all_logits.append(shift_nonzero_logits) 505 | all_labels.append(shift_nonzero_labels) 506 | 507 | # 将所有batch的数据拼接起来 508 | concatenated_logits = torch.cat(all_logits, dim=0) # shape: [total_valid_tokens, vocab_size] 509 | concatenated_labels = torch.cat(all_labels, dim=0) # shape: [total_valid_tokens] 510 | 511 | # 统一计算loss 512 | loss = self.loss(concatenated_logits, concatenated_labels) 513 | 514 | return loss 515 | --------------------------------------------------------------------------------