├── .DS_Store
├── assets
└── overview.jpg
├── scripts
├── gen_advantageous_pos_for_R1.sh
├── eval_data.sh
└── train_nq.sh
├── metrics.py
├── utils.py
├── create_datasets.py
├── gen_advantageous_pos_for_R1.py
├── requirements.txt
├── README.md
├── eval_data.py
└── kd
├── customized_kd_trainer.py
├── train_kd.py
├── customized_sft_dataset.py
└── loss_design.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AMAP-ML/Pos2Distill/HEAD/.DS_Store
--------------------------------------------------------------------------------
/assets/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AMAP-ML/Pos2Distill/HEAD/assets/overview.jpg
--------------------------------------------------------------------------------
/scripts/gen_advantageous_pos_for_R1.sh:
--------------------------------------------------------------------------------
1 | total_doc=20
2 | num_gpus=8
3 | # export CUDA_VISIBLE_DEVICES=3
4 | dataset_name="nq"
5 |
6 | INPUT_PATH="nq-open-${total_doc}_total_documents.jsonl.gz"
7 | model_name=Mistral-7B-Instruct-v0.3
8 | dataset_name=nq
9 | python gen_advantageous_pos_for_R1.py \
10 | --input-path "$INPUT_PATH" \
11 | --model $model_name \
12 | --output-path raw_data/$dataset_name \
13 | --max-prompt-length 32768 \
14 | --max-new-tokens 100 \
15 | --num_gpus "$num_gpus" \
16 | --total_doc "$total_doc" \
17 | --cache_dir $cache_dir \
18 | --gold_doc 0
19 |
20 |
21 | model_name=Mistral-7B-Instruct-v0.3
22 | dataset_name=nq
23 | python create_datasets.py \
24 | --model_name $model_name \
25 | --position_sample_num 4 \
26 | --example_num 400 \
27 | --dataset_name $dataset_name
--------------------------------------------------------------------------------
/scripts/eval_data.sh:
--------------------------------------------------------------------------------
1 | total_docs=20
2 | num_gpus=8
3 | # export CUDA_VISIBLE_DEVICES=0,1,2,3
4 | # export CUDA_VISIBLE_DEVICES=4,5,6,7
5 |
6 | random_num=4
7 | strengthen=1
8 | INPUT_PATH="nq-open-${total_docs}_total_documents.jsonl.gz"
9 | # model_name=Mistral-7B-Instruct-v0.3_20total_docs_filter_4random_1strengthen_300_kd0.0_lm0.0_rank0.0_adaptive1.0_1.0
10 | # model_name=Mistral-7B-Instruct-v0.3_20total_docs_filter_4random_1strengthen_400_kd0.0_lm0.0_rank0.0_adaptive1.0_1.0
11 | # /mnt/workspace/wangyifei/miniconda3/envs/openrlhf/bin/python eval_data.py \
12 | # --input-path "$INPUT_PATH" \
13 | # --model $model_name \
14 | # --output-path evaluate \
15 | # --sample_num 500 \
16 | # --max-prompt-length 32768 \
17 | # --max-new-tokens 100 \
18 | # --num_gpus "$num_gpus" \
19 | # --total_doc "$total_docs" \
20 |
21 | model_name=Mistral-7B-Instruct-v0.3
22 | INPUT_PATH="webq_dev.jsonl.gz"
23 | total_docs=20
24 | num_gpus=8
25 | python eval_data.py \
26 | --input-path "$INPUT_PATH" \
27 | --model $model_name \
28 | --output-path evaluate \
29 | --sample_num 500 \
30 | --max-prompt-length 32768 \
31 | --max-new-tokens 100 \
32 | --num_gpus "$num_gpus" \
33 | --total_doc "$total_docs" \
34 |
35 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | # Code for metrics comes from beam_retriever/blob/main/gpt_turbo_exp.py
2 |
3 | import re
4 | import string
5 | import collections
6 |
7 | def normalize_answer(s):
8 |
9 | def remove_articles(text):
10 | return re.sub(r'\b(a|an|the)\b', ' ', text)
11 |
12 | def white_space_fix(text):
13 | return ' '.join(text.split())
14 |
15 | def remove_punc(text):
16 | exclude = set(string.punctuation)
17 | return ''.join(ch for ch in text if ch not in exclude)
18 |
19 | def lower(text):
20 | return text.lower()
21 |
22 | return white_space_fix(remove_articles(remove_punc(lower(s))))
23 |
24 | def get_tokens(s):
25 | if not s:
26 | return []
27 | return normalize_answer(s).split()
28 |
29 | def compute_exact(a_gold, a_pred):
30 | return int(normalize_answer(a_gold) == normalize_answer(a_pred))
31 |
32 | def compute_subem(a_gold, a_pred):
33 | return int(normalize_answer(a_gold) in normalize_answer(a_pred))
34 |
35 | def compute_f1(a_gold, a_pred):
36 | gold_toks = get_tokens(a_gold)
37 | pred_toks = get_tokens(a_pred)
38 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
39 | num_same = sum(common.values())
40 | if len(gold_toks) == 0 or len(pred_toks) == 0:
41 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
42 | return int(gold_toks == pred_toks)
43 | if num_same == 0:
44 | return 0
45 | precision = 1.0 * num_same / len(pred_toks)
46 | recall = 1.0 * num_same / len(gold_toks)
47 | f1 = (2 * precision * recall) / (precision + recall)
48 | return f1
49 |
50 | # metrics for attribution
51 |
52 | def compute_attr_metrics(response, ground_truth):
53 | attr_matches = re.findall(r'\[(\d+)\]', response)
54 | attr_predicted = [int(match) for match in attr_matches]
55 |
56 | predicted_set = set(attr_predicted)
57 | ground_truth_set = set(ground_truth)
58 |
59 | true_positives = len(predicted_set & ground_truth_set)
60 | precision = true_positives / len(predicted_set) if predicted_set else 0
61 | recall = true_positives / len(ground_truth_set) if ground_truth_set else 0
62 | f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
63 |
64 | return precision, recall, f1
65 |
66 | def best_subspan_em(prediction, ground_truths) -> float:
67 | normalized_prediction = normalize_answer(prediction)
68 |
69 | for ground_truth in ground_truths:
70 | normalized_ground_truth = normalize_answer(ground_truth)
71 | if normalized_ground_truth.lower() in normalized_prediction.lower():
72 | return 1.0
73 | return 0.0
--------------------------------------------------------------------------------
/scripts/train_nq.sh:
--------------------------------------------------------------------------------
1 | ############################ RUNNING CONFIG ############################
2 | # export LD_PRELOAD="/usr/lib64/libjemalloc.so.2"
3 | # sudo mount -o size=20480M -o nr_inodes=1000000 -o noatime,nodiratime -o remount /dev/shm
4 | # if [ -f "/opt/conda/etc/profile.d/conda.sh" ]; then
5 | # . "/opt/conda/etc/profile.d/conda.sh"
6 | # else
7 | # export PATH="/opt/conda/bin:$PATH"
8 | # fi
9 | # conda activate /mnt/workspace/wangyifei/miniconda3/envs/openrlhf
10 | # workdir="/mnt/workspace/wangyifei/projects/self_training"
11 | # cd $workdir
12 | ############################ RUNNING CONFIG ############################
13 |
14 |
15 |
16 |
17 | export CUDA_HOME=/usr/local/cuda-12.4
18 | export LD_LIBRARY_PATH=${CUDA_HOME}/lib64
19 | export PATH=${CUDA_HOME}/bin:${PATH}
20 | export NCCL_DEBUG=ERROR
21 | export NCCL_IB_DISABLE=0
22 | export NCCL_P2P_LEVEL=PIX
23 | export HF_ENDPOINT=https://hf-mirror.com
24 | # export CUDA_VISIBLE_DEVICES=0,1,2,3
25 | # export CUDA_VISIBLE_DEVICES=4,5,6,7
26 | #Meta-Llama-3-8B-Instruct Qwen1.5-7B-Chat 0.1 0.2 0.3 0.4 0.0 0.6 0.7 0.8 0.9 1.0
27 | # /mnt/workspace/wangyifei/projects/self_training/kd/eval_kl.sh
28 |
29 |
30 |
31 | strengthen=1
32 | total_docs=20
33 | dataset_name="nq"
34 | for i in 0 ; do
35 | for model_name in Mistral-7B-Instruct-v0.3; do
36 | for num in 400; do
37 | for kd_coef in "0.0,0.0,0.0,1.0" ; do
38 | for K in 4 ; do
39 | deepspeed --master_port 6666 \
40 | kd/train_kd.py \
41 | --max_len 32000 \
42 | --dataset ${model_name}_${total_docs}total_docs_filter_${K}random_${strengthen}strengthen_${num}\
43 | --input_key question \
44 | --output_key oracle_answer \
45 | --train_batch_size 32 \
46 | --micro_train_batch_size 4 \
47 | --max_samples 500000 \
48 | --pretrain $model_name \
49 | --teacher_model $model_name \
50 | --save_path checkpoints/${dataset_name}\
51 | --save_steps -1 \
52 | --logging_steps 1 \
53 | --eval_steps -1 \
54 | --zero_stage 3 \
55 | --max_epochs 1 \
56 | --l2 0.01 \
57 | --bf16 \
58 | --flash_attn \
59 | --kd_coef $kd_coef \
60 | --learning_rate 3e-6 \
61 | --teacher_offload \
62 | --apply_chat_template \
63 | --gradient_checkpointing \
64 | --perserve 1.0 \
65 | --dataset_name nq \
66 | --use_tensorboard tensorBoard
67 | # --use_wandb "c0c629e7ba14b453e7da5da4ff86f33816c5cc6a" # > training_kl2.log 2>&1
68 | done
69 | done
70 | done
71 | done
72 | done
73 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import string
3 | import re
4 | from metrics import compute_exact,best_subspan_em
5 | from typing import List
6 | import regex
7 | import argparse
8 | import json
9 | import logging
10 | import statistics
11 | import sys
12 | from copy import deepcopy
13 | import os
14 | from tqdm import tqdm
15 | from xopen import xopen
16 | from pydantic.dataclasses import dataclass
17 | from typing import List, Optional, Tuple, Type, TypeVar
18 | logger = logging.getLogger(__name__)
19 | def read_xopen(file_path):
20 | data = []
21 | with xopen(file_path) as fin:
22 | for line in fin:
23 | example = json.loads(line)
24 | data.append(example)
25 | return data
26 | def write_xopen(file_path,data):
27 | with xopen(file_path, "w") as f:
28 | for d in data:
29 | f.write(json.dumps(d) + "\n")
30 |
31 | def get_template_prompt(ins,model_name,tokenizer):
32 | # assert tokenizer.chat_template is not None
33 | context = ''.join([f"- Title: {doc['title']}\n{doc['text']}\n" for doc in ins["documents"]])
34 | task_instruction = "Please write a high-quantify answer for the given question using only the provided search documents (some of which might be irrelevant)."
35 | prompt_message = f"{task_instruction}\n{context}\nQuestion: {ins['question']}\n"
36 | system_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
37 | messages = [
38 | {"role": "user", "content": prompt_message},
39 | ]
40 |
41 | if tokenizer.chat_template is not None:
42 | prompt = tokenizer.apply_chat_template(
43 | messages,
44 | tokenize=False,
45 | add_generation_prompt=True
46 | )
47 | ins["prompt"] = prompt
48 | return ins
49 |
50 |
51 | def get_metrics_for_example(example,METRICS):
52 | gold_answers = example["answers"]
53 | model_answer = example["current_answer"]
54 | # NOTE: we take everything up to the first newline, since otherwise models could hack
55 | model_answer = model_answer.strip("\n").split("\n")[0].strip()
56 | example_metrics = {}
57 | for (metric, metric_name) in METRICS:
58 | example_metrics[metric_name] = metric(prediction=model_answer, ground_truths=gold_answers)
59 | return (example_metrics, example)
60 |
61 |
62 |
63 | def evaluate_qa_data(input_path,output_path=None,sample_num = None):
64 | METRICS = [(best_subspan_em, "best_subspan_em"),]
65 | all_examples = []
66 | with xopen(input_path) as fin:
67 | for line in tqdm(fin):
68 | input_example = json.loads(line)
69 | all_examples.append(input_example)
70 | if sample_num:
71 | all_examples = all_examples[:sample_num]
72 | # Compute normal metrics in parallel, if applicable
73 | logger.info("Computing metrics")
74 | all_example_metrics = []
75 | for example in tqdm(all_examples):
76 | all_example_metrics.append(get_metrics_for_example(example,METRICS))
77 | # Average metrics across examples
78 | for (_, metric_name) in METRICS:
79 | average_metric_value = statistics.mean(
80 | example_metrics[metric_name] for (example_metrics, _) in all_example_metrics
81 | )
82 | print(f"{metric_name}: {average_metric_value}")
83 | logger.info(f"{metric_name}: {average_metric_value}")
84 |
85 | # summary_path = os.path.join(os.path.dirname(input_path),"A_metrics_summary.txt")
86 | # with xopen(summary_path,"a") as f:
87 | # f.write(f"{input_path.split('/')[-1].split('.jsonl.gz')[0]}\n{metric_name}: {average_metric_value}\n\n")
88 | if output_path:
89 | with xopen(output_path, "w") as f:
90 | for (example_metrics, example) in all_example_metrics:
91 | example_with_metrics = deepcopy(example)
92 | for metric_name, metric_value in example_metrics.items():
93 | example_with_metrics[f"metric_{metric_name}"] = metric_value
94 | f.write(json.dumps(example_with_metrics) + "\n")
95 | return average_metric_value
--------------------------------------------------------------------------------
/create_datasets.py:
--------------------------------------------------------------------------------
1 | from utils import best_subspan_em,read_xopen,write_xopen
2 | from xopen import xopen
3 | import json
4 | from random import sample
5 | import random
6 | from copy import deepcopy
7 | from collections import defaultdict
8 | from datasets import Dataset
9 | from collections import Counter
10 | import itertools
11 | import os
12 | import numpy as np
13 | import re
14 | from metrics import compute_exact
15 | from argparse import ArgumentParser
16 | random.seed(42)
17 | def best_subspan_em_musique(r,answers):
18 | match = re.search(r'the answer is[::]?\s*(.*)', r.lower())
19 | if match is None:
20 | return int(0)
21 | a_pred = match.group(1)
22 | qa_em_score = 0
23 | for a_gold in answers:
24 | qa_em_score = max(qa_em_score, compute_exact(a_gold,a_pred))
25 | return int(qa_em_score)
26 |
27 | def chunk_list(lst, chunk_size):
28 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
29 |
30 |
31 |
32 | def get_chunked_answer(model_answer):
33 | # chunked_answer = model_answer.strip("\n").strip(" ").split(".")[0] + "."
34 | chunked_answer = model_answer
35 | print(chunked_answer)
36 | return chunked_answer
37 |
38 | def process_example(example,position_sample_num,strengthen=2,total_doc=20,train_counter=Counter()):
39 | sample_examples = []
40 | gold_doc = example["gold_document"]
41 | sample_pool = list(range(1,total_doc)) if strengthen!=0 else list(range(total_doc+1))
42 | random_pos = sample(sample_pool,position_sample_num)
43 | print(f"iteration1: {random_pos}")
44 |
45 |
46 |
47 |
48 | # 加上位置 0 的加强数据
49 | sample_positions = random_pos + [0] * strengthen
50 | for gold_idx in sample_positions:
51 | new_example = deepcopy(example)
52 | new_example["distractors"] = random.sample(example["distractors"],len(example["distractors"]))
53 | new_example["distractors"].insert(gold_idx,gold_doc)
54 | new_example["documents"]=deepcopy(new_example["distractors"])
55 | new_example["distractors"]=example["distractors"]
56 | new_example["gold_idx"] = gold_idx
57 | assert len(new_example["documents"])==total_doc
58 | assert new_example["documents"] != new_example["distractors"]
59 | new_example["oracle_answer"]=get_chunked_answer(example["oracle_answer"])
60 | sample_examples.append(new_example)
61 | train_counter[gold_idx] += 1
62 |
63 | return sample_examples,train_counter
64 |
65 |
66 | # dict_keys(['question', 'answers', 'ctxs', 'nq_annotated_gold', 'gold_document', 'distractors', 'documents', 'prompt', 'model_answer'])
67 | def generate_dataset(model_name,position_sample_num,total_doc=20,strengthen=2,train_size=200,position=0,filter=True,dataset_name="nq"):
68 | datasets_path = f"datasets/{dataset_name}"
69 | file_path = f"raw_data/{dataset_name}"
70 | file_name = f"{model_name}_{total_doc}docs_{position}.jsonl.gz"
71 | file_path = os.path.join(file_path,file_name)
72 | position_statistics = []
73 | examples = read_xopen(file_path)
74 | train_datasets = defaultdict(list)
75 | print(f"example nums: {len(examples)}")
76 |
77 |
78 | # process train_examples
79 | cnt = 0
80 | train_counter = Counter({i: 0 for i in range(0,total_doc+1)})
81 |
82 | for idx,example in enumerate(examples):
83 | if cnt == train_size:
84 | break
85 | is_correct = best_subspan_em(get_chunked_answer(example["oracle_answer"]),example["answers"])
86 | # is_correct = best_subspan_em_musique(get_chunked_answer(example["oracle_answer"]),example["answers"])
87 | print(f"is_correct: {is_correct}")
88 | if filter:
89 | if int(is_correct)==0:
90 | continue
91 | if "predicted_label" not in example:
92 | example["predicted_label"] = "correct"
93 | if "predicted" not in example:
94 | example["predicted"] = example["answers"][0]
95 | if example["predicted_label"] == "incorrect":
96 | continue
97 | if not example["predicted"]:
98 | continue
99 | cnt+=1
100 | sample_examples,train_counter = process_example(example,position_sample_num,strengthen,total_doc,train_counter)
101 | assert len(sample_examples) == (position_sample_num + strengthen)
102 | train_datasets["data"].extend(sample_examples)
103 | # train_datasets["position_statistics"].append(random_pos)
104 | # train_positions = list(itertools.chain.from_iterable(train_datasets["position_statistics"]))
105 |
106 | print(f"Train filter generate datapoints number: {len(train_datasets['data'])}")
107 | train_last_idx = idx
108 | assert len(train_datasets["data"]) == train_size*(position_sample_num + strengthen)
109 | filter_name = "filter" if filter else "unfilter"
110 | train_datasets_path = os.path.join(datasets_path,f"{model_name}_{total_doc}total_docs_{filter_name}_{position_sample_num}random_{strengthen}strengthen_{train_size}")
111 | train = Dataset.from_list(train_datasets["data"])
112 | train.save_to_disk(train_datasets_path)
113 | # save position_statistics
114 | print(f"{model_name}_{file_name}_{position_sample_num}_{strengthen}: {train_counter}")
115 |
116 |
117 | position_statistics = {f"train_{file_name}":train_counter}
118 | os.makedirs(datasets_path+"/position_statistics",exist_ok=True)
119 | with open(datasets_path+f"/position_statistics/{model_name}.json", "w") as f:
120 | json.dump(position_statistics, f,indent=1)
121 | return
122 |
123 |
124 | if __name__=="__main__":
125 | parser = ArgumentParser()
126 | parser.add_argument("--model_name", type=str, default="Mistral-7B-Instruct-v0.3")
127 | parser.add_argument("--position_sample_num", type=int, default=4)
128 | parser.add_argument("--example_num", type=int, default=250)
129 | parser.add_argument("--dataset_name", type=str, default="nq")
130 | args = parser.parse_args()
131 | generate_dataset(args.model_name,position_sample_num=args.position_sample_num,total_doc=20,strengthen=1,train_size=args.example_num,position=0,filter=True,dataset_name=args.dataset_name)
132 |
133 |
--------------------------------------------------------------------------------
/gen_advantageous_pos_for_R1.py:
--------------------------------------------------------------------------------
1 | import os
2 | from transformers import AutoTokenizer, AutoModelForCausalLM
3 | import argparse
4 | import json
5 | import numpy as np
6 | import logging
7 | import pathlib
8 | import random
9 | import sys
10 | from copy import deepcopy
11 | import torch
12 | from tqdm import tqdm
13 | from transformers import AutoTokenizer,LlamaTokenizer
14 | from xopen import xopen
15 | from vllm import LLM, SamplingParams
16 | import pandas as pd
17 | import random
18 | import itertools
19 | from utils import *
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 |
24 |
25 | def seed_everything(seed):
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 | np.random.seed(seed)
29 | random.seed(seed)
30 | torch.backends.cudnn.benchmark = False
31 | torch.backends.cudnn.deterministic = True
32 | torch.cuda.manual_seed_all(seed)
33 |
34 | seed_everything(42)
35 | def get_prompt_at_gold_index(input_path,gold_doc,model_name,total_doc,tokenizer=None):
36 | examples = []
37 | prompts = []
38 | with xopen(input_path) as fin:
39 | for line in tqdm(fin):
40 | input_example = json.loads(line)
41 | gold_document = input_example["gold_document"]
42 | distractors = input_example["distractors"]
43 | random.shuffle(distractors)
44 | # Get the prediction for the input example
45 |
46 | shuffled_distractors = distractors[:]
47 | shuffled_distractors.insert(gold_doc, gold_document)
48 |
49 | ins = deepcopy(input_example)
50 | ins["documents"] = shuffled_distractors if total_doc >0 else []
51 | assert len(ins["documents"])==total_doc
52 | prompt = get_template_prompt(ins,model_name,tokenizer)["prompt"]
53 | prompts.append(prompt)
54 | examples.append(ins)
55 | return prompts,examples
56 |
57 | def write_responses(output_path,examples, responses,prompts):
58 | with xopen(output_path, "w") as f:
59 | for (example,response,prompt) in zip(examples,responses,prompts):
60 | example["oracle_answer"] = response
61 | example["oracle_prompt"] = prompt
62 | f.write(json.dumps(example) + "\n")
63 | average_metric_value = evaluate_qa_data(input_path=output_path)
64 | print(f"acc when gold doc at 0: {average_metric_value}")
65 | return average_metric_value
66 |
67 | def main(
68 | input_path,
69 | model_name,
70 | temperature,
71 | top_p,
72 | num_gpus,
73 | max_new_tokens,
74 | max_prompt_length,
75 | output_path,
76 | gold_doc,
77 | total_doc,
78 | cache_dir
79 | ):
80 | os.makedirs(output_path, exist_ok=True)
81 | output_path_gold_index = os.path.join(output_path,f"{model_name}_{total_doc}docs_{gold_doc}.jsonl.gz")
82 | model_path = os.path.join(cache_dir,model_name)
83 | model = LLM(
84 | model=os.path.join(cache_dir,model_name),
85 | tensor_parallel_size=num_gpus,
86 | # load_format="safetensors",
87 | # load_format="pt",
88 | max_num_batched_tokens=max_prompt_length,
89 | )
90 | sampling_params = SamplingParams(temperature=0.7, max_tokens=max_new_tokens,top_p=top_p)
91 | output_path_gold_index = os.path.join(output_path,f"{model_name}_{total_doc}docs_{gold_doc}.jsonl.gz")
92 | if "mistral" in model_name.lower():
93 | prompts,examples = get_prompt_at_gold_index(input_path,total_doc,model_name,total_doc,model.get_tokenizer())
94 | else:
95 | prompts,examples = get_prompt_at_gold_index(input_path,0,model_name,total_doc,model.get_tokenizer())
96 | raw_responses = model.generate(prompts, sampling_params)
97 | responses = [output.outputs[0].text.strip() for output in raw_responses]
98 | write_responses(output_path_gold_index,examples,responses,prompts)
99 | return
100 |
101 |
102 |
103 |
104 | if __name__ == "__main__":
105 | logging.basicConfig(format="%(asctime)s - %(module)s - %(levelname)s - %(message)s", level=logging.INFO)
106 | parser = argparse.ArgumentParser()
107 | parser.add_argument("--input-path", help="Path to data with questions and documents to use.", required=True)
108 | parser.add_argument(
109 | "--model",
110 | help="Model to use in generating responses",
111 | required=True,
112 | choices=[
113 | "Mistral-7B-Instruct-v0.3",
114 | "Qwen1.5-7B-Chat",
115 | "Meta-Llama-3-8B-Instruct",
116 | "Qwen2.5-7B-Instruct"
117 | ]
118 | )
119 | parser.add_argument("--temperature", help="Temperature to use in generation", type=float, default=0.6)
120 | parser.add_argument("--num-gpus", help="Number of GPUs to use", type=int, default=1)
121 | parser.add_argument("--cache_dir", help="Path to huggingface cache to use.")
122 | parser.add_argument("--output-path", help="Path to write output file of generated responses", required=True)
123 | parser.add_argument(
124 | "--max-new-tokens",
125 | help="Maximum number of new tokens to generate",
126 | type=int,
127 | default=100,
128 | )
129 | parser.add_argument(
130 | "--max-prompt-length",
131 | help="Maximum number of tokens in the prompt. Longer prompts will be skipped.",
132 | type=int,
133 | default=100,
134 | )
135 | parser.add_argument(
136 | "--top_p",
137 | help="top_p",
138 | type=float,
139 | default=0.95,
140 | )
141 | parser.add_argument(
142 | "--gold_doc",
143 | type=int,
144 | default=0,
145 | )
146 | parser.add_argument(
147 | "--num_gpus",
148 | type=int,
149 | default=0,
150 | )
151 | parser.add_argument(
152 | "--total_doc",
153 | type=int,
154 | )
155 |
156 | args = parser.parse_args()
157 |
158 | logger.info("running %s", " ".join(sys.argv))
159 | main(
160 | args.input_path,
161 | args.model,
162 | args.temperature,
163 | args.top_p,
164 | args.num_gpus,
165 | args.max_new_tokens,
166 | args.max_prompt_length,
167 | args.output_path,
168 | args.gold_doc,
169 | args.total_doc,
170 | args.cache_dir
171 | )
172 | logger.info("finished running %s", sys.argv[0])
173 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==1.5.1
3 | aclpubcheck @ git+https://github.com/acl-org/aclpubcheck@a340fc0a1a7d1c9808f08ab1dab1d228f63af405
4 | aiohappyeyeballs==2.6.1
5 | aiohttp==3.11.13
6 | aiohttp-cors==0.7.0
7 | aiosignal==1.3.2
8 | airportsdata==20250224
9 | annotated-types==0.7.0
10 | anyio==4.9.0
11 | argon2-cffi==23.1.0
12 | argon2-cffi-bindings==21.2.0
13 | arrow==1.3.0
14 | astor==0.8.1
15 | asttokens==3.0.0
16 | async-lru==2.0.5
17 | async-timeout==5.0.1
18 | attrs==25.3.0
19 | babel==2.17.0
20 | beautifulsoup4==4.13.3
21 | bibtexparser==1.4.3
22 | bitsandbytes==0.45.3
23 | black==25.1.0
24 | blake3==1.0.4
25 | bleach==6.2.0
26 | cachetools==5.5.2
27 | certifi==2025.1.31
28 | cffi==1.17.1
29 | charset-normalizer==3.4.1
30 | click==8.1.8
31 | cloudpickle==3.1.1
32 | colorful==0.5.6
33 | comm==0.2.2
34 | compressed-tensors==0.9.1
35 | contourpy==1.3.1
36 | cryptography==45.0.3
37 | cycler==0.12.1
38 | datasets==3.3.2
39 | debugpy==1.8.13
40 | decorator==5.2.1
41 | deepspeed==0.16.4
42 | defusedxml==0.7.1
43 | depyf==0.18.0
44 | dill==0.3.8
45 | diskcache==5.6.3
46 | distlib==0.3.9
47 | distro==1.9.0
48 | docker-pycreds==0.4.0
49 | einops==0.8.1
50 | et_xmlfile==2.0.0
51 | exceptiongroup==1.2.2
52 | executing==2.2.0
53 | fastapi==0.115.12
54 | fastchat==0.1.0
55 | fastjsonschema==2.21.1
56 | filelock==3.18.0
57 | flake8==7.2.0
58 | flash-attn==2.7.0.post2
59 | fonttools==4.57.0
60 | fqdn==1.5.1
61 | frozenlist==1.5.0
62 | fschat==0.2.36
63 | fsspec==2024.6.1
64 | gguf==0.10.0
65 | gitdb==4.0.12
66 | GitPython==3.1.44
67 | google-api-core==2.24.2
68 | google-auth==2.38.0
69 | googleapis-common-protos==1.69.1
70 | grpcio==1.71.0
71 | h11==0.14.0
72 | hf_transfer==0.1.9
73 | hjson==3.1.0
74 | httpcore==1.0.7
75 | httptools==0.6.4
76 | httpx==0.28.1
77 | huggingface-hub==0.29.3
78 | idna==3.10
79 | importlib_metadata==8.6.1
80 | iniconfig==2.1.0
81 | interegular==0.3.3
82 | ipykernel==6.29.5
83 | ipython==8.34.0
84 | ipywidgets==8.1.5
85 | isal==1.7.2
86 | isoduration==20.11.0
87 | isort==6.0.1
88 | jedi==0.19.2
89 | Jinja2==3.1.4
90 | jiter==0.9.0
91 | json5==0.12.0
92 | jsonlines==4.0.0
93 | jsonpointer==3.0.0
94 | jsonschema==4.23.0
95 | jsonschema-specifications==2024.10.1
96 | jupyter==1.1.1
97 | jupyter-console==6.6.3
98 | jupyter-events==0.12.0
99 | jupyter-lsp==2.2.5
100 | jupyter_client==8.6.3
101 | jupyter_core==5.7.2
102 | jupyter_server==2.15.0
103 | jupyter_server_terminals==0.5.3
104 | jupyterlab==4.3.6
105 | jupyterlab_pygments==0.3.0
106 | jupyterlab_server==2.27.3
107 | jupyterlab_widgets==3.0.13
108 | kagglehub==0.3.12
109 | kiwisolver==1.4.8
110 | lark==1.2.2
111 | latexcodec==3.0.0
112 | lightning-utilities==0.14.0
113 | lm-format-enforcer==0.10.11
114 | loralib==0.1.2
115 | Markdown==3.7
116 | markdown-it-py==3.0.0
117 | MarkupSafe==2.1.5
118 | matplotlib==3.10.1
119 | matplotlib-inline==0.1.7
120 | mccabe==0.7.0
121 | mistral_common==1.5.4
122 | mistune==3.1.3
123 | modelscope==1.25.0
124 | mpmath==1.3.0
125 | msgpack==1.1.0
126 | msgspec==0.19.0
127 | multidict==6.1.0
128 | multiprocess==0.70.16
129 | mypy_extensions==1.1.0
130 | nbclient==0.10.2
131 | nbconvert==7.16.6
132 | nbformat==5.10.4
133 | nest-asyncio==1.6.0
134 | networkx==3.3
135 | nh3==0.2.21
136 | ninja==1.11.1.3
137 | notebook==7.3.3
138 | notebook_shim==0.2.4
139 | numpy==1.26.4
140 | nvidia-cublas-cu12==12.4.5.8
141 | nvidia-cuda-cupti-cu12==12.4.127
142 | nvidia-cuda-nvrtc-cu12==12.4.127
143 | nvidia-cuda-runtime-cu12==12.4.127
144 | nvidia-cudnn-cu12==9.1.0.70
145 | nvidia-cufft-cu12==11.2.1.3
146 | nvidia-curand-cu12==10.3.5.147
147 | nvidia-cusolver-cu12==11.6.1.9
148 | nvidia-cusparse-cu12==12.3.1.170
149 | nvidia-ml-py==12.570.86
150 | nvidia-nccl-cu12==2.21.5
151 | nvidia-nvjitlink-cu12==12.4.127
152 | nvidia-nvtx-cu12==12.4.127
153 | nvitop==1.4.2
154 | openai==1.75.0
155 | opencensus==0.11.4
156 | opencensus-context==0.1.3
157 | opencv-python-headless==4.11.0.86
158 | openpyxl==3.1.5
159 | openrlhf==0.6.1.post1
160 | optimum==1.24.0
161 | outlines==0.1.11
162 | outlines_core==0.1.26
163 | overrides==7.7.0
164 | packaging==24.2
165 | pandas==2.2.3
166 | pandocfilters==1.5.1
167 | parso==0.8.4
168 | partial-json-parser==0.2.1.1.post5
169 | pathspec==0.12.1
170 | pdfminer.six==20250327
171 | pdfplumber==0.11.6
172 | peft==0.14.0
173 | pexpect==4.9.0
174 | pillow==11.0.0
175 | platformdirs==4.3.6
176 | pluggy==1.5.0
177 | prometheus-fastapi-instrumentator==7.1.0
178 | prometheus_client==0.21.1
179 | prompt_toolkit==3.0.50
180 | propcache==0.3.0
181 | proto-plus==1.26.1
182 | protobuf==5.29.3
183 | psutil==7.1.2
184 | ptyprocess==0.7.0
185 | pure_eval==0.2.3
186 | py-cpuinfo==9.0.0
187 | py-spy==0.4.0
188 | pyarrow==19.0.1
189 | pyasn1==0.6.1
190 | pyasn1_modules==0.4.1
191 | pybtex==0.24.0
192 | pycodestyle==2.13.0
193 | pycountry==24.6.1
194 | pycparser==2.22
195 | pydantic==2.10.6
196 | pydantic_core==2.27.2
197 | pyflakes==3.3.2
198 | Pygments==2.19.1
199 | pylatexenc==2.10
200 | pynvml==12.0.0
201 | pyparsing==3.2.3
202 | pypdfium2==4.30.1
203 | pytest==8.3.5
204 | python-dateutil==2.9.0.post0
205 | python-dotenv==1.1.0
206 | python-json-logger==3.3.0
207 | pytz==2025.1
208 | PyYAML==6.0.2
209 | pyzmq==26.4.0
210 | ray==2.42.0
211 | rebiber==1.1.3
212 | referencing==0.36.2
213 | regex==2024.11.6
214 | requests==2.32.3
215 | rewardbench==0.1.3
216 | rfc3339-validator==0.1.4
217 | rfc3986-validator==0.1.1
218 | rich==14.0.0
219 | rpds-py==0.23.1
220 | rsa==4.9
221 | safetensors==0.5.3
222 | scipy==1.15.3
223 | seaborn==0.13.2
224 | Send2Trash==1.8.3
225 | sentencepiece==0.2.0
226 | sentry-sdk==2.22.0
227 | setproctitle==1.3.5
228 | six==1.17.0
229 | smart-open==7.1.0
230 | smmap==5.0.2
231 | sniffio==1.3.1
232 | soupsieve==2.6
233 | stack-data==0.6.3
234 | starlette==0.46.2
235 | sympy==1.13.1
236 | tabulate==0.9.0
237 | tensorboard==2.19.0
238 | tensorboard-data-server==0.7.2
239 | termcolor==3.1.0
240 | terminado==0.18.1
241 | tiktoken==0.6.0
242 | tinycss2==1.4.0
243 | tokenizers==0.21.1
244 | tomli==2.2.1
245 | torch==2.5.1+cu124
246 | torchaudio==2.5.1+cu124
247 | torchmetrics==1.6.3
248 | torchvision==0.20.1+cu124
249 | tornado==6.4.2
250 | tqdm==4.67.1
251 | traitlets==5.14.3
252 | transformers==4.48.3
253 | transformers-stream-generator==0.0.5
254 | triton==3.1.0
255 | trl==0.12.2
256 | tsv==1.2
257 | types-python-dateutil==2.9.0.20241206
258 | typing_extensions==4.12.2
259 | tzdata==2025.1
260 | Unidecode==1.4.0
261 | uri-template==1.3.0
262 | urllib3==2.3.0
263 | uvicorn==0.34.2
264 | uvloop==0.21.0
265 | virtualenv==20.29.3
266 | vllm==0.7.2
267 | wandb==0.19.8
268 | watchfiles==1.0.5
269 | wcwidth==0.2.13
270 | webcolors==24.11.1
271 | webencodings==0.5.1
272 | websocket-client==1.8.0
273 | websockets==15.0.1
274 | Werkzeug==3.1.3
275 | widgetsnbextension==4.0.13
276 | wrapt==1.17.2
277 | xformers==0.0.28.post3
278 | xgrammar==0.1.18
279 | xopen==2.0.2
280 | xxhash==3.5.0
281 | yarl==1.18.3
282 | zipp==3.21.0
283 | zlib-ng==0.5.1
284 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
POSITION BIAS MITIGATES POSITION BIAS: Mitigate Position Bias Through Inter-Position Knowledge Distillation
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | We propose **Pos2Distill**, a novel position to position knowledge distillation framework, transferring knowledge from advantageous positions to rectify responses at unfavorable ones, therefore mitigating position bias naturally.
13 | 
14 |
15 | Tabel of Contents
16 |
17 | -
18 | Installation
19 |
20 | -
21 | Data Preparation
22 |
23 | -
24 | Training
25 |
26 | -
27 | Evaluation
28 |
29 | -
30 | Citation
31 |
32 | -
33 | Acknowledgement
34 |
35 |
36 |
37 |
38 | ## News
39 | - [2025-11-03] We have released our test code, data, and model weights for Pos2Distill.
40 |
41 | ## Installation
42 | Create a conda environment and install the required packages:
43 | ```bash
44 | conda create -n pos2distill python=3.10
45 | conda activate pos2distill
46 |
47 | git clone https://github.com/AMAP-ML/Pos2Distill.git
48 | cd Pos2Distill
49 | pip install -r requirements.txt
50 | ```
51 | Our knowledge distillation freamwork is based on [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF), an open-source RLHF framework.
52 | ```bash
53 | pip install openrlhf==0.6.1.post1
54 | ```
55 | ## Data Preparation
56 | We provide a NaturalQuestions (20docs) in [google drive]():
57 | - Firstly, we need to generate high quality responses from advantagous positions through **gen_advantageous_pos_for_R1.sh**, the resulting responses are saved in raw_data/$dataset_name directory.
58 | - Second, we need construct advantagous-trivia pairs for knowledge distillation in the training stage:
59 | - example_num: how many different datapoints are used;
60 | - position_sample_num: how many trivial positions are sampled.
61 | ```bash
62 | model_name=Mistral-7B-Instruct-v0.3
63 | dataset_name=nq
64 | python create_datasets.py \
65 | --model_name $model_name \
66 | --position_sample_num 4 \
67 | --example_num 400 \
68 | --dataset_name $dataset_name
69 | ```
70 |
71 | ## Training
72 | We train our model through the script train_nq.sh: kd_coef controls which training approach you choose, "0.0,0.0,0.0,1.0" means you train models through Pos2Distill.
73 | ```bash
74 | kd_coef="0.0,0.0,0.0,1.0"
75 | deepspeed --master_port 6666 \
76 | kd/train_kd.py \
77 | --max_len 32000 \
78 | --dataset ${model_name}_${total_docs}total_docs_filter_${K}random_${strengthen}strengthen_${num}\
79 | --input_key question \
80 | --output_key oracle_answer \
81 | --train_batch_size 32 \
82 | --micro_train_batch_size 4 \
83 | --max_samples 500000 \
84 | --pretrain $model_name \
85 | --teacher_model $model_name \
86 | --save_path checkpoints/${dataset_name}\
87 | --save_steps -1 \
88 | --logging_steps 1 \
89 | --eval_steps -1 \
90 | --zero_stage 3 \
91 | --max_epochs 1 \
92 | --l2 0.01 \
93 | --bf16 \
94 | --flash_attn \
95 | --kd_coef $kd_coef \
96 | --learning_rate 3e-6 \
97 | --teacher_offload \
98 | --apply_chat_template \
99 | --gradient_checkpointing \
100 | --perserve 1.0 \
101 | --dataset_name $dataset_name \
102 | --use_tensorboard tensorBoard
103 | ```
104 | ## Evaluation
105 |
106 | 1. Currently, the model weights are saved in checkpoints directory. You can run the test code (eval_data.sh) using the command below to evaluate the
107 | performance of Pos2Distill on NaturalQuestion datasets.
108 | ```bash
109 | total_docs=20
110 | num_gpus=8
111 | INPUT_PATH="nq-open-${total_docs}_total_documents.jsonl.gz"
112 | random_num=4
113 | strengthen=1
114 | model_name=Mistral-7B-Instruct-v0.3_20total_docs_filter_4random_1strengthen_400_kd0.0_lm0.0_rank0.0_adaptive1.0_1.0
115 | python eval_data.py \
116 | --input-path "$INPUT_PATH" \
117 | --model $model_name \
118 | --output-path evaluate \
119 | --sample_num 500 \
120 | --max-prompt-length 32768 \
121 | --max-new-tokens 100 \
122 | --num_gpus "$num_gpus" \
123 | --total_doc "$total_docs"
124 | ```
125 | 2. If you want evaluate Pos2Distill on webq datasets (20 docs), run code below:
126 | ```bash
127 | INPUT_PATH="webq_dev.jsonl.gz"
128 | python eval_data.py \
129 | --input-path "$INPUT_PATH" \
130 | --model $model_name \
131 | --output-path evaluate \
132 | --sample_num 500 \
133 | --max-prompt-length 32768 \
134 | --max-new-tokens 100 \
135 | --num_gpus "$num_gpus" \
136 | --total_doc "$total_docs"
137 | ```
138 | 3. If you want evaluate Pos2Distill on tqa datasets (20 docs), run code below:
139 | ```bash
140 | INPUT_PATH="tqa_dev.jsonl.gz"
141 | python eval_data.py \
142 | --input-path "$INPUT_PATH" \
143 | --model $model_name \
144 | --output-path evaluate \
145 | --sample_num 500 \
146 | --max-prompt-length 32768 \
147 | --max-new-tokens 100 \
148 | --num_gpus "$num_gpus" \
149 | --total_doc "$total_docs"
150 | ```
151 |
152 | ## Acknowledgement
153 | We would like to thank the following works for their code and models:
154 | - Training: [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF), [TRL](https://huggingface.co/docs/trl/sft_trainer)
155 | - Datasets: [NaturalQuestions](https://github.com/nelson-liu/lost-in-the-middle), [TQA](https://huggingface.co/datasets/vsearch/tqa), [WEBQA](https://huggingface.co/datasets/vsearch/webq)
156 |
157 | We are extremely grateful to **Linjing Li**, **Yong Wang**, **Xiangxiang Chu** and many other friends in our Machine Learning Group for their helpful feedback and insightful discussions.
158 |
159 | ## Citation
160 | If you feel this project is helpful, please consider cite our report :blush:. If you have any question about us, please contact the email: [Yifei Wang](https://scholar.google.com/citations?hl=zh-CN&user=AHD4c24AAAAJ).
161 | ```bibtex
162 | @article{wang2025position,
163 | title={Position Bias Mitigates Position Bias: Mitigate Position Bias Through Inter-Position Knowledge Distillation},
164 | author={Wang, Yifei and Xiong, Feng and Wang, Yong and Li, Linjing and Chu, Xiangxiang and Zeng, Daniel Dajun},
165 | journal={arXiv preprint arXiv:2508.15709},
166 | year={2025}
167 | }
168 | ```
169 |
--------------------------------------------------------------------------------
/eval_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import logging
5 | import pathlib
6 | import random
7 | import sys
8 | from copy import deepcopy
9 | import torch
10 | from tqdm import tqdm
11 | from transformers import AutoTokenizer,LlamaTokenizer
12 | from xopen import xopen
13 | from vllm import LLM, SamplingParams
14 | import pandas as pd
15 | import random
16 | import numpy as np
17 | from utils import *
18 | from transformers import AutoModelForCausalLM, AutoTokenizer
19 | from gen_advantageous_pos_for_R1 import seed_everything,get_prompt_at_gold_index
20 | cache_dir="/mnt/workspace/wangyifei/projects/huggingcache" # use your cache dir of model weights
21 | seed = random.randint(0, 2**32 - 1)
22 | seed_everything(seed)
23 | def write_responses(output_path,examples, responses,prompts):
24 | with xopen(output_path, "w") as f:
25 | for (example,response,prompt) in zip(examples,responses,prompts):
26 | is_correct = int(best_subspan_em(response.split("\n")[0],example["answers"]))
27 | example["current_answer"] = response
28 | example["current_prompt"] = prompt
29 | f.write(json.dumps(example) + "\n")
30 | average_metric_value = evaluate_qa_data(input_path=output_path)
31 | print(f"average_metric_value: {average_metric_value}")
32 | return average_metric_value
33 |
34 |
35 | def main(
36 | input_path,
37 | model_name,
38 | temperature,
39 | top_p,
40 | num_gpus,
41 | max_new_tokens,
42 | max_prompt_length,
43 | output_path,
44 | sample_num,
45 | gold_doc,
46 | total_doc,
47 | dataset_name
48 |
49 | ):
50 |
51 | # os.makedirs(f"{output_path}/{total_doc}",exist_ok=True)
52 | os.makedirs(output_path,exist_ok=True)
53 | logger.info(f"model name {model_name}")
54 | if model_name.lower().endswith("instruct") or model_name.lower().endswith("chat") or model_name.lower().endswith("v0.3") or model_name.lower().endswith("7b") or model_name.lower().endswith("lite") or model_name.lower().endswith("hf"):
55 | model_path = os.path.join(cache_dir,model_name)
56 | else:
57 | model_path = os.path.join(f"checkpoints/nq",model_name)
58 | model = LLM(
59 | model=model_path,
60 | tensor_parallel_size=num_gpus,
61 | # load_format="safetensors",
62 | # load_format="pt",
63 | trust_remote_code=True,
64 | max_num_batched_tokens=max_prompt_length,
65 | )
66 | sampling_params = SamplingParams(temperature=0.0, max_tokens=max_new_tokens,top_p=top_p,seed=seed)
67 |
68 | results_average = {"Model_name": model_name}
69 | results_std = {"Model_name": model_name}
70 | for i in range(0,total_doc+1,5):
71 | output_path_gold_index = os.path.join(output_path, f"{model_name}_{i}.jsonl.gz")
72 | average_iterations = 1
73 | average_metric_values = []
74 | for j in range(average_iterations):
75 | prompts, examples = get_prompt_at_gold_index(input_path, i, model_name, total_doc, model.get_tokenizer())
76 | prompts = prompts[-sample_num:]
77 | examples = examples[-sample_num:]
78 | if i==0:
79 | # compute average tokens
80 | stats = []
81 | for prompt in prompts:
82 | tok=model.get_tokenizer()
83 | tok.pad_token = tok.eos_token
84 | prompt_len = tok(prompt,padding=True,add_special_tokens=False,return_tensors="pt")["input_ids"].size()[1]
85 | stats.append(prompt_len)
86 | logger.info(f"tokens stats: max={max(stats)} min={min(stats)} avg={np.mean(stats)}")
87 | sampling_params.seed+=1
88 | raw_responses = model.generate(prompts, sampling_params)
89 | responses = [output.outputs[0].text.strip() for output in raw_responses]
90 |
91 |
92 | average_metric_value = write_responses(output_path_gold_index, examples, responses,prompts)
93 | average_metric_values.append(average_metric_value)
94 | # 记录均值和标准差
95 | results_average[i] = np.mean(average_metric_values)
96 | results_std[i] = np.std(average_metric_values)
97 | # 转换为 DataFrame
98 | df_avg = pd.DataFrame([results_average])
99 | file_path = f"{dataset_name}_{total_doc}.xlsx"
100 | if os.path.exists(file_path):
101 | existing_df = pd.read_excel(file_path, engine='openpyxl')
102 | combined_df = pd.concat([existing_df, df_avg], ignore_index=True)
103 | else:
104 | combined_df = pd.concat([df_avg], ignore_index=True)
105 | combined_df.to_excel(file_path, index=False, engine='openpyxl')
106 | print(f"Results saved to {file_path}")
107 | return
108 |
109 |
110 |
111 |
112 | if __name__ == "__main__":
113 | logging.basicConfig(format="%(asctime)s - %(module)s - %(levelname)s - %(message)s", level=logging.INFO)
114 | parser = argparse.ArgumentParser()
115 | parser.add_argument("--input-path", help="Path to data with questions and documents to use.", required=True)
116 | parser.add_argument(
117 | "--model",
118 | help="Model to use in generating responses",
119 | required=True,
120 | )
121 | parser.add_argument("--temperature", help="Temperature to use in generation", type=float, default=0.6)
122 | parser.add_argument("--top-p", help="Top-p to use in generation", type=float, default=1.0)
123 | parser.add_argument("--num-gpus", help="Number of GPUs to use", type=int, default=1)
124 | parser.add_argument("--output-path", help="Path to write output file of generated responses", required=True)
125 | parser.add_argument(
126 | "--max-new-tokens",
127 | help="Maximum number of new tokens to generate",
128 | type=int,
129 | default=100,
130 | )
131 | parser.add_argument(
132 | "--max-prompt-length",
133 | help="Maximum number of tokens in the prompt. Longer prompts will be skipped.",
134 | type=int,
135 | default=4096,
136 | )
137 | parser.add_argument(
138 | "--sample_num",
139 | help="sample size",
140 | type=int,
141 | default=500,
142 | )
143 | parser.add_argument(
144 | "--top_p",
145 | help="top_p",
146 | type=float,
147 | default=1.0,
148 | )
149 | parser.add_argument(
150 | "--gold_doc",
151 | type=int,
152 | default=0,
153 | )
154 | parser.add_argument(
155 | "--num_gpus",
156 | type=int,
157 | default=0,
158 | )
159 | parser.add_argument(
160 | "--total_doc",
161 | type=int,
162 | )
163 | parser.add_argument(
164 | "--dataset_name",
165 | type=str,
166 | )
167 |
168 | args = parser.parse_args()
169 |
170 | logger.info("running %s", " ".join(sys.argv))
171 | main(
172 | args.input_path,
173 | args.model,
174 | args.temperature,
175 | args.top_p,
176 | args.num_gpus,
177 | args.max_new_tokens,
178 | args.max_prompt_length,
179 | args.output_path,
180 | args.sample_num,
181 | args.gold_doc,
182 | args.total_doc,
183 | args.dataset_name
184 |
185 | )
186 | logger.info("finished running %s", sys.argv[0])
187 |
--------------------------------------------------------------------------------
/kd/customized_kd_trainer.py:
--------------------------------------------------------------------------------
1 | from openrlhf.trainer import KDTrainer
2 | from typing import Optional, Tuple
3 | from tqdm import tqdm
4 | import torch
5 | import torch.distributed as dist
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import os
9 | from abc import ABC
10 | import torch
11 |
12 | from torch.optim import Optimizer
13 | from tqdm import tqdm
14 | from openrlhf.models import GPTLMLoss, KDLoss
15 | from openrlhf.utils.distributed_sampler import DistributedSampler
16 | import logging
17 | from typing import Optional, Tuple
18 | import torch
19 | import torch.distributed as dist
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | from loss_design import AdaptiveKLWeightedKLLoss,MY_KDLoss,MY_rankLoss,MY_topkLoss,MY_GPTLMLoss,AdaptiveKLWeightedKLLoss_multi
23 | logger = logging.getLogger('customized kd trainer')
24 |
25 | class MY_KDTrainer(KDTrainer):
26 | def __init__(self, *args, **kwargs):
27 | attention_dillution_coef = kwargs.pop("attention_dillution_coef", 1.0)
28 | multi_gold_doc = kwargs.pop("multi_gold_doc", False)
29 | super().__init__(*args, **kwargs)
30 | self.kd_loss = MY_KDLoss()
31 | self.rank_loss = MY_topkLoss()
32 | self.loss_fn = MY_GPTLMLoss()
33 | self.adaptive_kd_loss = AdaptiveKLWeightedKLLoss(batch_norm=True,attention_dillution_coef = attention_dillution_coef) if not multi_gold_doc else AdaptiveKLWeightedKLLoss_multi(batch_norm=True,attention_dillution_coef = attention_dillution_coef)
34 | # self.loss_fn = MY_rankLoss()
35 |
36 | def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None):
37 | # get eval and save steps
38 | if args.eval_steps == -1:
39 | args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch
40 | if args.save_steps == -1:
41 | args.save_steps = float("inf") # do not save ckpt
42 |
43 | # Restore step and start_epoch
44 | step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1
45 | start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch
46 | consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size)
47 |
48 | epoch_bar = tqdm(
49 | range(start_epoch, self.epochs),
50 | desc="Train epoch",
51 | disable=not self.strategy.is_rank_0(),
52 | )
53 | loss_sum = 0
54 | for epoch in range(start_epoch, self.epochs):
55 | if isinstance(self.train_dataloader.sampler, DistributedSampler):
56 | self.train_dataloader.sampler.set_epoch(
57 | epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples
58 | )
59 | step_bar = tqdm(
60 | range(self.train_dataloader.__len__()),
61 | desc="Train step of epoch %d" % epoch,
62 | disable=not self.strategy.is_rank_0(),
63 | )
64 | # train
65 | self.model.train()
66 | self.teacher_model.eval()
67 |
68 | for teacher_prompt_id_lens,teacher_inputs,teacher_attention_masks, student_prompt_id_lens,student_inputs,student_attention_masks,infos in self.train_dataloader:
69 | student_inputs = student_inputs.squeeze(1).to(torch.cuda.current_device())
70 | student_attention_mask = student_attention_masks.squeeze(1).to(torch.cuda.current_device())
71 | student_output = self.model(student_inputs, attention_mask=student_attention_mask, return_output=True)
72 | student_labels = torch.where(
73 | student_attention_mask.bool(),
74 | student_inputs,
75 | self.loss_fn.IGNORE_INDEX,
76 | )
77 | teacher_inputs = teacher_inputs.squeeze(1).to(torch.cuda.current_device())
78 | teacher_attention_mask = teacher_attention_masks.squeeze(1).to(torch.cuda.current_device())
79 | teacher_labels = torch.where(
80 | teacher_attention_mask.bool(),
81 | teacher_inputs,
82 | self.loss_fn.IGNORE_INDEX,
83 | )
84 | if not self.pretrain_mode:
85 | for label, source_len in zip(student_labels, student_prompt_id_lens):
86 | label[:source_len] = self.loss_fn.IGNORE_INDEX
87 |
88 | for label, source_len in zip(teacher_labels, teacher_prompt_id_lens):
89 | label[:source_len] = self.loss_fn.IGNORE_INDEX
90 | if args.kd_coef[1] > 0:
91 | gpt_loss = self.loss_fn(student_output.logits, student_labels)
92 | else:
93 | gpt_loss = torch.tensor(0.0).to(student_output.logits.device)
94 |
95 | with torch.no_grad():
96 | teacher_logits = self.teacher_model(teacher_inputs, attention_mask=teacher_attention_mask, return_output=True)[
97 | "logits"
98 | ]
99 | if args.kd_coef[0] > 0:
100 | distil_loss = self.kd_loss(student_output.logits, teacher_logits, student_labels,teacher_labels)
101 | else:
102 | distil_loss = torch.tensor(0.0).to(teacher_logits.device)
103 |
104 | if args.kd_coef[2] > 0:
105 | rank_loss = self.rank_loss(student_output.logits, teacher_logits, student_labels,teacher_labels)
106 | else:
107 | rank_loss = torch.tensor(0.0).to(teacher_logits.device)
108 |
109 | if args.kd_coef[3] > 0:
110 | adaptive_kd_loss= self.adaptive_kd_loss(student_output.logits, teacher_logits, student_labels,teacher_labels,infos["gold_idx"])
111 | else:
112 | adaptive_kd_loss = torch.tensor(0.0).to(teacher_logits.device)
113 | #均衡系数
114 |
115 | loss = distil_loss * self.args.kd_coef[0] + gpt_loss* args.kd_coef[1] + rank_loss * args.kd_coef[2] + adaptive_kd_loss* self.args.kd_coef[3]
116 |
117 | self.strategy.backward(loss, self.model, self.optimizer)
118 | self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler)
119 |
120 | loss_sum += loss.item()
121 | logs_dict = {
122 | "gpt_loss": gpt_loss.item(),
123 | "distil_loss": distil_loss.item(),
124 | "rank_loss": rank_loss.item(),
125 | "adaptive_kd_loss": adaptive_kd_loss.item(),
126 | # "lr": self.scheduler.get_last_lr()[0],
127 | }
128 | # step bar
129 | logs_dict = self.strategy.all_reduce(logs_dict)
130 | step_bar.set_postfix(logs_dict)
131 | step_bar.update()
132 |
133 | # logs/checkpoints/evaluation
134 | if step % self.strategy.accumulated_gradient == 0:
135 | logs_dict["loss_mean"] = loss_sum / self.strategy.accumulated_gradient
136 | loss_sum = 0
137 | global_step = step // self.strategy.accumulated_gradient
138 | client_states = {"consumed_samples": global_step * args.train_batch_size}
139 | self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states)
140 |
141 | step += 1
142 |
143 | epoch_bar.update()
144 |
145 | if self._wandb is not None and self.strategy.is_rank_0():
146 | self._wandb.finish()
147 | if self._tensorboard is not None and self.strategy.is_rank_0():
148 | self._tensorboard.close()
149 |
150 | # logs/checkpoints/evaluation
151 | def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}):
152 | if global_step % args.logging_steps == 0:
153 | # wandb
154 | if self._wandb is not None and self.strategy.is_rank_0():
155 | logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()}
156 | self._wandb.log(logs)
157 | # TensorBoard
158 | elif self._tensorboard is not None and self.strategy.is_rank_0():
159 | for k, v in logs_dict.items():
160 | self._tensorboard.add_scalar(f"train/{k}", v, global_step)
161 |
162 | # eval
163 | # if global_step % args.eval_steps == 0:
164 | # # do eval when len(dataloader) > 0, avoid zero division in eval.
165 | # if len(self.eval_dataloader) > 0:
166 | # self.evaluate(self.eval_dataloader, global_step)
167 | # save ckpt
168 | # TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric
169 | # if global_step % args.save_steps == 0:
170 | # tag = f"global_step{global_step}"
171 | # self.strategy.save_ckpt(
172 | # self.model.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states
173 | # )
174 |
175 | def evaluate(self, eval_dataloader, steps=0):
176 | times = 0
177 | self.model.eval()
178 | self.teacher_model.eval()
179 | with torch.no_grad():
180 | loss_sum = 0
181 | step_bar = tqdm(
182 | range(eval_dataloader.__len__()),
183 | desc="Eval stage of steps %d" % steps,
184 | disable=not self.strategy.is_rank_0(),
185 | )
186 | for teacher_prompt_id_lens,teacher_inputs,teacher_attention_masks, student_prompt_id_lens,student_inputs,student_attention_masks,infos in eval_dataloader:
187 | student_inputs = student_inputs.squeeze(1).to(torch.cuda.current_device())
188 | student_attention_mask = student_attention_masks.squeeze(1).to(torch.cuda.current_device())
189 | student_output = self.model(student_inputs, attention_mask=student_attention_mask, return_output=True)
190 | student_labels = torch.where(
191 | student_attention_mask.bool(),
192 | student_inputs,
193 | self.loss_fn.IGNORE_INDEX,
194 | )
195 | teacher_inputs = teacher_inputs.squeeze(1).to(torch.cuda.current_device())
196 | teacher_attention_mask = teacher_attention_masks.squeeze(1).to(torch.cuda.current_device())
197 | teacher_labels = torch.where(
198 | teacher_attention_mask.bool(),
199 | teacher_inputs,
200 | self.loss_fn.IGNORE_INDEX,
201 | )
202 | if not self.pretrain_mode:
203 | for label, source_len in zip(student_labels, student_prompt_id_lens):
204 | label[:source_len] = self.loss_fn.IGNORE_INDEX
205 |
206 | for label, source_len in zip(teacher_labels, teacher_prompt_id_lens):
207 | label[:source_len] = self.loss_fn.IGNORE_INDEX
208 | if args.kd_coef[1] > 0:
209 | gpt_loss = self.loss_fn(student_output.logits, student_labels)
210 | else:
211 | gpt_loss = torch.tensor(0.0).to(student_output.logits.device)
212 |
213 | with torch.no_grad():
214 | teacher_logits = self.teacher_model(teacher_inputs, attention_mask=teacher_attention_mask, return_output=True)[
215 | "logits"
216 | ]
217 | if args.kd_coef[0] > 0:
218 | distil_loss = self.kd_loss(student_output.logits, teacher_logits, student_labels,teacher_labels)
219 | else:
220 | distil_loss = torch.tensor(0.0).to(teacher_logits.device)
221 |
222 | if args.kd_coef[2] > 0:
223 | rank_loss = self.rank_loss(student_output.logits, teacher_logits, student_labels,teacher_labels)
224 | else:
225 | rank_loss = torch.tensor(0.0).to(teacher_logits.device)
226 |
227 | if args.kd_coef[3] > 0:
228 | adaptive_kd_loss = self.adaptive_kd_loss(student_output.logits, teacher_logits, student_labels,teacher_labels,infos["gold_idx"])
229 | else:
230 | adaptive_kd_loss = torch.tensor(0.0).to(teacher_logits.device)
231 | #均衡系数
232 |
233 | loss = distil_loss * self.args.kd_coef[0] + gpt_loss* args.kd_coef[1] + rank_loss * args.kd_coef[2] + adaptive_kd_loss* self.args.kd_coef[3]
234 |
235 |
236 |
237 |
238 | times += 1
239 | loss_sum += loss.item()
240 | bar_dict = {"eval loss": loss_sum / times}
241 | step_bar.update()
242 | logs = self.strategy.all_reduce(bar_dict)
243 | step_bar.set_postfix(logs)
244 |
245 | if self.strategy.is_rank_0():
246 | if self._wandb is not None:
247 | logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()}
248 | self._wandb.log(logs)
249 | elif self._tensorboard is not None:
250 | for k, v in logs.items():
251 | self._tensorboard.add_scalar(f"eval/{k}", v, steps)
252 |
253 | self.model.train() # reset model state
254 |
--------------------------------------------------------------------------------
/kd/train_kd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import os
4 | from datetime import datetime
5 | import logging
6 | from transformers.trainer import get_scheduler
7 | import torch
8 | from customized_sft_dataset import SFTDataset
9 | from openrlhf.models import Actor
10 | # from openrlhf.trainer import KDTrainer
11 | from customized_kd_trainer import MY_KDTrainer
12 | from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer
13 | # from openrlhf.utils.deepspeed import DeepspeedStrategy
14 | from torch import distributed as dist
15 | # from torch.utils.data import DataLoader
16 | # from openrlhf.utils.distributed_sampler import DistributedSampler
17 | import numpy as np
18 | import random
19 | import transformers
20 |
21 | def seed_everything(seed):
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | np.random.seed(seed)
25 | random.seed(seed)
26 | torch.backends.cudnn.benchmark = False
27 | torch.backends.cudnn.deterministic = True
28 | torch.cuda.manual_seed_all(seed)
29 |
30 | seed_everything(42)
31 |
32 |
33 |
34 |
35 | logging.basicConfig(
36 | level=logging.INFO, # 设置日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL
37 | format='%(asctime)s - %(levelname)s - %(message)s',
38 | datefmt='%Y-%m-%d %H:%M:%S'
39 | )
40 | logger = logging.getLogger('kd_training')
41 | def train(args):
42 | cache_dir = "/mnt/workspace/wangyifei/projects/huggingcache"
43 | dataset_dir = "/mnt/workspace/wangyifei/projects/self_training/datasets"
44 | if args.multi_gold_doc:
45 | assert args.dataset_name in ["musique"]
46 | else:
47 | assert args.dataset_name in ["nq","tqa","webq"]
48 | # configure strategy
49 | strategy = get_strategy(args)
50 | strategy.setup_distributed()
51 |
52 | # configure model
53 | # load huggingface model
54 | logger.info(f"Loading student model from {os.path.join(cache_dir, args.pretrain)}")
55 | model = Actor(
56 | pretrain_or_model = os.path.join(cache_dir, args.pretrain),
57 | # pretrain_or_model="/mnt/workspace/wangyifei/projects/self_training/kl_checkpoints/Qwen1.5-7B-Chat_20total_docs_filter_5random_2strengthen_400_kl_0.8",
58 | use_flash_attention_2=args.flash_attn,
59 | bf16=args.bf16,
60 | load_in_4bit=args.load_in_4bit,
61 | lora_rank=args.lora_rank,
62 | lora_alpha=args.lora_alpha,
63 | target_modules=args.target_modules,
64 | lora_dropout=args.lora_dropout,
65 | ds_config=strategy.get_ds_train_config(is_actor=True),
66 | )
67 | logger.info(f"Loading teacher model from {os.path.join(cache_dir, args.teacher_model)}")
68 | # load teacher model for inference
69 | teacher_model = Actor(
70 | os.path.join(cache_dir, args.teacher_model),
71 | use_flash_attention_2=args.flash_attn,
72 | bf16=args.bf16,
73 | load_in_4bit=args.load_in_4bit,
74 | ds_config=strategy.get_ds_eval_config(offload=args.teacher_offload),
75 | )
76 | if args.teacher_offload:
77 | teacher_model._offload = True
78 |
79 |
80 | # configure tokenizer
81 | logger.info(f"Loading tokenizer from {os.path.join(cache_dir, args.pretrain)}")
82 | tokenizer = get_tokenizer(os.path.join(cache_dir, args.pretrain), model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer)
83 |
84 | strategy.print(model)
85 |
86 |
87 | # configure optimizer
88 | logger.info(f"Creating optimizer....")
89 | optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2)
90 |
91 | logger.info(f"Prepare for data and dataset....")
92 | # prepare for data and dataset
93 | logger.info(f"args.dataset_probs: {args.dataset_probs}")
94 | logger.info(f"args.dataset: {args.dataset}")
95 | dataset_dir = os.path.join(dataset_dir, args.dataset_name)
96 | train_data = blending_datasets(
97 | os.path.join(dataset_dir, args.dataset),
98 | args.dataset_probs,
99 | strategy,
100 | args.seed,
101 | return_eval=False,
102 | max_count=args.max_samples,
103 | train_split=args.train_split,
104 | eval_split=args.eval_split,
105 | )
106 |
107 | # eval_dataset_name = args.dataset.split("_unfilter")[0] + "_dev"
108 | # logger.info(f"Prepare for eval data and dataset....")
109 | # logger.info(f"eval dataset name: {eval_dataset_name}")
110 | # eval_data = blending_datasets(
111 | # os.path.join(dataset_dir, eval_dataset_name),
112 | # args.dataset_probs,
113 | # strategy,
114 | # args.seed,
115 | # return_eval=False,
116 | # max_count=args.max_samples,
117 | # train_split=args.train_split,
118 | # eval_split=args.eval_split,
119 | # )
120 |
121 | train_data = train_data.select(range(min(args.max_samples, len(train_data))))
122 | train_data = train_data.shuffle(seed=args.seed)
123 | # eval_data = eval_data.select(range(min(args.max_samples, len(eval_data))))
124 |
125 | train_dataset = SFTDataset(
126 | train_data,
127 | tokenizer,
128 | args.max_len,
129 | strategy,
130 | pretrain_mode=args.pretrain_mode,
131 | input_template=args.input_template,
132 | multi_gold_doc=args.multi_gold_doc,
133 | num_processors=1,
134 | )
135 |
136 | # eval_dataset = SFTDataset(
137 | # eval_data,
138 | # tokenizer,
139 | # args.max_len,
140 | # strategy,
141 | # pretrain_mode=args.pretrain_mode,
142 | # input_template=args.input_template,
143 | # )
144 |
145 | train_dataloader = strategy.setup_dataloader(
146 | replay_buffer = train_dataset,
147 | batch_size = args.micro_train_batch_size,
148 | pin_memory = True,
149 | shuffle = True,
150 | collate_fn = train_dataset.collate_fn,
151 | # use_block_split=True ,
152 | # block_size= args.micro_train_batch_size
153 | )
154 |
155 | num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size
156 | max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
157 |
158 | scheduler = get_scheduler(
159 | args.lr_scheduler,
160 | optim,
161 | num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio),
162 | num_training_steps=max_steps,
163 | scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1},
164 | )
165 |
166 | # gradient_checkpointing
167 | if args.gradient_checkpointing:
168 | model.gradient_checkpointing_enable(
169 | gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
170 | )
171 |
172 | # prepare models
173 | ((model, optim, scheduler), teacher_model) = strategy.prepare((model, optim, scheduler), teacher_model)
174 |
175 | # load checkpoint
176 | args.kd_coef = [float(x) for x in args.kd_coef.split(',')]
177 | consumed_samples = 0
178 | if args.load_checkpoint and os.path.exists(args.ckpt_path):
179 | _, states = strategy.load_ckpt(model.model, args.ckpt_path)
180 | consumed_samples = states["consumed_samples"]
181 | strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}")
182 | if args.output_key == "oracle_answer":
183 | checkpoint_name = args.dataset + f"_kd{str(args.kd_coef[0])}_lm{str(args.kd_coef[1])}_rank{str(args.kd_coef[2])}_adaptive{str(args.kd_coef[3])}_{args.perserve}"
184 | # args.save_path = os.path.join(args.save_path, args.dataset + f"_kd{str(args.kd_coef[0])}_lm{str(args.kd_coef[1])}_rank{str(args.kd_coef[2])}_adaptive{str(args.kd_coef[3])}")
185 | strategy.args.wandb_run_name = args.dataset + f"_kd{str(args.kd_coef[0])}_lm{str(args.kd_coef[1])}_rank{str(args.kd_coef[2])}_adaptive{str(args.kd_coef[3])}_{args.perserve}"
186 | else:
187 | checkpoint_name = args.dataset + f"_sft"
188 | # args.save_path = os.path.join(args.save_path, args.dataset + f"_sft")
189 | strategy.args.wandb_run_name = args.dataset + f"_sft"
190 | assert args.kd_coef[0] == 0.0
191 |
192 |
193 | # configure Trainer
194 | attention_dillution_coef = args.perserve
195 | trainer = MY_KDTrainer(
196 | model=model,
197 | teacher_model=teacher_model,
198 | strategy=strategy,
199 | optim=optim,
200 | train_dataloader=train_dataloader,
201 | eval_dataloader=None,
202 | scheduler=scheduler,
203 | max_norm=args.max_norm,
204 | pretrain_mode=args.pretrain_mode,
205 | batch_size=args.train_batch_size,
206 | max_epochs=args.max_epochs,
207 | tokenizer=tokenizer,
208 | multi_gold_doc=args.multi_gold_doc,
209 | attention_dillution_coef=round(attention_dillution_coef, 1),
210 | # **({"attention_dillution_coef": round(attention_dillution_coef, 1)} if args.output_key == "oracle_answer" else {}),
211 |
212 | )
213 |
214 | trainer.fit(args, consumed_samples, num_update_steps_per_epoch)
215 | # save model checkpoint after fitting on only rank0
216 | save_path = os.path.join(args.save_path, checkpoint_name)
217 | os.makedirs(save_path, exist_ok=True)
218 | strategy.save_model(model, tokenizer, save_path)
219 |
220 |
221 |
222 | if __name__ == "__main__":
223 | parser = argparse.ArgumentParser()
224 | # Checkpoints
225 | parser.add_argument("--save_path", type=str, default="./ckpt")
226 | parser.add_argument("--save_steps", type=int, default=-1)
227 | parser.add_argument("--logging_steps", type=int, default=1)
228 | parser.add_argument("--eval_steps", type=int, default=-1)
229 | parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_kd")
230 | parser.add_argument("--max_ckpt_num", type=int, default=3)
231 | parser.add_argument("--max_ckpt_mem", type=int, default=1e8)
232 | parser.add_argument("--load_checkpoint", action="store_true", default=False)
233 |
234 | # DeepSpeed
235 | parser.add_argument("--micro_train_batch_size", type=int, default=8, help="batch size per GPU")
236 | parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size")
237 | parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping")
238 | parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
239 | parser.add_argument("--seed", type=int, default=42)
240 | parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed")
241 | parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage")
242 | parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16")
243 | parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size")
244 | parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer")
245 | parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2")
246 | parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss")
247 | parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type")
248 | parser.add_argument("--overlap_comm", action="store_true", default=False)
249 | parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False)
250 | parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)
251 |
252 | # LoRA
253 | parser.add_argument("--load_in_4bit", action="store_true", default=False)
254 | parser.add_argument("--lora_rank", type=int, default=0)
255 | parser.add_argument("--lora_alpha", type=int, default=16)
256 | parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear")
257 | parser.add_argument("--lora_dropout", type=float, default=0)
258 |
259 | # KD
260 | parser.add_argument("--pretrain", type=str, default=None)
261 | parser.add_argument("--teacher_model", type=str, default=None)
262 | parser.add_argument("--max_epochs", type=int, default=1)
263 | parser.add_argument("--kd_coef", type=str)
264 | parser.add_argument("--learning_rate", type=float, default=5e-6)
265 | parser.add_argument("--lr_warmup_ratio", type=float, default=0.05)
266 | parser.add_argument("--pretrain_mode", action="store_true", default=False, help="Use pretrain loss")
267 | parser.add_argument("--lr_scheduler", type=str, default="cosine_with_min_lr")
268 | parser.add_argument("--l2", type=float, default=0, help="weight decay loss")
269 | parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer")
270 | parser.add_argument("--teacher_offload", action="store_true", default=False)
271 |
272 | # Custom dataset
273 | parser.add_argument("--dataset", type=str, default=None)
274 | parser.add_argument("--dataset_name", type=str, default=None)
275 | parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets")
276 | parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset")
277 | parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset")
278 |
279 | parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key")
280 | parser.add_argument("--output_key", type=str, default="output", help="JSON dataset key")
281 | parser.add_argument("--input_template", type=str, default="User: {}\nAssistant: ")
282 | parser.add_argument(
283 | "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template"
284 | )
285 |
286 | parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples")
287 | parser.add_argument("--max_len", type=int, default=2048, help="Max tokens for the samples")
288 |
289 | # wandb parameters
290 | parser.add_argument("--use_wandb", type=str, default=None)
291 | parser.add_argument("--wandb_org", type=str, default=None)
292 | parser.add_argument("--wandb_group", type=str, default=None)
293 | parser.add_argument("--wandb_project", type=str, default="openrlhf_train_sft")
294 | # parser.add_argument(
295 | # "--wandb_run_name",
296 | # type=str,
297 | # default="sft_%s" % datetime.now().strftime("%m%dT%H:%M"),
298 | # )
299 | parser.add_argument(
300 | "--wandb_run_name",
301 | type=str,
302 | default="temp")
303 | # TensorBoard parameters
304 | parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path")
305 |
306 | # ModelScope parameters
307 | parser.add_argument("--use_ms", action="store_true", default=False)
308 | # kd trainer parameters
309 | parser.add_argument("--perserve", type=float,default=1.0)
310 | parser.add_argument("--multi_gold_doc",action="store_true",default=False)
311 |
312 | args = parser.parse_args()
313 |
314 | if args.input_template and "{}" not in args.input_template:
315 | print("[Warning] {} not in args.input_template, set to None")
316 | args.input_template = None
317 |
318 | if args.input_template and "\\n" in args.input_template:
319 | print(
320 | "[Warning] input_template contains \\n chracters instead of newline. "
321 | "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell."
322 | )
323 |
324 | if args.use_ms:
325 | from modelscope.utils.hf_util import patch_hub
326 |
327 | # Patch hub to download models from modelscope to speed up.
328 | # patch_hub()
329 |
330 | train(args)
331 |
--------------------------------------------------------------------------------
/kd/customized_sft_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 | import pdb
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.utils.data import Dataset
6 | from openrlhf.datasets.utils import zero_pad_sequences
7 | import logging
8 | import torch.distributed as dist
9 | logger = logging.getLogger('customized sft dataset')
10 | def get_teacher_student_prompt(data,multi_gold_doc=False):
11 | # logger.info(f"Get teacher student prompt")
12 | question = data["question"]
13 | distractors = data["distractors"]
14 | gold_doc = data["gold_document"]
15 |
16 | # key = "idx" if multi_gold_doc else "id"
17 | # distractor_indexs = [x[key] for x in distractors]
18 | # gold_indexs = [x[key] for x in gold_doc]
19 |
20 | if not multi_gold_doc:
21 | distractors.insert(0, gold_doc)
22 | else:
23 | distractors[len(distractors):len(distractors)] = gold_doc
24 | # complex_indexs = [x[key] for x in distractors]
25 | # assert complex_indexs[-len(gold_indexs):] == gold_indexs
26 | # print(f"complex_indexs: {complex_indexs}")
27 | # print(f"distractor_indexs: {distractor_indexs}")
28 | # print(f"gold_indexs: {gold_indexs}")
29 |
30 |
31 |
32 | teacher_documents = distractors
33 | student_documents = data["documents"]
34 |
35 | student_context = ''.join([f"- Title: {doc['title']}\n{doc['text']}\n" for doc in student_documents])
36 | teacher_context = ''.join([f"- Title: {doc['title']}\n{doc['text']}\n" for doc in teacher_documents])
37 | if not multi_gold_doc:
38 | task_instruction = "Please write a high-quantify answer for the given question using only the provided search documents (some of which might be irrelevant)."
39 | else:
40 | task_instruction = """Let’s first identify the relevant information from the long context and list it. Then, carry out step-by-step reasoning based on that information, and \
41 | finally, provide the answer. The final answer must end with “The answer is:”."""
42 | student_prompt = f"{task_instruction}\n{student_context}\nQuestion: {question}\n"
43 | teacher_prompt = f"{task_instruction}\n{teacher_context}\nQuestion: {question}\n"
44 |
45 | return student_prompt, teacher_prompt
46 |
47 |
48 |
49 | def preprocess_data(
50 | data, input_template=None, input_key="input", output_key=None, apply_chat_template=None, multiturn=False,multi_gold_doc=False
51 | ):
52 | if apply_chat_template:
53 | # logger.info(f"Apply chat template")
54 | if output_key:
55 | # question = data[input_key]
56 | # oracle_answer = data[output_key]
57 | student_prompt, teacher_prompt = get_teacher_student_prompt(data,multi_gold_doc)
58 | if not multi_gold_doc:
59 | response = data[output_key].split("\n")[0]
60 | else:
61 | response = data[output_key]
62 |
63 |
64 | if isinstance(student_prompt, str) and isinstance(teacher_prompt, str) and isinstance(response, str):
65 | student_prompt_message = [{"role": "user", "content": student_prompt}]
66 | teacher_prompt_message = [{"role": "user", "content": teacher_prompt}]
67 | response_message = [{"role": "assistant", "content": response}]
68 |
69 | student_template_prompt = apply_chat_template(student_prompt_message, tokenize=False, add_generation_prompt=True)
70 | teacher_template_prompt = apply_chat_template(teacher_prompt_message, tokenize=False, add_generation_prompt=True)
71 | student_response = apply_chat_template(student_prompt_message + response_message, tokenize=False)[len(student_template_prompt):]
72 | teacher_response = apply_chat_template(teacher_prompt_message + response_message, tokenize=False)[len(teacher_template_prompt):]
73 | assert student_response == teacher_response
74 | else:
75 | prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True)
76 | response = apply_chat_template(data[input_key], tokenize=False)[len(prompt) :]
77 | else:
78 | prompt = data[input_key]
79 | if input_template:
80 | prompt = input_template.format(prompt)
81 | # output_key is None for continue pretrain
82 | response = data[output_key] if output_key else ""
83 | return student_template_prompt, teacher_template_prompt, student_response
84 |
85 |
86 | class SFTDataset(Dataset):
87 | """
88 | Dataset for SFT model
89 |
90 | Args:
91 | dataset: dataset for SFT model
92 | tokenizer: tokenizer for SFT model
93 | max_length: max length of input
94 | """
95 |
96 | def __init__(
97 | self,
98 | dataset,
99 | tokenizer: Callable,
100 | max_length: int,
101 | strategy,
102 | input_template=None,
103 | pretrain_mode=False,
104 | num_processors= 48, # Specify the number of processors you want to use
105 | multiple_of=1,
106 | multiturn=False,
107 | multi_gold_doc=False
108 | ) -> None:
109 | super().__init__()
110 |
111 | self.tokenizer = tokenizer
112 | self.strategy = strategy
113 | self.pretrain_mode = pretrain_mode
114 | self.max_length = max_length
115 | self.multiple_of = multiple_of
116 | self.multiturn = multiturn
117 | self.multi_gold_doc = multi_gold_doc
118 |
119 | # chat template
120 | self.input_template = input_template
121 | self.input_key = getattr(self.strategy.args, "input_key", None)
122 | self.output_key = getattr(self.strategy.args, "output_key", None)
123 | self.apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False)
124 |
125 | if self.apply_chat_template:
126 | self.apply_chat_template = self.tokenizer.apply_chat_template
127 | tokenizer_chat_template = getattr(self.strategy.args, "tokenizer_chat_template", None)
128 | if tokenizer_chat_template:
129 | self.tokenizer.chat_template = tokenizer_chat_template
130 |
131 | # Parallel loading datasets
132 |
133 | processed_dataset = dataset.map(
134 | self.process_data,
135 | remove_columns=dataset.column_names,
136 | num_proc=num_processors,
137 | )
138 | processed_dataset = processed_dataset.filter(lambda x: x["teacher_prompt"] is not None and x["student_prompt"] is not None and x["response"] is not None)
139 |
140 | # Store the processed data in class attributes
141 | self.student_prompts = processed_dataset["student_prompt"]
142 | self.teacher_prompts = processed_dataset["teacher_prompt"]
143 | self.responses = processed_dataset["response"]
144 | self.student_prompts_ids_lens = processed_dataset["student_prompt_ids_len"]
145 | self.teacher_prompts_ids_lens = processed_dataset["teacher_prompt_ids_len"]
146 | self.response_ranges = processed_dataset["response_ranges"] if self.multiturn else None
147 | self.gold_idx = processed_dataset["gold_idx"]
148 |
149 | def process_data(self, data):
150 | if self.multiturn and self.output_key:
151 | data[self.input_key].append(data[self.output_key])
152 | data[self.output_key] = None
153 |
154 | if self.multiturn:
155 | assert (
156 | not self.output_key or not data[self.output_key]
157 | ), "You should put the whole trajactory into data[input_key] and do not set output_key"
158 | input_key = self.input_key
159 | apply_chat_template = self.apply_chat_template
160 | response_ranges = []
161 | for idx, message in enumerate(data[input_key]):
162 | if message["role"] == "assistant":
163 | prompt = apply_chat_template(data[input_key][:idx], tokenize=False, add_generation_prompt=True)
164 | response = apply_chat_template(data[input_key][: idx + 1], tokenize=False)[len(prompt) :]
165 |
166 | start_idx = (
167 | self.tokenizer(
168 | prompt,
169 | max_length=self.max_length,
170 | padding=False,
171 | truncation=True,
172 | return_tensors="pt",
173 | add_special_tokens=False,
174 | )["attention_mask"]
175 | .int()
176 | .sum()
177 | .item()
178 | )
179 |
180 | end_idx = (
181 | start_idx
182 | + self.tokenizer(
183 | response,
184 | max_length=self.max_length,
185 | padding=False,
186 | truncation=True,
187 | return_tensors="pt",
188 | add_special_tokens=False,
189 | )["attention_mask"]
190 | .int()
191 | .sum()
192 | .item()
193 | - 1
194 | )
195 | response_ranges.append((start_idx, end_idx)) # left close right open
196 |
197 | student_template_prompt, teacher_template_prompt, student_response = preprocess_data(
198 | data,
199 | None if self.pretrain_mode else self.input_template,
200 | self.input_key,
201 | self.output_key,
202 | apply_chat_template=None if self.pretrain_mode else self.apply_chat_template,
203 | multiturn=self.multiturn,
204 | multi_gold_doc=self.multi_gold_doc
205 | )
206 |
207 | if not self.pretrain_mode:
208 | student_prompt_token = self.tokenizer(
209 | student_template_prompt,
210 | max_length=self.max_length,
211 | padding=False,
212 | truncation=True,
213 | return_tensors="pt",
214 | add_special_tokens=False,
215 | )
216 | student_prompt_ids_len = student_prompt_token["attention_mask"].int().sum().item()
217 |
218 | teacher_prompt_token = self.tokenizer(
219 | teacher_template_prompt,
220 | max_length=self.max_length,
221 | padding=False,
222 | truncation=True,
223 | return_tensors="pt",
224 | add_special_tokens=False,
225 | )
226 | teacher_prompt_ids_len = teacher_prompt_token["attention_mask"].int().sum().item()
227 |
228 | # filter the sample whose length is greater than max_length (2 for answer length)
229 | if not student_template_prompt or not student_response or student_prompt_ids_len >= self.max_length - 2:
230 | student_template_prompt = None
231 | else:
232 | prompt_ids_len = 0
233 |
234 | return {
235 | "teacher_prompt": teacher_template_prompt,
236 | "student_prompt":student_template_prompt,
237 | "response": student_response,
238 | "student_prompt_ids_len": student_prompt_ids_len,
239 | "teacher_prompt_ids_len": teacher_prompt_ids_len,
240 | "response_ranges": response_ranges if self.multiturn else None,
241 | "gold_idx":data["gold_idx"]
242 | }
243 |
244 | def __len__(self):
245 | length = len(self.teacher_prompts)
246 | return length
247 |
248 | def __getitem__(self, idx):
249 | teacher_prompt = self.teacher_prompts[idx]
250 | student_prompt = self.student_prompts[idx]
251 | response = self.responses[idx]
252 | teacher_prompt_ids_len = self.teacher_prompts_ids_lens[idx]
253 | student_prompt_ids_len = self.student_prompts_ids_lens[idx]
254 |
255 | if not self.pretrain_mode:
256 | teacher_text = (teacher_prompt + response).rstrip("\n")
257 | student_text = (student_prompt + response).rstrip("\n")
258 |
259 | if not teacher_text.endswith(self.tokenizer.eos_token):
260 | teacher_text += " " + self.tokenizer.eos_token
261 | if not student_text.endswith(self.tokenizer.eos_token):
262 | student_text += " " + self.tokenizer.eos_token
263 | else:
264 | text = prompt
265 |
266 | teacher_input_token = self.tokenizer(
267 | teacher_text,
268 | max_length=self.max_length,
269 | padding=False,
270 | truncation=True,
271 | return_tensors="pt",
272 | add_special_tokens=False,
273 | )
274 | student_input_token = self.tokenizer(
275 | student_text,
276 | max_length=self.max_length,
277 | padding=False,
278 | truncation=True,
279 | return_tensors="pt",
280 | add_special_tokens=False,
281 | )
282 |
283 | if not self.pretrain_mode:
284 | # to avoid EOS_token truncation
285 | teacher_input_token["input_ids"][0][-1] = self.tokenizer.eos_token_id
286 | teacher_input_token["attention_mask"][0][-1] = True
287 |
288 | student_input_token["input_ids"][0][-1] = self.tokenizer.eos_token_id
289 | student_input_token["attention_mask"][0][-1] = True
290 | info = {
291 | "teacher_input": teacher_prompt,
292 | "student_input": student_prompt,
293 | "output": response,
294 | "teacher_input_length": teacher_input_token["attention_mask"].int().sum().item(),
295 | "student_input_length": student_input_token["attention_mask"].int().sum().item(),
296 | "response_ranges": self.response_ranges[idx] if self.multiturn else None,
297 | "gold_idx":self.gold_idx[idx]
298 | }
299 | # print(info)
300 | return teacher_input_token["input_ids"], teacher_input_token["attention_mask"], teacher_prompt_ids_len, \
301 | student_input_token["input_ids"], student_input_token["attention_mask"],student_prompt_ids_len, info
302 |
303 | def collate_fn(self, item_list):
304 | # logger.info(f"Collate fn ....")
305 | teacher_prompt_id_lens = []
306 | student_prompt_id_lens = []
307 |
308 | teacher_input_ids = []
309 | student_input_ids = []
310 |
311 | teacher_attention_masks = []
312 | student_attention_masks = []
313 | infos = {"teacher_prompt": [], "output": [],\
314 | "student_prompt": [],"gold_idx":[]}
315 |
316 | for teacher_input_id, teacher_attention_mask, teacher_prompt_ids_len, \
317 | student_input_id, student_attention_mask, student_prompt_ids_len, info in item_list:
318 |
319 | teacher_prompt_id_lens.append(teacher_prompt_ids_len)
320 | student_prompt_id_lens.append(student_prompt_ids_len)
321 | teacher_input_ids.append(teacher_input_id)
322 | student_input_ids.append(student_input_id)
323 | teacher_attention_masks.append(teacher_attention_mask)
324 | student_attention_masks.append(student_attention_mask)
325 | infos["teacher_prompt"].append(info["teacher_input"])
326 | infos["student_prompt"].append(info["student_input"])
327 | infos["output"].append(info["output"])
328 | infos["gold_idx"].append(info["gold_idx"])
329 |
330 |
331 | teacher_input_ids = zero_pad_sequences(teacher_input_ids, "right", self.tokenizer.pad_token_id)
332 | teacher_attention_masks = zero_pad_sequences(teacher_attention_masks, "right")
333 |
334 | student_input_ids = zero_pad_sequences(student_input_ids, "right", self.tokenizer.pad_token_id)
335 | student_attention_masks = zero_pad_sequences(student_attention_masks, "right")
336 | return teacher_prompt_id_lens,teacher_input_ids,teacher_attention_masks, \
337 | student_prompt_id_lens,student_input_ids,student_attention_masks,infos
338 |
339 | def packing_collate_fn(self, item_list):
340 | packed_input_ids = []
341 | packed_attention_masks = []
342 | prompt_ids_lens = []
343 | infos = {"input_length": [], "response_ranges": [] if self.multiturn else None}
344 | index = 1
345 | for prompt_ids_len, input_id, attention_mask, info in item_list:
346 | packed_input_ids.append(input_id.flatten())
347 | packed_attention_masks.append(torch.full_like(input_id.flatten(), index))
348 | prompt_ids_lens.append(prompt_ids_len)
349 | infos["input_length"].append(info["input_length"])
350 | if self.multiturn:
351 | if len(infos["response_ranges"]) >= 1:
352 | for i in range(len(info["response_ranges"])):
353 | info["response_ranges"][i][0] += infos["response_ranges"][-1][-1][
354 | 1
355 | ] # end_index of the last response of the last item
356 | info["response_ranges"][i][1] += infos["response_ranges"][-1][-1][1]
357 | infos["response_ranges"].append(info["response_ranges"])
358 | index += 1
359 |
360 | packed_input_ids = torch.cat(packed_input_ids, dim=0).unsqueeze(0)
361 | packed_attention_masks = torch.cat(packed_attention_masks, dim=0).unsqueeze(0)
362 |
363 | if (
364 | self.multiple_of > 1 and packed_input_ids.numel() % self.multiple_of != 0
365 | ): # not divisible by multiple_of; here we align for grouping
366 | padding_len = self.multiple_of - (packed_input_ids.numel() % self.multiple_of)
367 | packed_input_ids = F.pad(packed_input_ids, (0, padding_len), value=self.tokenizer.pad_token_id)
368 | packed_attention_masks = F.pad(packed_attention_masks, (0, padding_len), value=0)
369 |
370 | return prompt_ids_lens, packed_input_ids, packed_attention_masks, infos
371 |
--------------------------------------------------------------------------------
/kd/loss_design.py:
--------------------------------------------------------------------------------
1 | from openrlhf.trainer import KDTrainer
2 | from typing import Optional, Tuple,List
3 | from tqdm import tqdm
4 | import torch
5 | import torch.distributed as dist
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import os
9 | from abc import ABC
10 | import torch
11 |
12 | from torch.optim import Optimizer
13 | from tqdm import tqdm
14 | from openrlhf.models import GPTLMLoss, KDLoss
15 | from openrlhf.utils.distributed_sampler import DistributedSampler
16 | import logging
17 | from typing import Optional, Tuple
18 | import torch
19 | from torch.distributed import get_rank, get_world_size, all_gather, broadcast
20 | import torch.distributed as dist
21 | import torch.nn as nn
22 | import torch.nn.functional as F
23 | class AdaptiveKLWeightedKLLoss(nn.Module):
24 | """
25 | KL Divergence Loss with Adaptive KL-Based Weighting
26 | """
27 | def __init__(self, temperature=1.0, batch_norm=False,attention_dillution_coef=1):
28 | super().__init__()
29 | self.IGNORE_INDEX = -100
30 | self.temperature = temperature
31 | self.epsilon = 1e-4 # 防止除零
32 | self.attention_dillution_coef = attention_dillution_coef
33 | self.batch_norm = batch_norm
34 |
35 | def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor,
36 | student_label: torch.Tensor, teacher_label: torch.Tensor,gold_idxs: List[int]) -> torch.Tensor:
37 | """
38 | Inputs:
39 | - logits: Student logits, shape [B, L, C]
40 | - teacher_logits: Teacher logits, shape [B, L, C]
41 | - *_label: token-level masks, shape [B, L]
42 | """
43 | gold_idxs = torch.tensor(gold_idxs, device=logits.device)
44 |
45 | teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L]
46 | student_mask = (student_label != self.IGNORE_INDEX) # [B, L]
47 | B = teacher_mask.shape[0]
48 | assert torch.all(torch.sum(teacher_mask, dim=1) == torch.sum(student_mask, dim=1)), "Mask mismatch"
49 |
50 | teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) # [B, L, C]
51 | student_log_probs = F.log_softmax(logits / self.temperature, dim=-1) # [B, L, C]
52 | kl_divs = torch.stack([
53 | F.kl_div(
54 | student_log_probs[i][student_mask[i]],
55 | teacher_log_probs[i][teacher_mask[i]].exp(),
56 | reduction="batchmean"
57 | )
58 | for i in range(B)
59 | ])
60 | # mask = gold_idxs == 0
61 | # local_invaried_tensor = kl_divs[mask].sum() # [1]
62 | # local_kl_tensor = kl_divs[~mask] # [3]
63 | local_kl_tensor = kl_divs
64 | print(f"rank = {get_rank()}")
65 | print(f"local_kl_tensor = {local_kl_tensor}")
66 | print(f"gold idxs = {gold_idxs}")
67 | if self.batch_norm:
68 | world_size = dist.get_world_size()
69 | gathered_kl_list = [torch.zeros_like(local_kl_tensor) for _ in range(world_size)]
70 | gathered_gold_indexs = [torch.zeros_like(gold_idxs) for _ in range(world_size)]
71 | all_gather(gathered_kl_list, local_kl_tensor)
72 | all_gather(gathered_gold_indexs, gold_idxs)
73 | all_kl = torch.stack(gathered_kl_list, dim=0) # [8, 4]
74 | all_gold_indexs = torch.stack(gathered_gold_indexs, dim=0) #[8,4]
75 | print(f"all_kl.shape = {all_kl.shape}")
76 | print(f"all_gold_index shape = {all_gold_indexs.shape}")
77 | # 仅在Rank 0计算权重后广播,确保一致性
78 | if dist.get_rank() == 0:
79 | weights = torch.zeros_like(all_gold_indexs,dtype=torch.float32)
80 | mask = all_gold_indexs == 0
81 | weights[mask] = self.attention_dillution_coef
82 | unique_gold_indexs = torch.unique(all_gold_indexs[~mask])
83 | # unique_gold_indexs = torch.unique(all_gold_indexs)
84 | pos_mask_list = {}
85 | pos_kl_list = []
86 | for unique_gold_index in unique_gold_indexs:
87 | pos_mask = all_gold_indexs == unique_gold_index
88 | avg_pos_kl = torch.sum(all_kl[pos_mask])/torch.sum(pos_mask)
89 | # add frequency
90 |
91 | pos_kl_list.append(avg_pos_kl)
92 | print(f"uniaue gold indexs = {unique_gold_indexs}")
93 | print(f"pos_kl_list = {pos_kl_list}")
94 |
95 | pos_kl_tensor = torch.stack(pos_kl_list,dim=0)
96 | # pos_kl_coefs = pos_kl_tensor*pos_kl_tensor.numel()/(torch.sum(pos_kl_tensor)+self.epsilon)
97 | pos_kl_coefs =F.softmax(pos_kl_tensor,dim=0)*pos_kl_tensor.numel()
98 | # pos_kl_coefs =F.softmax(pos_kl_tensor,dim=0)
99 |
100 |
101 | print(f"pos_kl_coefs = {pos_kl_coefs}")
102 |
103 | for pos_idx,unique_gold_index in enumerate(unique_gold_indexs):
104 | pos_mask = all_gold_indexs == unique_gold_index
105 | weights[pos_mask] = pos_kl_coefs[pos_idx]/torch.sum(pos_mask)*(all_kl[pos_mask]/max(all_kl[pos_mask]))
106 |
107 | else:
108 | weights = torch.zeros_like(all_gold_indexs,dtype=torch.float32)
109 | dist.broadcast(weights, src=0)
110 |
111 | print(f"rank = {get_rank()}, weights = {weights}")
112 | print(f"rank = {get_rank()}, all_kl = {all_kl}")
113 | total_loss = torch.sum(weights[get_rank()].detach() * local_kl_tensor)
114 | else:
115 | total_loss = torch.sum(local_kl_tensor) + local_invaried_tensor
116 | # total_loss = torch.sum(torch.clamp((local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor)).detach(),min=0.3,max=self.attention_dillution_coef) * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef
117 | # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef
118 | # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor)
119 |
120 | print(f"total_loss = {total_loss}")
121 | return total_loss * (self.temperature ** 2)
122 |
123 |
124 | class AdaptiveKLWeightedKLLoss_multi(nn.Module):
125 | """
126 | KL Divergence Loss with Adaptive KL-Based Weighting
127 | """
128 | def __init__(self, temperature=1.0, batch_norm=False,attention_dillution_coef=1):
129 | super().__init__()
130 | self.IGNORE_INDEX = -100
131 | self.temperature = temperature
132 | self.epsilon = 1e-4 # 防止除零
133 | self.attention_dillution_coef = attention_dillution_coef
134 | self.batch_norm = batch_norm
135 |
136 | def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor,
137 | student_label: torch.Tensor, teacher_label: torch.Tensor,gold_idxs: List[int]) -> torch.Tensor:
138 | """
139 | Inputs:
140 | - logits: Student logits, shape [B, L, C]
141 | - teacher_logits: Teacher logits, shape [B, L, C]
142 | - *_label: token-level masks, shape [B, L]
143 | """
144 | gold_idxs = torch.tensor(gold_idxs, device=logits.device)
145 |
146 | teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L]
147 | student_mask = (student_label != self.IGNORE_INDEX) # [B, L]
148 | B = teacher_mask.shape[0]
149 | assert torch.all(torch.sum(teacher_mask, dim=1) == torch.sum(student_mask, dim=1)), "Mask mismatch"
150 |
151 | teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) # [B, L, C]
152 | student_log_probs = F.log_softmax(logits / self.temperature, dim=-1) # [B, L, C]
153 | # breakpoint()
154 | kl_divs = torch.stack([
155 | F.kl_div(
156 | student_log_probs[i][student_mask[i]],
157 | teacher_log_probs[i][teacher_mask[i]].exp(),
158 | reduction="batchmean"
159 | )
160 | for i in range(B)
161 | ])
162 | # mask = gold_idxs == 0
163 | # local_invaried_tensor = kl_divs[mask].sum() # [1]
164 | # local_kl_tensor = kl_divs[~mask] # [3]
165 | local_kl_tensor = kl_divs
166 | print(f"rank = {get_rank()}")
167 | print(f"local_kl_tensor = {local_kl_tensor}")
168 | print(f"gold idxs = {gold_idxs}")
169 | print(f"observed pos: {torch.sum(teacher_mask, dim=1)}")
170 | if self.batch_norm:
171 | world_size = dist.get_world_size()
172 | gathered_kl_list = [torch.zeros_like(local_kl_tensor) for _ in range(world_size)]
173 | gathered_gold_indexs = [torch.zeros_like(gold_idxs) for _ in range(world_size)]
174 | all_gather(gathered_kl_list, local_kl_tensor)
175 | all_gather(gathered_gold_indexs, gold_idxs)
176 | all_kl = torch.stack(gathered_kl_list, dim=0) # [8, 4]
177 | all_gold_indexs = torch.stack(gathered_gold_indexs, dim=0) #[8,4,2]
178 | print(f"all_kl.shape = {all_kl.shape}")
179 | print(f"all_gold_index shape = {all_gold_indexs.shape}")
180 | # 仅在Rank 0计算权重后广播,确保一致性
181 | if dist.get_rank() == 0:
182 | m,n,_ = all_gold_indexs.size()
183 | weights = torch.zeros((m,n),dtype=torch.float32,device=all_gold_indexs.device)
184 | # weights = F.softmax(all_kl.flatten())*all_kl.numel().view(m,n)
185 | # mask = (all_gold_indexs == torch.tensor([0, 1],device=all_gold_indexs.device)).all(dim=-1)
186 | # weights[mask] = self.attention_dillution_coef
187 | # weights[~mask] = all_kl[~mask]/torch.sum(all_kl[~mask])*all_kl[~mask].numel()
188 | # weights = F.softmax(all_kl.flatten(),dim=0).view(m,n)*all_kl.numel()
189 | weights = all_kl.flatten()/torch.sum(all_kl.flatten()+self.epsilon)*all_kl.numel()
190 |
191 |
192 |
193 | # mask = all_gold_indexs == 0
194 | # weights[mask] = self.attention_dillution_coef
195 | # unique_gold_indexs = torch.unique(all_gold_indexs[~mask])
196 | # # unique_gold_indexs = torch.unique(all_gold_indexs)
197 | # pos_mask_list = {}
198 | # pos_kl_list = []
199 | # for unique_gold_index in unique_gold_indexs:
200 | # pos_mask = all_gold_indexs == unique_gold_index
201 | # avg_pos_kl = torch.sum(all_kl[pos_mask])/torch.sum(pos_mask)
202 | # # add frequency
203 |
204 | # pos_kl_list.append(avg_pos_kl)
205 | # print(f"unique gold indexs = {unique_gold_indexs}")
206 | # print(f"pos_kl_list = {pos_kl_list}")
207 |
208 | # pos_kl_tensor = torch.stack(pos_kl_list,dim=0)
209 | # # pos_kl_coefs = pos_kl_tensor*pos_kl_tensor.numel()/(torch.sum(pos_kl_tensor)+self.epsilon)
210 | # pos_kl_coefs =F.softmax(pos_kl_tensor,dim=0)*pos_kl_tensor.numel()
211 | # # pos_kl_coefs =F.softmax(pos_kl_tensor,dim=0)
212 |
213 |
214 | # print(f"pos_kl_coefs = {pos_kl_coefs}")
215 |
216 | # for pos_idx,unique_gold_index in enumerate(unique_gold_indexs):
217 | # pos_mask = all_gold_indexs == unique_gold_index
218 | # weights[pos_mask] = pos_kl_coefs[pos_idx]/torch.sum(pos_mask)*(all_kl[pos_mask]/max(all_kl[pos_mask]))
219 |
220 | else:
221 | m,n,_ = all_gold_indexs.size()
222 | weights = torch.zeros((m,n),dtype=torch.float32,device=all_gold_indexs.device)
223 | dist.broadcast(weights, src=0)
224 |
225 | print(f"rank = {get_rank()}, weights = {weights}")
226 | print(f"rank = {get_rank()}, all_kl = {all_kl}")
227 | total_loss = torch.sum(weights[get_rank()].detach() * local_kl_tensor)
228 | else:
229 | total_loss = torch.sum(local_kl_tensor) + local_invaried_tensor
230 | # total_loss = torch.sum(torch.clamp((local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor)).detach(),min=0.3,max=self.attention_dillution_coef) * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef
231 | # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef
232 | # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor)
233 |
234 | print(f"total_loss = {total_loss}")
235 | return total_loss * (self.temperature ** 2)
236 | # class AdaptiveKLWeightedKLLoss(nn.Module):
237 | # """
238 | # KL Divergence Loss with Adaptive KL-Based Weighting
239 | # """
240 | # def __init__(self, temperature=1.0, batch_norm=False,attention_dillution_coef=1):
241 | # super().__init__()
242 | # self.IGNORE_INDEX = -100
243 | # self.temperature = temperature
244 | # self.epsilon = 1e-4 # 防止除零
245 | # self.attention_dillution_coef = attention_dillution_coef
246 | # self.batch_norm = batch_norm
247 |
248 | # def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor,
249 | # student_label: torch.Tensor, teacher_label: torch.Tensor,gold_idxs: List[int]) -> torch.Tensor:
250 | # """
251 | # Inputs:
252 | # - logits: Student logits, shape [B, L, C]
253 | # - teacher_logits: Teacher logits, shape [B, L, C]
254 | # - *_label: token-level masks, shape [B, L]
255 | # """
256 | # gold_idxs = torch.tensor(gold_idxs, device=logits.device)
257 |
258 | # teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L]
259 | # student_mask = (student_label != self.IGNORE_INDEX) # [B, L]
260 | # B = teacher_mask.shape[0]
261 | # assert torch.all(torch.sum(teacher_mask, dim=1) == torch.sum(student_mask, dim=1)), "Mask mismatch"
262 |
263 | # teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) # [B, L, C]
264 | # student_log_probs = F.log_softmax(logits / self.temperature, dim=-1) # [B, L, C]
265 | # kl_divs = torch.stack([
266 | # F.kl_div(
267 | # student_log_probs[i][student_mask[i]],
268 | # teacher_log_probs[i][teacher_mask[i]].exp(),
269 | # reduction="batchmean"
270 | # )
271 | # for i in range(B)
272 | # ])
273 | # # mask = gold_idxs == 0
274 | # # local_invaried_tensor = kl_divs[mask].sum() # [1]
275 | # # local_kl_tensor = kl_divs[~mask] # [3]
276 | # local_kl_tensor = kl_divs
277 | # print(f"rank = {get_rank()}")
278 | # print(f"local_kl_tensor = {local_kl_tensor}")
279 | # print(f"gold idxs = {gold_idxs}")
280 | # if self.batch_norm:
281 | # world_size = dist.get_world_size()
282 | # gathered_kl_list = [torch.zeros_like(local_kl_tensor) for _ in range(world_size)]
283 | # gathered_gold_indexs = [torch.zeros_like(gold_idxs) for _ in range(world_size)]
284 | # all_gather(gathered_kl_list, local_kl_tensor)
285 | # all_gather(gathered_gold_indexs, gold_idxs)
286 | # all_kl = torch.stack(gathered_kl_list, dim=0) # [8, 4]
287 | # all_gold_indexs = torch.stack(gathered_gold_indexs, dim=0) #[8,4]
288 | # print(f"all_kl.shape = {all_kl.shape}")
289 | # print(f"all_gold_index shape = {all_gold_indexs.shape}")
290 | # # 仅在Rank 0计算权重后广播,确保一致性
291 | # if dist.get_rank() == 0:
292 | # mask = all_gold_indexs == 0
293 | # weights = torch.zeros_like(all_gold_indexs,dtype=torch.float32)
294 | # weights[mask] = self.attention_dillution_coef
295 | # print(f"mask = {mask}")
296 | # print(f"~mask = {~mask}")
297 | # weights[~mask] = all_kl[~mask]*torch.sum(~mask)/(torch.sum(all_kl[~mask]) + self.epsilon)
298 | # else:
299 | # weights = torch.zeros_like(all_gold_indexs,dtype=torch.float32)
300 | # dist.broadcast(weights, src=0)
301 |
302 | # print(f"rank = {get_rank()}, weights = {weights}")
303 | # print(f"rank = {get_rank()}, all_kl = {all_kl}")
304 | # total_loss = torch.sum(weights[get_rank()].detach() * local_kl_tensor)
305 | # else:
306 | # total_loss = torch.sum(local_kl_tensor) + local_invaried_tensor
307 | # # total_loss = torch.sum(torch.clamp((local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor)).detach(),min=0.3,max=self.attention_dillution_coef) * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef
308 | # # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor) + local_invaried_tensor * self.attention_dillution_coef
309 | # # total_loss = torch.sum(local_kl_tensor.numel() * local_kl_tensor / torch.sum(local_kl_tensor).detach() * local_kl_tensor)
310 |
311 | # print(f"total_loss = {total_loss}")
312 | # return total_loss * (self.temperature ** 2)
313 | class MY_KDLoss(nn.Module):
314 | """
315 | Language Model Knowledge Distillation Loss
316 | """
317 | def __init__(self):
318 | super().__init__()
319 | self.IGNORE_INDEX = -100
320 | self.temperature = 1.0
321 | self.chunk_size = 64
322 | self.decay_factor = 0.99
323 | def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor,
324 | student_label: torch.Tensor, teacher_label: torch.Tensor) -> torch.Tensor:
325 | # 计算教师和学生的有效 token mask
326 | teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L]
327 | student_mask = (student_label != self.IGNORE_INDEX) # [B, L]
328 | # 确保 mask 位置一致
329 | assert torch.sum(teacher_mask) == torch.sum(student_mask), "Teacher and student masks must have the same number of valid tokens"
330 | B = teacher_mask.shape[0]
331 | L = teacher_mask.shape[1]
332 | teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1, dtype=torch.float32) # 教师 log_softmax
333 | student_log_probs = F.log_softmax(logits / self.temperature, dim=-1, dtype=torch.float32) # [B, L, V]
334 |
335 | position_weights = torch.tensor([3*self.decay_factor ** i for i in range(L)], dtype=torch.float32, device=teacher_logits.device)
336 | # 对每个样本的有效位置计算 KL loss
337 | kl_divs = []
338 | for i in range(B):
339 | kl_loss = F.kl_div(student_log_probs[i][student_mask[i]], teacher_log_probs[i][teacher_mask[i]].exp(),reduction="none") # seq*V
340 | # breakpoint()
341 |
342 | kl_loss = torch.mean(torch.sum(kl_loss,dim=-1)*position_weights[:torch.sum(student_mask[i])])
343 |
344 | kl_divs.append(kl_loss)
345 | kl_divs = torch.stack(kl_divs)
346 | # 计算最终的 KL loss
347 | return kl_divs.mean()
348 |
349 | # class MY_KDLoss(nn.Module):
350 | # """
351 | # Language Model Knowledge Distillation Loss
352 | # """
353 | # def __init__(self):
354 | # super().__init__()
355 | # self.IGNORE_INDEX = -100
356 | # self.temperature = 1.0
357 | # def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor,
358 | # student_label: torch.Tensor, teacher_label: torch.Tensor) -> torch.Tensor:
359 | # # 计算教师和学生的有效 token mask
360 | # teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L]
361 | # student_mask = (student_label != self.IGNORE_INDEX) # [B, L]
362 | # # 确保 mask 位置一致
363 | # B = teacher_mask.shape[0]
364 | # assert torch.sum(teacher_mask) == torch.sum(student_mask), "Teacher and student masks must have the same number of valid tokens"
365 |
366 | # teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1, dtype=torch.float32) # 教师 log_softmax
367 | # student_log_probs = F.log_softmax(logits / self.temperature, dim=-1, dtype=torch.float32) # 学生 log_softmax
368 |
369 | # kl_divs = torch.stack([
370 | # F.kl_div(
371 | # student_log_probs[i][student_mask[i]],
372 | # teacher_log_probs[i][teacher_mask[i]].exp(),
373 | # reduction="batchmean"
374 | # )
375 | # for i in range(B)
376 | # ])
377 | # kl_loss = kl_divs.mean()
378 | # return kl_loss * (self.temperature ** 2)
379 | class MY_rankLoss(nn.Module):
380 | """
381 | Language Model Knowledge Distillation Loss
382 | """
383 | def __init__(self,top_k=30, margin=0.5):
384 | super().__init__()
385 | self.IGNORE_INDEX = -100
386 | self.temperature = 2.0
387 | self.top_k = top_k
388 | self.margin = margin
389 | self.ranking_loss = nn.MarginRankingLoss(margin=self.margin)
390 | def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor,
391 | student_label: torch.Tensor, teacher_label: torch.Tensor) -> torch.Tensor:
392 | # 计算教师和学生的有效 token mask
393 | teacher_mask = (teacher_label != self.IGNORE_INDEX).unsqueeze(-1) # [B, L, 1]
394 | student_mask = (student_label != self.IGNORE_INDEX).unsqueeze(-1) # [B, L, 1]
395 | # 确保 mask 位置一致
396 | assert torch.sum(teacher_mask) == torch.sum(student_mask), "Teacher and student masks must have the same number of valid tokens"
397 | teacher_indices = teacher_mask.squeeze(-1).nonzero(as_tuple=True)
398 | student_indices = student_mask.squeeze(-1).nonzero(as_tuple=True)
399 | # 提取有效 token 的概率
400 | teacher_logits = teacher_logits[teacher_indices] # [有效Token数, C]
401 | student_logits = logits[student_indices] # [有效Token数, C]
402 | topk_values, topk_indices = torch.topk(teacher_logits, self.top_k, dim=-1) # 获取 top-k token
403 | student_topk_logits = torch.gather(student_logits, dim=-1, index=topk_indices)
404 | # 生成pairs
405 | i_indices, j_indices = torch.triu_indices(self.top_k, self.top_k, offset=1)
406 | # 计算pairwise差异
407 | teacher_i = topk_values[:, i_indices]
408 | teacher_j = topk_values[:, j_indices]
409 | student_i = student_topk_logits[:, i_indices]
410 | student_j = student_topk_logits[:, j_indices]
411 |
412 | # 生成ranking标签
413 | ranking_labels = (teacher_i > teacher_j).float() * 2 - 1
414 |
415 | # 添加权重
416 | rank_diff = torch.abs(i_indices - j_indices).float()
417 | weights = 1.0 / (rank_diff + 1)
418 | weights = weights.to(student_i.device)
419 |
420 | # 计算loss
421 | ranking_loss = (self.ranking_loss(student_i, student_j, ranking_labels) * weights).mean()
422 |
423 | return ranking_loss
424 | class MY_topkLoss(nn.Module):
425 | """
426 | Language Model Knowledge Distillation Loss
427 | """
428 | def __init__(self):
429 | super().__init__()
430 | self.IGNORE_INDEX = -100
431 | self.temperature = 1.0
432 | self.top_k = 20
433 | def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor,
434 | student_label: torch.Tensor, teacher_label: torch.Tensor) -> torch.Tensor:
435 | # 计算教师和学生的有效 token mask
436 | teacher_mask = (teacher_label != self.IGNORE_INDEX) # [B, L]
437 | student_mask = (student_label != self.IGNORE_INDEX) # [B, L]
438 | # 确保 mask 位置一致
439 | assert torch.sum(teacher_mask) == torch.sum(student_mask), "Teacher and student masks must have the same number of valid tokens"
440 | teacher_indices = teacher_mask.squeeze(-1).nonzero(as_tuple=True)
441 | student_indices = student_mask.squeeze(-1).nonzero(as_tuple=True)
442 | # 提取有效 token 的概率
443 | teacher_logits = teacher_logits[teacher_indices] # [有效Token数, C]
444 | student_logits = logits[student_indices] # [有效Token数, C]
445 | topk_values, topk_indices = torch.topk(teacher_logits, self.top_k, dim=-1) # 获取 top-k token
446 | student_topk_logits = torch.gather(student_logits, dim=-1, index=topk_indices)
447 |
448 |
449 | teacher_probs = F.softmax(topk_values / self.temperature, dim=-1)
450 | student_log_probs = F.log_softmax(student_topk_logits / self.temperature, dim=-1)
451 | kl_loss = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean")
452 | return kl_loss * (self.temperature ** 2)
453 | class MY_GPTLMLoss(nn.Module):
454 | """
455 | GPT Language Model Loss
456 | """
457 |
458 | def __init__(self, ring_attn_group=None):
459 | super().__init__()
460 | self.IGNORE_INDEX = -100
461 | self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
462 |
463 | self.ring_attn_group = ring_attn_group
464 | if self.ring_attn_group:
465 | self.ring_attn_rank = dist.get_rank(self.ring_attn_group)
466 | self.ring_attn_world_size = dist.get_world_size(self.ring_attn_group)
467 |
468 | def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
469 | # RingAttention
470 |
471 | if self.ring_attn_group is not None:
472 | logger.info(f" gptloss is entered: the first choice of ring_attn_group.")
473 | total_seq_len = labels.size(-1)
474 | seq_len_per_process = total_seq_len // self.ring_attn_world_size
475 | start_idx = self.ring_attn_rank * seq_len_per_process
476 | end_idx = min(start_idx + seq_len_per_process, total_seq_len)
477 | labels = labels[..., start_idx:end_idx]
478 |
479 | shift_logits = logits[..., :-1, :].contiguous()
480 | shift_labels = labels[..., 1:].contiguous()
481 |
482 | # if labels are all IGNORE_INDEX, then nn.CrossEntropyLoss will be nan
483 | if torch.all(shift_labels == self.IGNORE_INDEX):
484 | # Use mean of logits multiplied by 0 to maintain gradient flow
485 | loss = shift_logits.mean() * 0
486 | else:
487 | loss = self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
488 |
489 | dist.all_reduce(loss, op=dist.ReduceOp.SUM, group=self.ring_attn_group)
490 | loss = loss / self.ring_attn_world_size
491 | else:
492 | # breakpoint()
493 | shift_logits = logits[..., :-1, :].contiguous()
494 | shift_labels = labels[..., 1:].contiguous()
495 | # loss = self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
496 | all_logits = []
497 | all_labels = []
498 | batch_size = shift_labels.shape[0]
499 | shift_mask = (shift_labels != self.IGNORE_INDEX).int().unsqueeze(-1)
500 | for i in range(batch_size):
501 | shift_nonzero_indices = shift_mask[i].nonzero()[:,0]
502 | shift_nonzero_logits = shift_logits[i,shift_nonzero_indices,:]
503 | shift_nonzero_labels = shift_labels[i,shift_nonzero_indices]
504 | all_logits.append(shift_nonzero_logits)
505 | all_labels.append(shift_nonzero_labels)
506 |
507 | # 将所有batch的数据拼接起来
508 | concatenated_logits = torch.cat(all_logits, dim=0) # shape: [total_valid_tokens, vocab_size]
509 | concatenated_labels = torch.cat(all_labels, dim=0) # shape: [total_valid_tokens]
510 |
511 | # 统一计算loss
512 | loss = self.loss(concatenated_logits, concatenated_labels)
513 |
514 | return loss
515 |
--------------------------------------------------------------------------------