├── src ├── scripts │ ├── run_cal_data_volumn.sh │ ├── run_tokenize_split.sh │ ├── run_fetch_data.sh │ ├── run_save_hf_dataset.sh │ ├── launch.sh │ └── run_train_multi.sh ├── config │ └── ds2_config.json ├── train_utils.py ├── cal_data_volumn.py ├── split_data.py ├── tokenize_text.py ├── save_hf_dataset.py ├── fetch_data.py └── train.py ├── assets ├── pipeline.png └── llama-3-syne-logo.png ├── README_zh.md ├── LICENSE └── README.md /src/scripts/run_cal_data_volumn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python cal_data_volumn.py 4 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RUC-GSAI/Llama-3-SynE/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /assets/llama-3-syne-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RUC-GSAI/Llama-3-SynE/HEAD/assets/llama-3-syne-logo.png -------------------------------------------------------------------------------- /src/scripts/run_tokenize_split.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TOKENIZER_PATH=# TODO: Your tokenizer path, e.g., meta-llama/Meta-Llama-3-8B 4 | DATA_PATH=# TODO: Your data path 5 | MODEL_NAME=# TODO: Your model name 6 | 7 | python tokenize_text.py \ 8 | --tokenizer_path ${TOKENIZER_PATH} \ 9 | --data_path ${DATA_PATH} \ 10 | --model_name ${MODEL_NAME} \ 11 | --num_file 500000 \ 12 | --text_key text \ 13 | --num_worker 64 \ 14 | --skip_exist True 15 | 16 | FATHER_DATASETS=# TODO: Your father datasets 17 | 18 | python split_data.py \ 19 | --father_datasets ${FATHER_DATASETS} 20 | -------------------------------------------------------------------------------- /src/config/ds2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "zero_optimization": { 6 | "stage": 2, 7 | "allgather_partitions": true, 8 | "allgather_bucket_size": 2e8, 9 | "overlap_comm": true, 10 | "reduce_scatter": true, 11 | "reduce_bucket_size": 2e8, 12 | "contiguous_gradients": true 13 | }, 14 | "gradient_accumulation_steps": "auto", 15 | "gradient_clipping": "auto", 16 | "steps_per_print": 16, 17 | "train_batch_size": "auto", 18 | "train_micro_batch_size_per_gpu": "auto", 19 | "wall_clock_breakdown": false 20 | } -------------------------------------------------------------------------------- /src/scripts/run_fetch_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TOTAL_TOKEN_NUM=# TODO: Your total token number, e.g., 40 4 | CN_RATIO=# TODO: Your Chinese ratio, e.g., 0.1 5 | EN_RATIO=# TODO: Your English ratio, e.g., 0.7 6 | SYN_RATIO=# TODO: Your synthetic data ratio, e.g., 0.2 7 | ROOT_DIR=# TODO: Your data root directory, e.g., /data 8 | TOKENIZER_PATH=# TODO: Your tokenizer path, e.g., meta-llama/Meta-Llama-3-8B 9 | 10 | python fetch_data.py \ 11 | --total_token_num ${TOTAL_TOKEN_NUM} \ 12 | --cn_ratio ${CN_RATIO} \ 13 | --en_ratio ${EN_RATIO} \ 14 | --syn_ratio ${SYN_RATIO} \ 15 | --root_dir ${ROOT_DIR} \ 16 | --tokenizer_path ${TOKENIZER_PATH} \ 17 | -------------------------------------------------------------------------------- /src/scripts/run_save_hf_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TIMESTAMP_LST=# TODO: Your timestamp list provided by fetch_data.py 4 | TOKENIZER_PATH=# TODO: Your tokenizer path, e.g., meta-llama/Meta-Llama-3-8B 5 | MODEL_MAX_LENGTH=# TODO: Your model max length, e.g., 8192 6 | NUM_WORKERS=# TODO: Your number of workers, e.g., 32 7 | MIN_TEXT_LENGTH=# TODO: Your minimum text length, e.g., 10 8 | ROOT_DIR=# TODO: Your data root directory, e.g., /data 9 | 10 | python save_hf_dataset.py \ 11 | --timestamp_lst ${TIMESTAMP_LST} \ 12 | --tokenizer_path ${TOKENIZER_PATH} \ 13 | --model_max_length ${MODEL_MAX_LENGTH} \ 14 | --num_workers ${NUM_WORKERS} \ 15 | --min_text_length ${MIN_TEXT_LENGTH} \ 16 | --root_dir ${ROOT_DIR} \ 17 | --show_case # Show the first case 18 | -------------------------------------------------------------------------------- /src/scripts/launch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PWD=`pwd` 4 | 5 | ### Execute your job 6 | SCRIPT=run_train_multi.sh 7 | SCRIPT_PATH=${PWD}/train.py 8 | MODEL_NAME=# TODO: model name 9 | MODEL_PATH=# TODO: model path 10 | DATA_PREFIX=# TODO: data prefix 11 | DATA_PATH=# TODO: data path 12 | JOB_NAME=${MODEL_NAME}_CPT_${DATA_PREFIX} 13 | SAVE_DIR=# TODO: save dir, e.g., ${PWD}/model 14 | nodelist=(\ 15 | # "job-master-0" \ 16 | # "job-worker-0" \ 17 | ) 18 | MASTER_PORT=# TODO: master port 19 | 20 | NNODES="${#nodelist[@]}" 21 | LOCAL_HOST=`hostname` 22 | echo "'"$LOCAL_HOST"'" $NNODES $MASTER_PORT 23 | 24 | for ((i=0;i<${NNODES};i=i+1)) 25 | do 26 | echo "${nodelist[i]} => " "cd ${PWD} && bash ${SCRIPT} ${NNODES} $i ${nodelist[0]} ${MASTER_PORT} ${JOB_NAME} ${SCRIPT_PATH} ${MODEL_PATH} ${DATA_PATH} ${SAVE_DIR}" "&> ${PWD}/log/${JOB_NAME}_part${i}.log &" 27 | ssh -o ServerAliveInterval=60 "${nodelist[i]}" "cd ${PWD} && bash ${SCRIPT} ${NNODES} $i ${nodelist[0]} ${MASTER_PORT} ${JOB_NAME} ${SCRIPT_PATH} ${MODEL_PATH} ${DATA_PATH} ${SAVE_DIR}" &> ${PWD}/log/${JOB_NAME}_part${i}.log & 28 | done 29 | 30 | wait 31 | 32 | echo finished! 33 | -------------------------------------------------------------------------------- /src/scripts/run_train_multi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NNODES=$1 4 | NODE_RANK=$2 5 | MASTER_ADDR=$3 6 | MASTER_PORT=$4 7 | JOB_NAME=$5 8 | SCRIPT_PATH=$6 9 | MODEL_PATH=$7 10 | DATA_PATH=$8 11 | SAVE_DIR=$9 12 | 13 | export OMP_NUM_THREADS=24 14 | export WANDB_MODE=offline 15 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 16 | 17 | export CUDA_DEVICE_MAX_CONNECTIONS=1 18 | export NCCL_SOCKET_IFNAME=eth0 19 | export NCCL_IB_DISABLE=0 20 | export NCCL_IB_CUDA_SUPPORT=1 21 | export NCCL_IB_GID_INDEX=0 22 | 23 | MODEL_MAX_LENGTH=# TODO: Set model max length, e.g., 8192 24 | PER_DEVICE_TRAIN_BATCH_SIZE=# TODO: Set per device train batch size, e.g., 2 25 | GRADIENT_ACCUMULATION_STEPS=# TODO: Set gradient accumulation steps, e.g., 1 26 | SAVE_STEPS=# TODO: Set model save steps, e.g., 2000 27 | SAVE_TOTAL_LIMIT=# TODO: Set model save total limit, e.g., 10 28 | LEARNING_RATE=# TODO: Set learning rate, e.g., 1e-5 29 | WARMUP_RATIO=# TODO: Set warmup ratio, e.g., 0.0 30 | WEIGHT_DECAY=# TODO: Set weight decay, e.g., 0.0 31 | DEEPSPEED_CONFIG_PATH=# TODO: Set deepspeed config path, e.g., ./config/ds2_config.json 32 | GRADIENT_CHECKPOINTING=# TODO: Whether to use gradient checkpointing, e.g., True 33 | # Set `--flash_attention` to use FlashAttention 2. 34 | # Set `--use_wsd` to use the WSD optimizer. 35 | # Set `--no_shuffle` to disable shuffling. 36 | # Set `--load_text_dataset` to load raw text data and perform preprocessing (tokenization and grouping) before training. After preprocessing, the script will save the processed dataset to disk and exit. If False, it assumes that a preprocessed dataset is provided and loads it directly from disk. 37 | # Set `single_dataset` to load a single dataset from the specified `data_path`. If False, it will load and concatenate multiple datasets found in the `data_path` directory. 38 | # Set `--resume_from_checkpoint ` to resume training from a checkpoint. 39 | 40 | 41 | torchrun --nproc_per_node=8 \ 42 | --nnodes=${NNODES} \ 43 | --node_rank=${NODE_RANK} \ 44 | --master_addr=${MASTER_ADDR} \ 45 | --master_port=${MASTER_PORT} \ 46 | ${SCRIPT_PATH} \ 47 | --model_name_or_path ${MODEL_PATH} \ 48 | --data_path ${DATA_PATH} \ 49 | --bf16 True \ 50 | --output_dir ${SAVE_DIR}/${JOB_NAME} \ 51 | --num_train_epochs 1 \ 52 | --model_max_length ${MODEL_MAX_LENGTH} \ 53 | --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ 54 | --per_device_eval_batch_size 4 \ 55 | --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ 56 | --evaluation_strategy "no" \ 57 | --save_strategy "steps" \ 58 | --save_steps $SAVE_STEPS \ 59 | --save_total_limit $SAVE_TOTAL_LIMIT \ 60 | --learning_rate $LEARNING_RATE \ 61 | --warmup_ratio $WARMUP_RATIO \ 62 | --weight_decay $WEIGHT_DECAY \ 63 | --logging_steps 2 \ 64 | --deepspeed ${DEEPSPEED_CONFIG_PATH} \ 65 | --gradient_checkpointing ${GRADIENT_CHECKPOINTING} \ 66 | --report_to none \ 67 | --tf32 True \ 68 | --lr_scheduler_type "linear" \ 69 | --flash_attention \ 70 | --use_wsd \ 71 | &> ./log/${JOB_NAME}_part${NODE_RANK}.log 72 | -------------------------------------------------------------------------------- /src/train_utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import torch 3 | import datasets 4 | import transformers 5 | from typing import Dict, Union 6 | from torch.utils.data import DataLoader, SequentialSampler 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from transformers import Trainer 9 | from transformers.trainer_utils import seed_worker 10 | from transformers.utils import is_datasets_available 11 | 12 | 13 | class NoShuffleSeq2SeqTrainer(Trainer): 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | 17 | def get_train_dataloader(self) -> DataLoader: 18 | """ 19 | Returns the training [`~torch.utils.data.DataLoader`]. 20 | 21 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 22 | training if necessary) otherwise. 23 | 24 | Subclass and override this method if you want to inject some custom behavior. 25 | """ 26 | if self.train_dataset is None: 27 | raise ValueError("Trainer: training requires a train_dataset.") 28 | 29 | train_dataset = self.train_dataset 30 | data_collator = self.data_collator 31 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): 32 | train_dataset = self._remove_unused_columns( 33 | train_dataset, description="training" 34 | ) 35 | else: 36 | data_collator = self._get_collator_with_removed_columns( 37 | data_collator, description="training" 38 | ) 39 | 40 | dataloader_params = { 41 | "batch_size": self._train_batch_size, 42 | "collate_fn": data_collator, 43 | "num_workers": self.args.dataloader_num_workers, 44 | "pin_memory": self.args.dataloader_pin_memory, 45 | } 46 | 47 | if not isinstance(train_dataset, torch.utils.data.IterableDataset): 48 | # dataloader_params["sampler"] = SequentialSampler(self.train_dataset) # Original 49 | dataloader_params["sampler"] = self._get_eval_sampler(self.train_dataset) 50 | dataloader_params["drop_last"] = self.args.dataloader_drop_last 51 | dataloader_params["worker_init_fn"] = seed_worker 52 | dataloader_params["shuffle"] = False 53 | 54 | return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) 55 | 56 | 57 | def get_wsd_scheduler( 58 | optimizer, num_warmup_steps, num_training_steps, last_epoch=-1, stable_ratio=1.0 59 | ): 60 | def lr_lambda(current_step): 61 | if current_step < num_warmup_steps: 62 | return float(current_step) / float(max(1, num_warmup_steps)) 63 | num_stable_steps = stable_ratio * num_training_steps 64 | if current_step < num_stable_steps: 65 | return 1.0 66 | return max( 67 | 0.1, 68 | float(num_training_steps - current_step) 69 | / float(max(1, num_training_steps - num_stable_steps)), 70 | ) 71 | 72 | return LambdaLR(optimizer, lr_lambda, last_epoch) 73 | 74 | 75 | class WSDTrainer(Trainer): 76 | def __init__(self, *args, **kwargs): 77 | super().__init__(*args, **kwargs) 78 | 79 | def create_scheduler( 80 | self, num_training_steps: int, optimizer: torch.optim.Optimizer = None 81 | ): 82 | if self.lr_scheduler is None: 83 | self.lr_scheduler = get_wsd_scheduler( 84 | optimizer=self.optimizer if optimizer is None else optimizer, 85 | num_warmup_steps=self.args.get_warmup_steps(num_training_steps), 86 | num_training_steps=num_training_steps, 87 | ) 88 | self._created_lr_scheduler = True 89 | print("Using WSD scheduler") 90 | return self.lr_scheduler 91 | 92 | def create_optimizer_and_scheduler(self, num_training_steps: int): 93 | self.create_optimizer() 94 | optimizer = self.optimizer 95 | self.create_scheduler( 96 | num_training_steps=num_training_steps, optimizer=optimizer 97 | ) 98 | print("Scheduler", self.lr_scheduler) 99 | 100 | 101 | class WSDNoShuffleTrainer(NoShuffleSeq2SeqTrainer, WSDTrainer): 102 | def __init__(self, *args, **kwargs): 103 | super().__init__(*args, **kwargs) 104 | -------------------------------------------------------------------------------- /src/cal_data_volumn.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | This script calculates the data volume and ratios for two stages of a project. 4 | 5 | Stage I: Bilingual Adaptation 6 | Stage II: Synthetic Enhancement 7 | 8 | Variables: 9 | stage1_volumn (float): Volume for Stage I. 10 | stage2_volumn (float): Volume for Stage II. 11 | first_stage (list): Ratios for Chinese (cn) and English (en) in Stage I. 12 | second_stage (list): Ratios for Chinese (cn), English (en), and Synthetic (synth) in Stage II. 13 | cn_ratio (list): Ratios for different Chinese data sources. 14 | en_ratio (list): Ratios for different English data sources. 15 | 16 | Calculations: 17 | stage1_ratio_lst (list): Calculated ratios for Stage I based on `first_stage` and `cn_ratio`/`en_ratio`. 18 | stage2_ratio_lst (list): Calculated ratios for Stage II based on `second_stage` and `cn_ratio`/`en_ratio`. 19 | data (list): Calculated data volumes for each data source. 20 | data_ratio (list): Calculated data ratios as percentages. 21 | 22 | Outputs: 23 | Prints the calculated ratios for Stage I and Stage II. 24 | Prints the sum of the ratios for Stage I and Stage II. 25 | Prints the calculated data volumes for each data source. 26 | Prints the sum of the data volumes. 27 | Prints the calculated data ratios as percentages. 28 | Prints the sum of the data ratios as percentages. 29 | """ 30 | 31 | if __name__ == "__main__": 32 | print("Calculating Data Volume and Ratios for Two Stages") 33 | 34 | stage1_volumn = 92.5 # Stage I: Bilingual Adpatation 35 | stage2_volumn = 7.5 # Stage II: Synthetic Enhancement 36 | 37 | first_stage = [ 38 | 0.2, # cn 39 | 0.8, # en 40 | ] 41 | second_stage = [ 42 | 0.1, # cn 43 | 0.7, # en 44 | 0.2, # synth 45 | ] 46 | 47 | cn_ratio = [ 48 | 0.7, # web-cn 49 | 0.05, # encyclopedia-cn 50 | 0.2, # book-cn 51 | 0.05, # qa_forum-cn 52 | 0.0, 53 | 0.0, 54 | 0.0, 55 | ] 56 | en_ratio = [ 57 | 0.4, # web-en 58 | 0.05, # encyclopedia-en 59 | 0.15, # book-en 60 | 0.05, # qa_forum-en 61 | 0.1, # paper-en 62 | 0.1, # math-en 63 | 0.15, # code-en 64 | ] 65 | 66 | # Calculate the data ratio for the two stages 67 | stage1_ratio_lst = [ 68 | first_stage[0] * cn_ratio[0], 69 | first_stage[1] * en_ratio[0], 70 | first_stage[0] * cn_ratio[1], 71 | first_stage[1] * en_ratio[1], 72 | first_stage[0] * cn_ratio[2], 73 | first_stage[1] * en_ratio[2], 74 | first_stage[0] * cn_ratio[3], 75 | first_stage[1] * en_ratio[3], 76 | first_stage[1] * en_ratio[4], 77 | first_stage[1] * en_ratio[5], 78 | first_stage[1] * en_ratio[6], 79 | ] 80 | 81 | stage2_ratio_lst = [ 82 | second_stage[0] * cn_ratio[0], 83 | second_stage[1] * en_ratio[0], 84 | second_stage[0] * cn_ratio[1], 85 | second_stage[1] * en_ratio[1], 86 | second_stage[0] * cn_ratio[2], 87 | second_stage[1] * en_ratio[2], 88 | second_stage[0] * cn_ratio[3], 89 | second_stage[1] * en_ratio[3], 90 | second_stage[1] * en_ratio[4], 91 | second_stage[1] * en_ratio[5], 92 | second_stage[1] * en_ratio[6], 93 | second_stage[2], 94 | ] 95 | 96 | stage1_ratio_lst = [round(r, 3) for r in stage1_ratio_lst] 97 | stage2_ratio_lst = [round(r, 3) for r in stage2_ratio_lst] 98 | 99 | print("Stage 1 Ratios: ", stage1_ratio_lst) 100 | print("Stage 2 Ratios: ", stage2_ratio_lst) 101 | print("Sum of Stage 1 Ratios: ", sum(stage1_ratio_lst)) 102 | print("Sum of Stage 2 Ratios: ", sum(stage2_ratio_lst)) 103 | 104 | data = [] 105 | 106 | for cn_r, en_r in zip(cn_ratio, en_ratio): 107 | data.append( 108 | stage1_volumn * first_stage[0] * cn_r 109 | + stage1_volumn * first_stage[1] * en_r 110 | + stage2_volumn * second_stage[0] * cn_r 111 | + stage2_volumn * second_stage[1] * en_r 112 | ) 113 | data.append(stage2_volumn * second_stage[2]) 114 | data = [round(d, 5) for d in data] 115 | 116 | print("Data Volume List: ", data) 117 | print("Sum of Data Volumes: ", sum(data)) 118 | 119 | data_ratio = [d / sum(data) * 100 for d in data] 120 | data_ratio = [round(d, 2) for d in data_ratio] 121 | 122 | print("Data Ratios (%): ", data_ratio) 123 | print("Sum of Data Ratios (%): ", sum(data_ratio)) 124 | -------------------------------------------------------------------------------- /src/split_data.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | split_data.py 4 | 5 | This script processes datasets by splitting large JSONL files into smaller parts based on a maximum token limit. 6 | 7 | Functions: 8 | parse_args(): Parses command line arguments. 9 | main(): Main function to process datasets. 10 | write_to_file(): Writes the current batch of data to a file. 11 | add_data(data): Adds data to the current batch. 12 | load_data(fin): Loads data from a file. 13 | 14 | Command Line Arguments: 15 | --father_datasets: Comma-separated list of dataset paths. 16 | 17 | Usage: 18 | python split_data.py --father_datasets 19 | 20 | Example: 21 | python split_data.py --father_datasets /path/to/dataset1,/path/to/dataset2 22 | """ 23 | 24 | import os 25 | import json 26 | import glob 27 | import pathlib 28 | import numpy as np 29 | from IPython import embed 30 | from tqdm import tqdm 31 | import argparse 32 | 33 | # Define the maximum number of tokens allowed in a single file 34 | MAX_TOKEN = int(0.01 * 1000 * 1000 * 1000) 35 | 36 | 37 | # Parse command line arguments 38 | def parse_args(): 39 | parser = argparse.ArgumentParser(description="Process some datasets.") 40 | parser.add_argument( 41 | "--father_datasets", 42 | type=str, 43 | required=True, 44 | help="Comma-separated list of dataset paths", 45 | ) 46 | return parser.parse_args() 47 | 48 | 49 | # Process datasets 50 | def main(): 51 | args = parse_args() 52 | father_datasets = args.father_datasets.split(",") 53 | 54 | for fd in father_datasets: 55 | datasets = os.listdir(fd) 56 | for dataset_name in tqdm(datasets): 57 | raw_src_folder = os.path.join(fd, dataset_name) 58 | print("Processing {} ...".format(raw_src_folder)) 59 | embed() 60 | folder2file = {} 61 | try: 62 | # Traverse the directory to find JSONL files 63 | for root_dir, _, files in os.walk(raw_src_folder, topdown=False): 64 | for fp in files: 65 | if "splited_part" in fp: 66 | continue 67 | if not fp.endswith(".jsonl"): 68 | continue 69 | if root_dir not in folder2file: 70 | folder2file[root_dir] = [] 71 | folder2file[root_dir].append(os.path.join(root_dir, fp)) 72 | 73 | except FileNotFoundError: 74 | print("Error Dataset: {}".format(dataset_name)) 75 | continue 76 | except NotADirectoryError: 77 | print("Error Dataset: {}".format(dataset_name)) 78 | continue 79 | 80 | if len(folder2file) == 0: 81 | print("Error Dataset: {}".format(dataset_name)) 82 | continue 83 | 84 | for src_folder, src_files in folder2file.items(): 85 | all_data = [] 86 | tokens_num = [] 87 | num_tokens = 0 88 | cur_idx = 0 89 | 90 | # Write data to a file 91 | def write_to_file(): 92 | nonlocal all_data, num_tokens, cur_idx, tokens_num 93 | tgt_path = os.path.join( 94 | src_folder, "splited_part-{}.jsonl".format(cur_idx) 95 | ) 96 | tokens_num_tgt_path = os.path.join( 97 | src_folder, 98 | "splited_part-{}-tokens_{}.jsonl".format(cur_idx, num_tokens), 99 | ) 100 | print(tgt_path) 101 | with open(tgt_path, "w") as fout: 102 | for tmp_data in all_data: 103 | fout.write(json.dumps(tmp_data, ensure_ascii=False) + "\n") 104 | with open(tokens_num_tgt_path, "w") as fout: 105 | for tmp_data in tokens_num: 106 | fout.write(str(tmp_data) + "\n") 107 | num_tokens = 0 108 | cur_idx = cur_idx + 1 109 | all_data = [] 110 | tokens_num = [] 111 | 112 | # Add data to the current batch 113 | def add_data(data): 114 | nonlocal all_data, num_tokens, cur_idx 115 | all_data.append(data[0]) 116 | tokens_num.append(data[1]) 117 | num_tokens = num_tokens + data[1] 118 | if num_tokens > MAX_TOKEN: 119 | write_to_file() 120 | 121 | # Load data from a file 122 | def load_data(fin): 123 | data = fin.readline() 124 | if not data: 125 | return None 126 | else: 127 | json_data = json.loads(data) 128 | new_data = {"input_ids": json_data["input_ids"]} 129 | return (new_data, len(json_data["input_ids"])) 130 | 131 | src_fin = [] 132 | src_data = [] 133 | for fp in src_files: 134 | if "splited_part" in fp: 135 | continue 136 | fin = open(os.path.join(src_folder, fp)) 137 | src_fin.append(fin) 138 | src_data.append(load_data(fin)) 139 | 140 | # Process the data files 141 | while True: 142 | idx = None 143 | for i in range(len(src_data)): 144 | if src_data[i] is None: 145 | continue 146 | if idx is None: 147 | idx = i 148 | break 149 | 150 | if idx is None: 151 | break 152 | 153 | add_data(src_data[idx]) 154 | src_data[idx] = load_data(src_fin[idx]) 155 | 156 | if len(all_data) > 0: 157 | write_to_file() 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /src/tokenize_text.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | This script tokenizes text data using a specified tokenizer and saves the tokenized data to a target folder. 4 | It supports multiprocessing to speed up the tokenization process. 5 | 6 | Functions: 7 | get_tgt_folder(file_path, model_name): 8 | Get the target folder path based on the file path and model name. 9 | 10 | tokenize_text(dataset, tgt_folder, idx, text_key, is_first): 11 | Tokenize text data and save it to the target folder. 12 | 13 | start_mp(dataset, is_first): 14 | Start multiprocessing for tokenizing text data. 15 | 16 | Main: 17 | The script takes several command-line arguments: 18 | --tokenizer_path: Path to the tokenizer. 19 | --model_name: Name of the model. 20 | --data_path: Path to the data files. 21 | --num_files: Number of files to process. 22 | --text_key: Key to access text data in the dataset. 23 | --num_worker: Number of worker processes for multiprocessing. 24 | --skip_exist: Whether to skip existing processed files. 25 | 26 | The script processes each file in the specified data path, tokenizes the text data, and saves the tokenized data to the target folder. 27 | """ 28 | 29 | import argparse 30 | import os 31 | import json 32 | import random 33 | import pathlib 34 | import numpy as np 35 | import multiprocessing as mp 36 | from tqdm import tqdm, trange 37 | from transformers import AutoTokenizer 38 | 39 | random.seed(45) 40 | MAX_DATA = int(1e6) 41 | 42 | 43 | # Get the target folder path based on the file path and model name 44 | def get_tgt_folder(file_path, model_name): 45 | file_path = file_path.replace("/data", f"/{model_name}_data_ids") 46 | tgt_folder = file_path[: file_path.rfind(".")] 47 | tgt_folder = os.path.join(tgt_folder, "wo_ppl") 48 | if os.path.exists(tgt_folder) == True: 49 | is_exists = True 50 | else: 51 | is_exists = False 52 | pathlib.Path(tgt_folder).mkdir(parents=True, exist_ok=True) 53 | return tgt_folder, is_exists 54 | 55 | 56 | # Tokenize text data and save it to the target folder 57 | def tokenize_text(dataset, tgt_folder, idx, text_key, is_first): 58 | tgt_path = os.path.join(tgt_folder, "part-{}.jsonl".format(idx)) 59 | if is_first == False: 60 | write_mode = "a" 61 | else: 62 | write_mode = "w" 63 | fout = open(tgt_path, write_mode) 64 | for data in tqdm(dataset, desc="Process {}".format(idx)): 65 | input_ids = tokenizer(data[text_key], add_special_tokens=False)["input_ids"] 66 | new_data = {"input_ids": input_ids} 67 | fout.write(json.dumps(new_data, ensure_ascii=False) + "\n") 68 | fout.close() 69 | 70 | 71 | # Start multiprocessing for tokenizing text data 72 | def start_mp(dataset, is_first): 73 | if len(dataset) == 0: 74 | return 75 | if isinstance(dataset, list) == False: 76 | return 77 | try: 78 | assert args.text_key in dataset[0] 79 | text_key = args.text_key 80 | except AssertionError: 81 | print("Available Keys:", dataset[0].keys()) 82 | raise Exception("Unknown Key!") 83 | 84 | random.shuffle(dataset) 85 | part_num = args.num_worker 86 | slice_idx = np.linspace(0, len(dataset), part_num + 1).astype("int") 87 | p = mp.Pool(part_num) 88 | for start_id in range(part_num): 89 | start, end = slice_idx[start_id], slice_idx[start_id + 1] 90 | new_lines = dataset[start:end] 91 | p.apply_async( 92 | tokenize_text, args=(new_lines, tgt_folder, start_id, text_key, is_first) 93 | ) 94 | p.close() 95 | p.join() 96 | print("All of the child processes over!") 97 | 98 | 99 | if __name__ == "__main__": 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument("--tokenizer_path", type=str) 102 | parser.add_argument("--model_name", type=str) 103 | parser.add_argument("--data_path", type=str) 104 | parser.add_argument("--num_files", type=int) 105 | parser.add_argument("--text_key", type=str) 106 | parser.add_argument("--num_worker", type=int) 107 | parser.add_argument("--skip_exist", type=bool, default=False) 108 | args = parser.parse_args() 109 | 110 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 111 | 112 | for root, _, files in os.walk(args.data_path, topdown=False): 113 | step = 0 114 | random.shuffle(files) 115 | for fp in tqdm(files): 116 | file_path = os.path.join(root, fp) 117 | tgt_folder, is_exists = get_tgt_folder(file_path, args.model_name) 118 | if is_exists == True and args.skip_exist == True: 119 | continue 120 | 121 | print("Process {}".format(file_path)) 122 | print("Target Folder: {}".format(tgt_folder)) 123 | 124 | fin = open(file_path, "r") 125 | is_jsonl = False 126 | if file_path.endswith(".json") == True: 127 | try: 128 | dataset = json.load(fin) 129 | start_mp(dataset, True) 130 | step = step + 1 131 | if step >= args.num_files: 132 | break 133 | continue 134 | except json.decoder.JSONDecodeError: 135 | is_jsonl = True 136 | fin.close() 137 | fin = open(file_path, "r") 138 | 139 | if file_path.endswith(".jsonl") == True or is_jsonl == True: 140 | is_finish = False 141 | is_first = True 142 | while True: 143 | dataset = [] 144 | for i in trange(MAX_DATA, desc="Reading Data"): 145 | tmp_data = fin.readline() 146 | if not tmp_data: 147 | is_finish = True 148 | break 149 | try: 150 | tmp_data = json.loads(tmp_data) 151 | dataset.append(tmp_data) 152 | except json.decoder.JSONDecodeError: 153 | continue 154 | 155 | start_mp(dataset, is_first) 156 | is_first = False 157 | if is_finish == True: 158 | break 159 | else: 160 | continue 161 | 162 | fin.close() 163 | step = step + 1 164 | if step >= args.num_files: 165 | break 166 | -------------------------------------------------------------------------------- /src/save_hf_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | This script processes and saves Hugging Face datasets based on provided configurations. 4 | 5 | Functions: 6 | main(timestamp_lst, tokenizer_path, model_max_length, num_workers, min_text_length, root_dir, show_case): 7 | Main function to process and save datasets. 8 | 9 | parse_args(): 10 | Parses command-line arguments. 11 | 12 | Arguments: 13 | --timestamp_lst (str): Comma-separated list of timestamps. 14 | --tokenizer_path (str): Path to the tokenizer. Default is "meta-llama/Meta-Llama-3-8B". 15 | --model_max_length (int): Maximum length of the model. Default is 8192. 16 | --num_workers (int): Number of workers. Default is 32. 17 | --min_text_length (int): Minimum text length to filter. Default is 20. 18 | --root_dir (str): Root directory. 19 | --show_case (bool): Whether to show the first case. 20 | 21 | Environment Variables: 22 | TMPDIR: Temporary directory path. 23 | HF_DATASETS_CACHE: Hugging Face datasets cache directory path. 24 | HF_HOME: Hugging Face home directory path. 25 | 26 | Usage: 27 | Run the script with the required arguments to process and save datasets. 28 | """ 29 | 30 | import os 31 | 32 | # Set the environment variable to use the cache 33 | SAVE_PATH = "~/.cache" 34 | os.environ["TMPDIR"] = os.path.join(SAVE_PATH, "tmp") 35 | os.environ["HF_DATASETS_CACHE"] = os.path.join(SAVE_PATH, "hf_datasets_cache") 36 | os.environ["HF_HOME"] = os.path.join(SAVE_PATH, "hf_home") 37 | os.makedirs(os.environ["TMPDIR"], exist_ok=True) 38 | os.makedirs(os.environ["HF_DATASETS_CACHE"], exist_ok=True) 39 | os.makedirs(os.environ["HF_HOME"], exist_ok=True) 40 | import json 41 | import datasets 42 | from transformers import AutoTokenizer 43 | import argparse 44 | 45 | 46 | def main( 47 | timestamp_lst, 48 | tokenizer_path, 49 | model_max_length, 50 | num_workers, 51 | min_text_length, 52 | root_dir, 53 | show_case, 54 | ): 55 | # timestamp 56 | timestamp_lst = timestamp_lst.split(",") 57 | 58 | # tokenizer 59 | print(f"Loading tokenizer from {tokenizer_path}") 60 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 61 | assert tokenizer.eos_token_id is not None, "Tokenizer must have an EOS token" 62 | 63 | # model name 64 | model_name = tokenizer_path.split("/")[-1] 65 | 66 | # save train circle info 67 | train_circle_dir = os.path.join( 68 | root_dir, f"train_info/{model_name}/train_circle_info" 69 | ) 70 | 71 | def group_texts(examples): 72 | processed_data = [] 73 | tmp = [] 74 | for example in examples["input_ids"]: 75 | if len(example) < min_text_length: 76 | continue 77 | example.append(tokenizer.eos_token_id) 78 | tmp.extend(example) 79 | while len(tmp) >= model_max_length: 80 | processed_data.append(tmp[:model_max_length]) 81 | tmp = tmp[model_max_length:] 82 | return {"input_ids": processed_data} 83 | 84 | for timestamp in timestamp_lst: 85 | 86 | train_circle_info_file_path = os.path.join( 87 | train_circle_dir, f"{timestamp}.json" 88 | ) 89 | 90 | print(f"Reading train circle information from {train_circle_info_file_path}") 91 | with open(train_circle_info_file_path, "r", encoding="utf-8") as f: 92 | data_folder2token_ids_paths = json.load(f)["Token ID Files Information"] 93 | 94 | for data_folder, info in data_folder2token_ids_paths.items(): 95 | 96 | print(f"Processing {data_folder}") 97 | 98 | info_lst = [ 99 | ( 100 | i["group index"], 101 | i["huggingface datasets directory"], 102 | i["token id file paths"], 103 | i["token num(B)"], 104 | ) 105 | for i in info 106 | ] 107 | 108 | for idx, hf_datasets_dir, data_files, token_num in info_lst: 109 | 110 | print( 111 | f"Saving {data_folder} | group {idx} | {token_num}B tokens | {len(data_files)} files | to {hf_datasets_dir}" 112 | ) 113 | os.makedirs(hf_datasets_dir, exist_ok=True) 114 | 115 | # load dataset 116 | raw_train_dataset = datasets.load_dataset( 117 | "json", 118 | data_files=data_files, 119 | split="train", 120 | ) 121 | 122 | print(raw_train_dataset) 123 | print("=========" * 9) 124 | if show_case: 125 | print(raw_train_dataset[0]["input_ids"]) 126 | print("=========" * 9) 127 | print(len(raw_train_dataset[0]["input_ids"])) 128 | print("=========" * 9) 129 | 130 | # group texts 131 | print( 132 | f"Grouping texts with sequence length {model_max_length}. Filter texts shorter than {min_text_length}" 133 | ) 134 | if len(raw_train_dataset) < 1000 * num_workers: 135 | total_data = [sample["input_ids"] for sample in raw_train_dataset] 136 | total_data = group_texts({"input_ids": total_data}) 137 | if len(total_data["input_ids"]) == 0: 138 | os.removedirs(hf_datasets_dir) 139 | continue 140 | train_dataset = datasets.Dataset.from_dict(total_data) 141 | train_dataset = train_dataset.cast_column( 142 | "input_ids", datasets.features.Sequence(datasets.Value("int64")) 143 | ) 144 | else: 145 | train_dataset = raw_train_dataset.map( 146 | group_texts, 147 | batched=True, 148 | num_proc=num_workers, 149 | desc=f"Grouping texts with sequence length {model_max_length}", 150 | ) 151 | print(train_dataset) 152 | print("=========" * 9) 153 | if show_case: 154 | print(train_dataset[0]["input_ids"]) 155 | print("=========" * 9) 156 | print(len(train_dataset[0]["input_ids"])) 157 | print("=========" * 9) 158 | print(tokenizer.decode(train_dataset[0]["input_ids"])) 159 | print("=========" * 9) 160 | 161 | train_dataset.save_to_disk(hf_datasets_dir) 162 | print(f"Dataset saved to {hf_datasets_dir}") 163 | 164 | del raw_train_dataset 165 | del train_dataset 166 | 167 | print("Finished!") 168 | 169 | 170 | def parse_args(): 171 | parser = argparse.ArgumentParser(description="Process and save HF datasets.") 172 | parser.add_argument( 173 | "--timestamp_lst", 174 | type=str, 175 | required=True, 176 | help="Comma-separated list of timestamps.", 177 | ) 178 | parser.add_argument( 179 | "--tokenizer_path", 180 | type=str, 181 | required=True, 182 | default="meta-llama/Meta-Llama-3-8B", 183 | help="Path to the tokenizer", 184 | ) 185 | parser.add_argument( 186 | "--model_max_length", 187 | type=int, 188 | required=True, 189 | default=8192, 190 | help="Maximum length of the model.", 191 | ) 192 | parser.add_argument( 193 | "--num_workers", 194 | type=int, 195 | required=True, 196 | default=32, 197 | help="Number of workers.", 198 | ) 199 | parser.add_argument( 200 | "--min_text_length", 201 | type=int, 202 | required=True, 203 | default=20, 204 | help="Minimum text length to filter.", 205 | ) 206 | parser.add_argument( 207 | "--root_dir", 208 | type=str, 209 | required=True, 210 | help="Root directory.", 211 | ) 212 | parser.add_argument( 213 | "--show_case", 214 | action="store_true", 215 | help="Whether to show the first case.", 216 | ) 217 | return parser.parse_args() 218 | 219 | 220 | if __name__ == "__main__": 221 | args = parse_args() 222 | main( 223 | args.timestamp_lst, 224 | args.tokenizer_path, 225 | args.model_max_length, 226 | args.num_workers, 227 | args.min_text_length, 228 | args.root_dir, 229 | args.show_case, 230 | ) 231 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | 10 | 11 | 18 | 19 | 24 | 25 |

26 | 📄 报告   |   🤗 Hugging Face 上的模型  |   📊 继续预训练数据集 27 |

28 | 29 |

30 | 🔍 English  |  简体中文 31 |

32 | 33 | 46 | 47 | 60 | 61 | --- 62 | 63 | ## 更新 64 | 65 | - 🌟🌟 `2024/12/17`: 我们发布了用于继续预训练和数据准备的 [代码](https://github.com/RUC-GSAI/Llama-3-SynE/blob/main/src),代码中包含详尽的文档注释。 66 | - ✨✨ `2024/08/12`: 我们发布了 [继续预训练数据集](https://huggingface.co/datasets/survivi/Llama-3-SynE-Dataset)。 67 | - ✨✨ `2024/08/10`: 我们发布了 [Llama-3-SynE 模型](https://huggingface.co/survivi/Llama-3-SynE)。 68 | - ✨ `2024/07/26`: 我们发布了 Llama-3-SynE 的 [技术报告](https://arxiv.org/abs/2407.18743),欢迎查阅! 69 | 70 |
71 | 72 |

73 | 74 |

75 | 76 | ## 模型介绍 77 | 78 | **Llama-3-SynE**(**Syn**thetic data **E**nhanced Llama-3)是 [Llama-3(8B)](https://github.com/meta-llama/llama3)的增强版,通过继续预训练(continual pre-training,CPT)来提升其**中文语言能力和科学推理能力**。通过精心设计的数据混合和课程策略,Llama-3-SynE 成功地在保持原始模型性能的同时增强了新能力。这个增强过程包括利用现有数据集并合成专门为目标任务设计的高质量数据集。 79 | 80 | Llama-3-SynE 的主要特点包括: 81 | 82 | - **增强的中文语言能力**:通过基于主题的数据混合和基于困惑度的数据课程实现。 83 | - **改进的科学推理能力**:利用合成数据集来增强多学科的科学知识。 84 | - **高效的继续预训练**:只消耗约 1000 亿个 token,成本效益高。 85 | 86 | ## 模型列表 87 | 88 | | 模型 | 类型 | 序列长度 | 下载 | 89 | | :----------- | :--- | :------- | :------------------------------------------------------------ | 90 | | Llama-3-SynE | Base | 8K | [🤗 Huggingface](https://huggingface.co/survivi/Llama-3-SynE) | 91 | 92 | ## 基准测试 93 | 94 | 我们将所有评估基准分为两组。第一组是 _主要基准_,旨在评估大语言模型的综合能力。值得注意的是我们在这一组基准中包括了常用的数学和代码基准,因为使用这些基准评估各种通用大语言模型是标准做法。 95 | 96 | 第二组是 _科学基准_,涵盖了多学科的科学知识。 97 | 98 | 我们报告了在 GSM8K、ASDiv 和 MAWPS 上的 8-shot 性能,C-Eval、CMMLU、MMLU、MATH、GaoKao、SciQ、SciEval、SAT-Math 和 AQUA-RAT 上的 5-shot 推理性能,MBPP 上的 3-shot 性能。 99 | 对于 HumanEval 和 ARC,我们报告了 0-shot 性能。最佳和次佳结果分别以 **粗体** 和 _斜体_ 标出。 100 | 101 | ### 主要基准 102 | 103 | | **模型** | **MMLU** | **C-Eval** | **CMMLU** | **MATH** | **GSM8K** | **ASDiv** | **MAWPS** | **SAT-Math** | **HumanEval** | **MBPP** | 104 | | :---------------------- | :-------- | :--------- | :-------- | :-------- | :-------- | :-------- | :-------- | :----------- | :------------ | :-------- | 105 | | Llama-3-8B | **66.60** | 49.43 | 51.03 | 16.20 | 54.40 | 72.10 | 89.30 | 38.64 | _36.59_ | **47.00** | 106 | | DCLM-7B | 64.01 | 41.24 | 40.89 | 14.10 | 39.20 | 67.10 | 83.40 | _41.36_ | 21.95 | 32.60 | 107 | | Mistral-7B-v0.3 | 63.54 | 42.74 | 43.72 | 12.30 | 40.50 | 67.50 | 87.50 | 40.45 | 25.61 | 36.00 | 108 | | Llama-3-Chinese-8B | 64.10 | _50.14_ | _51.20_ | 3.60 | 0.80 | 1.90 | 0.60 | 36.82 | 9.76 | 14.80 | 109 | | MAmmoTH2-8B | 64.89 | 46.56 | 45.90 | **34.10** | **61.70** | **82.80** | _91.50_ | _41.36_ | 17.68 | 38.80 | 110 | | Galactica-6.7B | 37.13 | 26.72 | 25.53 | 5.30 | 9.60 | 40.90 | 51.70 | 23.18 | 7.31 | 2.00 | 111 | | **Llama-3-SynE (ours)** | _65.19_ | **58.24** | **57.34** | _28.20_ | _60.80_ | _81.00_ | **94.10** | **43.64** | **42.07** | _45.60_ | 112 | 113 | > 在 **中文评估基准**(如 C-Eval 和 CMMLU)上,Llama-3-SynE 显著优于基础模型 Llama-3(8B),表明我们的方法在提升中文语言能力方面非常有效。 114 | 115 | > 在 **英文评估基准**(如 MMLU、MATH 和代码评估基准)上,Llama-3-SynE 展现出与基础模型相当或更好的性能,表明我们的方法在继续预训练过程中有效解决了灾难性遗忘问题。 116 | 117 | ### 科学基准 118 | 119 | “PHY”、“CHE” 和 “BIO” 分别表示对应基准的物理、化学和生物子任务。 120 | 121 | | **模型** | **SciEval PHY** | **SciEval CHE** | **SciEval BIO** | **SciEval Avg.** | **SciQ** | **GaoKao MathQA** | **GaoKao CHE** | **GaoKao BIO** | **ARC Easy** | **ARC Challenge** | **ARC Avg.** | **AQUA-RAT** | 122 | | :---------------------- | :-------------- | :-------------- | :-------------- | :--------------- | :-------- | :---------------- | :------------- | :------------- | :----------- | :---------------- | :----------- | :----------- | 123 | | Llama-3-8B | 46.95 | 63.45 | 74.53 | 65.47 | 90.90 | 27.92 | 32.85 | 43.81 | 91.37 | 77.73 | 84.51 | _27.95_ | 124 | | DCLM-7B | **56.71** | 64.39 | 72.03 | 66.25 | **92.50** | 29.06 | 31.40 | 37.14 | 89.52 | 76.37 | 82.94 | 20.08 | 125 | | Mistral-7B-v0.3 | 48.17 | 59.41 | 68.89 | 61.51 | 89.40 | 30.48 | 30.92 | 41.43 | 87.33 | 74.74 | 81.04 | 23.23 | 126 | | Llama-3-Chinese-8B | 48.17 | 67.34 | 73.90 | _67.34_ | 89.20 | 27.64 | 30.43 | 38.57 | 88.22 | 70.48 | 79.35 | 27.56 | 127 | | MAmmoTH2-8B | 49.39 | **69.36** | _76.83_ | **69.60** | 90.20 | **32.19** | _36.23_ | _49.05_ | **92.85** | **84.30** | **88.57** | 27.17 | 128 | | Galactica-6.7B | 34.76 | 43.39 | 54.07 | 46.27 | 71.50 | 23.65 | 27.05 | 24.76 | 65.91 | 46.76 | 56.33 | 20.87 | 129 | | **Llama-3-SynE (ours)** | _53.66_ | _67.81_ | **77.45** | **69.60** | _91.20_ | _31.05_ | **51.21** | **69.52** | _91.58_ | _80.97_ | _86.28_ | **28.74** | 130 | 131 | > 在 **科学评估基准**(如 SciEval、GaoKao 和 ARC)上,Llama-3-SynE 显著优于基础模型,特别是在中文科学基准上表现出显著提升(例如,高考生物子测试中提升了 25.71%)。 132 | 133 | ## 快速开始 134 | 135 | 基于 transformers 进行推理: 136 | 137 | ```python 138 | from transformers import AutoTokenizer, AutoModelForCausalLM 139 | import torch 140 | 141 | model_path = "survivi/Llama-3-SynE" 142 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 143 | model = AutoModelForCausalLM.from_pretrained( 144 | model_path, torch_dtype=torch.bfloat16, trust_remote_code=True 145 | ) 146 | model.to("cuda:0") 147 | model.eval() 148 | prompt = "Hello world!" 149 | inputs = tokenizer(prompt, return_tensors="pt") 150 | inputs = inputs.to("cuda") 151 | pred = model.generate( 152 | **inputs, 153 | max_new_tokens=2048, 154 | repetition_penalty=1.05, 155 | temperature=0.5, 156 | top_k=5, 157 | top_p=0.85, 158 | do_sample=True 159 | ) 160 | pred = pred[0][len(inputs.input_ids[0]) :] 161 | output = tokenizer.decode(pred, skip_special_tokens=True) 162 | print(output) 163 | ``` 164 | 165 | 基于 vLLM 进行推理: 166 | 167 | ```python 168 | from transformers import AutoTokenizer 169 | from vllm import LLM, SamplingParams 170 | 171 | model_path = "survivi/Llama-3-SynE" 172 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 173 | sampling_params = SamplingParams( 174 | max_tokens=2048, 175 | repetition_penalty=1.05, 176 | temperature=0.5, 177 | top_k=5, 178 | top_p=0.85, 179 | ) 180 | llm = LLM( 181 | model=model_path, 182 | tensor_parallel_size=1, 183 | trust_remote_code=True, 184 | ) 185 | prompt = "Hello world!" 186 | output = llm.generate(prompt, sampling_params) 187 | output = output[0].outputs[0].text 188 | print(output) 189 | ``` 190 | 191 | ## 许可证 192 | 193 | 本项目基于 Meta 的 Llama-3 模型开发,Llama-3-SynE 模型权重的使用必须遵循 Llama-3 [许可协议](https://github.com/meta-llama/llama3/blob/main/LICENSE)。此开源代码库中的代码遵循 [Apache 2.0](LICENSE) 许可证。 194 | 195 | ## 引用 196 | 197 | 如果您觉得我们的工作对您有帮助,请考虑引用以下论文: 198 | 199 | ``` 200 | @article{jie2024llama3syne, 201 | title={Towards Effective and Efficient Continual Pre-training of Large Language Models}, 202 | author={Chen, Jie and Chen, Zhipeng and Wang, Jiapeng and Zhou, Kun and Zhu, Yutao and Jiang, Jinhao and Min, Yingqian and Zhao, Wayne Xin and Dou, Zhicheng and Mao, Jiaxin and others}, 203 | journal={arXiv preprint arXiv:2407.18743}, 204 | year={2024} 205 | } 206 | ``` 207 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/fetch_data.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | fetch_data.py 4 | 5 | This script is used to fetch and organize token data for training purposes. It reads token ID files from specified directories, selects a specified number of tokens based on given ratios, and organizes the selected token data into groups for further processing. 6 | 7 | Functions: 8 | main(total_token_num, cn_ratio, en_ratio, syn_ratio, root_dir, tokenizer_path) 9 | Main function to fetch and organize token data. 10 | 11 | Arguments: 12 | --total_token_num (int): Total number of tokens to be selected. 13 | --cn_ratio (float): Ratio of CN tokens. 14 | --en_ratio (float): Ratio of EN tokens. 15 | --syn_ratio (float): Ratio of SYN tokens. 16 | --root_dir (str): Root directory for token IDs. 17 | --tokenizer_path (str): Path to the tokenizer. 18 | 19 | Usage: 20 | python fetch_data.py --total_token_num 40 --cn_ratio 0.1 --en_ratio 0.7 --syn_ratio 0.2 --root_dir /path/to/root --tokenizer_path meta-llama/Meta-Llama-3-8B 21 | """ 22 | 23 | import os 24 | import argparse 25 | import json 26 | import glob 27 | from datetime import datetime 28 | 29 | # Set the environment variable to use the cache 30 | SAVE_PATH = "~/.cache" 31 | os.environ["TMPDIR"] = os.path.join(SAVE_PATH, "tmp") 32 | os.environ["HF_DATASETS_CACHE"] = os.path.join(SAVE_PATH, "hf_datasets_cache") 33 | os.environ["HF_HOME"] = os.path.join(SAVE_PATH, "hf_home") 34 | os.makedirs(os.environ["TMPDIR"], exist_ok=True) 35 | os.makedirs(os.environ["HF_DATASETS_CACHE"], exist_ok=True) 36 | os.makedirs(os.environ["HF_HOME"], exist_ok=True) 37 | 38 | 39 | def main( 40 | total_token_num, 41 | cn_ratio, 42 | en_ratio, 43 | syn_ratio, 44 | root_dir, 45 | tokenizer_path, 46 | ): 47 | print(f"Start | {datetime.now().strftime('%Y%m%d_%H%M%S')}") 48 | 49 | cn_token_num = total_token_num * cn_ratio # CN 50 | en_token_num = total_token_num * en_ratio # EN 51 | syn_token_num = total_token_num * syn_ratio # SYNTH 52 | 53 | cn_ratio_lst = [ 54 | 0.7, # web-cn 55 | 0.05, # encyclopedia-cn 56 | 0.2, # book-cn 57 | 0.05, # qa_forum-cn 58 | ] 59 | en_ratio_lst = [ 60 | 0.4, # web-en 61 | 0.05, # encyclopedia-en 62 | 0.15, # book-en 63 | 0.05, # qa_forum-en 64 | 0.1, # paper-en 65 | 0.1, # math-en 66 | 0.15, # code-en 67 | ] 68 | 69 | data_select_info = { 70 | ## CN 71 | "web-cn": cn_token_num * cn_ratio_lst[0], 72 | "encyclopedia-cn": cn_token_num * cn_ratio_lst[1], 73 | "book-cn": cn_token_num * cn_ratio_lst[2], 74 | "qa_forum-cn": cn_token_num * cn_ratio_lst[3], 75 | ## EN 76 | "web-en": en_token_num * en_ratio_lst[0], 77 | "encyclopedia-en": en_token_num * en_ratio_lst[1], 78 | "book-en": en_token_num * en_ratio_lst[2], 79 | "qa_forum-en": en_token_num * en_ratio_lst[3], 80 | "paper-en": en_token_num * en_ratio_lst[4], 81 | "math-en": en_token_num * en_ratio_lst[5], 82 | "code-en": en_token_num * en_ratio_lst[6], 83 | ## SYNTH 84 | "synthesis-en": syn_token_num, 85 | } 86 | print(data_select_info) 87 | 88 | # Billion 89 | B = 10**9 90 | 91 | # Model name 92 | model_name = tokenizer_path.split("/")[-1] 93 | 94 | # Token ids root directory 95 | token_ids_dir = os.path.join(root_dir, f"{model_name}_data_ids") 96 | 97 | # Time stamp 98 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 99 | 100 | # Save trained data info 101 | trained_info_dir = os.path.join( 102 | root_dir, f"train_info/{model_name}/trained_data_info/{timestamp}" 103 | ) 104 | os.makedirs(trained_info_dir, exist_ok=True) 105 | 106 | # Save train circle info 107 | train_circle_dir = os.path.join( 108 | root_dir, f"train_info/{model_name}/train_circle_info" 109 | ) 110 | os.makedirs(train_circle_dir, exist_ok=True) 111 | 112 | # Hf datasets save directory 113 | hf_datasets_save_dir = os.path.join(root_dir, "hf_dataset", model_name, timestamp) 114 | 115 | # Total tokens selected 116 | total_tokens = 0 117 | 118 | # Number of selected tokens for each data folder 119 | data_folder2num_tokens = dict() 120 | 121 | # Data folder name to selected token id path group 122 | data_folder2token_ids_paths = dict() 123 | 124 | # Hf dataset save step 125 | hf_dataset_save_step = 3 # 3B 126 | 127 | # Get trained files 128 | trained_files = [] 129 | for file in glob.glob(train_circle_dir + "/*.json"): 130 | print(f"Reading trained data info from {file}") 131 | with open(file, "r") as f: 132 | trained_file_info = json.load(f) 133 | path_lst = [ 134 | paths 135 | for info in trained_file_info["Token ID Files Information"].values() 136 | for paths in [ii["token id file paths"] for ii in info] 137 | ] 138 | trained_files.extend([iii for jjj in path_lst for iii in jjj]) 139 | 140 | trained_files = [i for i in trained_files] 141 | 142 | for data_folder_name, target_token_num in data_select_info.items(): 143 | 144 | if target_token_num == 0: 145 | continue 146 | 147 | print(f"Processing {data_folder_name}") 148 | 149 | # Read trained data info for this data folder 150 | save_path_data_info = os.path.join(trained_info_dir, data_folder_name) 151 | os.makedirs(save_path_data_info, exist_ok=True) 152 | 153 | # Get untrained token id files for this data folder 154 | dirpath_filename_lst = [ 155 | (dirpath, filename) 156 | for dirpath, dirnames, filenames in os.walk( 157 | os.path.join(token_ids_dir, data_folder_name) 158 | ) 159 | for filename in filenames 160 | if filename.endswith(".jsonl") 161 | and "splited_part" in filename 162 | and "tokens_" not in filename 163 | and os.path.join(dirpath, filename) not in trained_files 164 | ] 165 | 166 | if len(dirpath_filename_lst) == 0: 167 | continue 168 | 169 | # Sort by file name to ensure the order of ppl, etc 170 | dirpath_filename_lst.sort(key=lambda x: x[1]) 171 | 172 | # Train data info for this data folder 173 | token_infos = [] 174 | token_num = 0 175 | 176 | # Seleceted token id file paths for this data folder 177 | paths_groups = [] 178 | 179 | # Hf dataset piece info for this data folder 180 | file_paths_piece = [] 181 | token_num_piece = 0 182 | 183 | # Read file names, select token id files and get token num 184 | for dir_path, file_name in dirpath_filename_lst: 185 | file_path = os.path.join(dir_path, file_name) 186 | 187 | # Get token num 188 | # E.g. ···/splited_part-1.jsonl => ···/splited_part-1-tokens_100073149.jsonl 189 | glob_files = glob.glob( 190 | os.path.join(dir_path, file_name.split(".")[0] + "-tokens_*.jsonl") 191 | ) 192 | if len(glob_files) == 0: 193 | continue 194 | assert len(glob_files) == 1 195 | file_token_num = int(glob_files[0].split("_")[-1].split(".")[0]) / B 196 | print(f"{file_path} has {file_token_num} tokens") 197 | 198 | # Update train data info 199 | token_infos.append( 200 | { 201 | "token_ids_path": file_path, 202 | "token_num(B)": file_token_num, 203 | } 204 | ) 205 | token_num += file_token_num 206 | print( 207 | f"Total token updated: {token_num}B/{target_token_num}B for {data_folder_name}" 208 | ) 209 | 210 | # Update hf dataset piece info 211 | file_paths_piece.append(file_path) 212 | token_num_piece += file_token_num 213 | 214 | # Update hf dataset piece info if token num exceeds save step 215 | if token_num_piece > hf_dataset_save_step: 216 | paths_groups.append([file_paths_piece, token_num_piece]) 217 | # Unset hf dataset piece info 218 | file_paths_piece = [] 219 | token_num_piece = 0 220 | 221 | # Check if target token num is reached 222 | if token_num > target_token_num: 223 | break 224 | 225 | # Save last hf dataset piece 226 | if len(file_paths_piece) > 0: 227 | paths_groups.append([file_paths_piece, token_num_piece]) 228 | 229 | # Save train data info 230 | with open( 231 | os.path.join(save_path_data_info, f"{timestamp}.jsonl"), 232 | "w", 233 | encoding="utf-8", 234 | ) as f: 235 | for i in token_infos: 236 | f.write(json.dumps(i, ensure_ascii=False) + "\n") 237 | 238 | # Update global train data info 239 | total_tokens += token_num 240 | data_folder2num_tokens[data_folder_name] = token_num 241 | data_folder2token_ids_paths[data_folder_name] = [ 242 | { 243 | "group index": idx, 244 | "huggingface datasets directory": os.path.join( 245 | hf_datasets_save_dir, 246 | f"{data_folder_name.replace('/', '')}_{idx}", 247 | ), 248 | "token num(B)": info[1], 249 | "token id file paths": info[0], 250 | } 251 | for idx, info in enumerate(paths_groups) 252 | ] 253 | 254 | # Print train data info for this data folder 255 | print( 256 | f"Token id files selected for {data_folder_name}:\n{data_folder2token_ids_paths[data_folder_name]}" 257 | ) 258 | print( 259 | f"Select {len([i for j in paths_groups for i in j])} token id files for {data_folder_name}" 260 | ) 261 | print( 262 | f"Select token num: {token_num}B/{target_token_num}B for {data_folder_name}" 263 | ) 264 | 265 | # Show train circle info 266 | train_circle_info = { 267 | "Total Tokens(B)": total_tokens, 268 | "Total Huggingface Datasets Directory": hf_datasets_save_dir, 269 | "Number of Selected Tokens(B)": data_folder2num_tokens, 270 | "Manual Selected Tokens(B)": data_select_info, 271 | "Token ID Files Information": data_folder2token_ids_paths, 272 | } 273 | print(f"Train circle info:\n{train_circle_info}") 274 | 275 | # Save train circle info 276 | train_circle_info_path = os.path.join(train_circle_dir, f"{timestamp}.json") 277 | print(f"Save train circle info to {train_circle_info_path}") 278 | with open(train_circle_info_path, "w", encoding="utf-8") as f: 279 | json.dump(train_circle_info, f, ensure_ascii=False, indent=4) 280 | 281 | print(f"Finished | {timestamp}") 282 | 283 | 284 | if __name__ == "__main__": 285 | parser = argparse.ArgumentParser(description="Fetch data script") 286 | parser.add_argument( 287 | "--total_token_num", 288 | type=int, 289 | required=True, 290 | default=40, 291 | help="Total number of tokens", 292 | ) 293 | parser.add_argument( 294 | "--cn_ratio", type=float, required=True, default=0.1, help="Ratio of CN tokens" 295 | ) 296 | parser.add_argument( 297 | "--en_ratio", type=float, required=True, default=0.7, help="Ratio of EN tokens" 298 | ) 299 | parser.add_argument( 300 | "--syn_ratio", 301 | type=float, 302 | required=True, 303 | default=0.2, 304 | help="Ratio of SYN tokens", 305 | ) 306 | parser.add_argument( 307 | "--root_dir", type=str, required=True, help="Root directory for token ids" 308 | ) 309 | parser.add_argument( 310 | "--tokenizer_path", 311 | type=str, 312 | required=True, 313 | default="meta-llama/Meta-Llama-3-8B", 314 | help="Path to the tokenizer", 315 | ) 316 | 317 | args = parser.parse_args() 318 | 319 | main( 320 | args.total_token_num, 321 | args.cn_ratio, 322 | args.en_ratio, 323 | args.syn_ratio, 324 | args.root_dir, 325 | args.tokenizer_path, 326 | ) 327 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 10 | 11 | 18 | 19 | 24 | 25 |

26 | 📄 Report   |   🤗 Model on Hugging Face  |   📊 CPT Dataset 27 |

28 | 29 |

30 | 🔍 English  |  简体中文 31 |

32 | 33 | 46 | 47 | 60 | 61 | --- 62 | 63 | ## News 64 | 65 | - 🌟🌟 `2024/12/17`: We released the [code](https://github.com/RUC-GSAI/Llama-3-SynE/blob/main/src) used for continual pre-training and data preparation. The code contains detailed documentation comments. 66 | - ✨✨ `2024/08/12`: We released the [continual pre-training dataset](https://huggingface.co/datasets/survivi/Llama-3-SynE-Dataset). 67 | - ✨✨ `2024/08/10`: We released the [Llama-3-SynE model](https://huggingface.co/survivi/Llama-3-SynE). 68 | - ✨ `2024/07/26`: We released the [technical report](https://arxiv.org/abs/2407.18743), welcome to check it out! 69 | 70 |
71 | 72 |

73 | 74 |

75 | 76 | ## Model Introduction 77 | 78 | **Llama-3-SynE** (Synthetic data Enhanced Llama-3) is a significantly enhanced version of [Llama-3 (8B)](https://github.com/meta-llama/llama3), achieved through continual pre-training (CPT) to improve its **Chinese language ability and scientific reasoning capability**. By employing a meticulously designed data mixture and curriculum strategy, Llama-3-SynE successfully enhances new abilities while maintaining the original model’s performance. This enhancement process involves utilizing existing datasets and synthesizing high-quality datasets specifically designed for targeted tasks. 79 | 80 | Key features of Llama-3-SynE include: 81 | 82 | - **Enhanced Chinese Language Capabilities**: Achieved through topic-based data mixture and perplexity-based data curriculum. 83 | - **Improved Scientific Reasoning**: Utilizing synthetic datasets to enhance multi-disciplinary scientific knowledge. 84 | - **Efficient CPT**: Only consuming around 100 billion tokens, making it a cost-effective solution. 85 | 86 | ## Model List 87 | 88 | | Model | Type | Seq Length | Download | 89 | | :----------- | :--- | :--------- | :------------------------------------------------------------ | 90 | | Llama-3-SynE | Base | 8K | [🤗 Huggingface](https://huggingface.co/survivi/Llama-3-SynE) | 91 | 92 | ## BenchMark 93 | 94 | We divide all evaluation benchmarks into two groups. The first group is _major benchmarks_, which aim to evaluate the comprehensive capacities of LLMs. Note that we include commonly used math and code benchmarks in this group because it is standard practice to use these benchmarks for evaluating various general-purpose LLMs. 95 | 96 | The second group is _scientific benchmarks_, which have a broader coverage of multidisciplinary scientific knowledge. 97 | 98 | We report the eight-shot performance on GSM8K, ASDiv, and MAWPS, five-shot for C-Eval, CMMLU, MMLU, MATH, GaoKao, SciQ, SciEval, SAT-Math, and AQUA-RAT, three-shot for MBPP. 99 | For HumanEval and ARC, we report the zero-shot evaluation performance. The best and second best are in **bold** and underlined, respectively. 100 | 101 | ### Major Benchmarks 102 | 103 | | **Models** | **MMLU** | **C-Eval** | **CMMLU** | **MATH** | **GSM8K** | **ASDiv** | **MAWPS** | **SAT-Math** | **HumanEval** | **MBPP** | 104 | | :---------------------- | :--------------- | :--------------- | :--------------- | :--------------- | :--------------- | :--------------- | :--------------- | :--------------- | :--------------- | :--------------- | 105 | | Llama-3-8B | **66.60** | 49.43 | 51.03 | 16.20 | 54.40 | 72.10 | 89.30 | 38.64 | 36.59 | **47.00** | 106 | | DCLM-7B | 64.01 | 41.24 | 40.89 | 14.10 | 39.20 | 67.10 | 83.40 | 41.36 | 21.95 | 32.60 | 107 | | Mistral-7B-v0.3 | 63.54 | 42.74 | 43.72 | 12.30 | 40.50 | 67.50 | 87.50 | 40.45 | 25.61 | 36.00 | 108 | | Llama-3-Chinese-8B | 64.10 | 50.14 | 51.20 | 3.60 | 0.80 | 1.90 | 0.60 | 36.82 | 9.76 | 14.80 | 109 | | MAmmoTH2-8B | 64.89 | 46.56 | 45.90 | **34.10** | **61.70** | **82.80** | 91.50 | 41.36 | 17.68 | 38.80 | 110 | | Galactica-6.7B | 37.13 | 26.72 | 25.53 | 5.30 | 9.60 | 40.90 | 51.70 | 23.18 | 7.31 | 2.00 | 111 | | **Llama-3-SynE (ours)** | 65.19 | **58.24** | **57.34** | 28.20 | 60.80 | 81.00 | **94.10** | **43.64** | **42.07** | 45.60 | 112 | 113 | > On **Chinese evaluation benchmarks** (such as C-Eval and CMMLU), Llama-3-SynE significantly outperforms the base model Llama-3 (8B), indicating that our method is very effective in improving Chinese language capabilities. 114 | 115 | > On **English evaluation benchmarks** (such as MMLU, MATH, and code evaluation benchmarks), Llama-3-SynE demonstrates comparable or better performance than the base model, indicating that our method effectively addresses the issue of catastrophic forgetting during the CPT process. 116 | 117 | ### Scientific Benchmarks 118 | 119 | "PHY", "CHE", and "BIO" denote the physics, chemistry, and biology sub-tasks of the corresponding benchmarks. 120 | 121 | | **Models** | **SciEval PHY** | **SciEval CHE** | **SciEval BIO** | **SciEval Avg.** | **SciQ** | **GaoKao MathQA** | **GaoKao CHE** | **GaoKao BIO** | **ARC Easy** | **ARC Challenge** | **ARC Avg.** | **AQUA-RAT** | 122 | | :---------------------- | :--------------- | :--------------- | :--------------- | :--------------- | :--------------- | :---------------- | :--------------- | :--------------- | :--------------- | :---------------- | :--------------- | :--------------- | 123 | | Llama-3-8B | 46.95 | 63.45 | 74.53 | 65.47 | 90.90 | 27.92 | 32.85 | 43.81 | 91.37 | 77.73 | 84.51 | 27.95 | 124 | | DCLM-7B | **56.71** | 64.39 | 72.03 | 66.25 | **92.50** | 29.06 | 31.40 | 37.14 | 89.52 | 76.37 | 82.94 | 20.08 | 125 | | Mistral-7B-v0.3 | 48.17 | 59.41 | 68.89 | 61.51 | 89.40 | 30.48 | 30.92 | 41.43 | 87.33 | 74.74 | 81.04 | 23.23 | 126 | | Llama-3-Chinese-8B | 48.17 | 67.34 | 73.90 | 67.34 | 89.20 | 27.64 | 30.43 | 38.57 | 88.22 | 70.48 | 79.35 | 27.56 | 127 | | MAmmoTH2-8B | 49.39 | **69.36** | 76.83 | **69.60** | 90.20 | **32.19** | 36.23 | 49.05 | **92.85** | **84.30** | **88.57** | 27.17 | 128 | | Galactica-6.7B | 34.76 | 43.39 | 54.07 | 46.27 | 71.50 | 23.65 | 27.05 | 24.76 | 65.91 | 46.76 | 56.33 | 20.87 | 129 | | **Llama-3-SynE (ours)** | 53.66 | 67.81 | **77.45** | **69.60** | 91.20 | 31.05 | **51.21** | **69.52** | 91.58 | 80.97 | 86.28 | **28.74** | 130 | 131 | > On **scientific evaluation benchmarks** (such as SciEval, GaoKao, and ARC), Llama-3-SynE significantly outperforms the base model, particularly showing remarkable improvement in Chinese scientific benchmarks (for example, a 25.71% improvement in the GaoKao biology subtest). 132 | 133 | ## Quick Start 134 | 135 | Use the transformers backend for inference: 136 | 137 | ```python 138 | from transformers import AutoTokenizer, AutoModelForCausalLM 139 | import torch 140 | 141 | model_path = "survivi/Llama-3-SynE" 142 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 143 | model = AutoModelForCausalLM.from_pretrained( 144 | model_path, torch_dtype=torch.bfloat16, trust_remote_code=True 145 | ) 146 | model.to("cuda:0") 147 | model.eval() 148 | prompt = "Hello world!" 149 | inputs = tokenizer(prompt, return_tensors="pt") 150 | inputs = inputs.to("cuda") 151 | pred = model.generate( 152 | **inputs, 153 | max_new_tokens=2048, 154 | repetition_penalty=1.05, 155 | temperature=0.5, 156 | top_k=5, 157 | top_p=0.85, 158 | do_sample=True 159 | ) 160 | pred = pred[0][len(inputs.input_ids[0]) :] 161 | output = tokenizer.decode(pred, skip_special_tokens=True) 162 | print(output) 163 | ``` 164 | 165 | Use the vLLM backend for inference: 166 | 167 | ```python 168 | from transformers import AutoTokenizer 169 | from vllm import LLM, SamplingParams 170 | 171 | model_path = "survivi/Llama-3-SynE" 172 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 173 | sampling_params = SamplingParams( 174 | max_tokens=2048, 175 | repetition_penalty=1.05, 176 | temperature=0.5, 177 | top_k=5, 178 | top_p=0.85, 179 | ) 180 | llm = LLM( 181 | model=model_path, 182 | tensor_parallel_size=1, 183 | trust_remote_code=True, 184 | ) 185 | prompt = "Hello world!" 186 | output = llm.generate(prompt, sampling_params) 187 | output = output[0].outputs[0].text 188 | print(output) 189 | ``` 190 | 191 | ## License 192 | 193 | This project is built upon Meta's Llama-3 model. The use of Llama-3-SynE model weights must follow the Llama-3 [license agreement](https://github.com/meta-llama/llama3/blob/main/LICENSE). The code in this open-source repository follows the [Apache 2.0](LICENSE) license. 194 | 195 | ## Citation 196 | 197 | If you find our work helpful, please consider citing the following paper: 198 | 199 | ``` 200 | @article{jie2024llama3syne, 201 | title={Towards Effective and Efficient Continual Pre-training of Large Language Models}, 202 | author={Chen, Jie and Chen, Zhipeng and Wang, Jiapeng and Zhou, Kun and Zhu, Yutao and Jiang, Jinhao and Min, Yingqian and Zhao, Wayne Xin and Dou, Zhicheng and Mao, Jiaxin and others}, 203 | journal={arXiv preprint arXiv:2407.18743}, 204 | year={2024} 205 | } 206 | ``` 207 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # The code is adapted from tatsu-lab/stanford_alpaca. The original code is licensed under the Apache 2.0 License. 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Training script for CPT using Hugging Face Transformers and Datasets libraries. 18 | 19 | This script supports loading and preprocessing text datasets or loading preprocessed datasets from disk. It allows for customization of training behavior through various command-line arguments. 20 | 21 | Key Parameters: 22 | - `use_wsd` (bool): If set to True, the script uses the WSD optimizer for training, which may impact the training dynamics. Default is False. 23 | - `no_shuffle` (bool): If set to True, the training data will not be shuffled during training. This can be useful for certain training strategies where data order is important. Default is False. 24 | - `load_text_dataset` (bool): If set to True, the script will load raw text data and perform preprocessing (tokenization and grouping) before training. After preprocessing, the script will save the processed dataset to disk and exit. If False, it assumes that a preprocessed dataset is provided and loads it directly from disk. Default is False. 25 | - `single_dataset` (bool): If set to True, the script will load a single dataset from the specified `data_path`. If False, it will load and concatenate multiple datasets found in the `data_path` directory. Default is False. 26 | 27 | Usage: 28 | Run the script with the desired arguments to start training the model. Command-line arguments allow you to specify the model, data paths, and various training configurations. 29 | 30 | For more detailed configurations and options, refer to the argument definitions in the script. 31 | """ 32 | import os 33 | 34 | # Set the environment variable to use the cache 35 | SAVE_PATH = "~/.cache" 36 | os.environ["TMPDIR"] = os.path.join(SAVE_PATH, "tmp") 37 | os.environ["HF_DATASETS_CACHE"] = os.path.join(SAVE_PATH, "hf_datasets_cache") 38 | os.environ["HF_HOME"] = os.path.join(SAVE_PATH, "hf_home") 39 | os.makedirs(os.environ["TMPDIR"], exist_ok=True) 40 | os.makedirs(os.environ["HF_DATASETS_CACHE"], exist_ok=True) 41 | os.makedirs(os.environ["HF_HOME"], exist_ok=True) 42 | from dataclasses import dataclass, field 43 | from typing import Optional, Sequence, Dict 44 | import random 45 | import copy 46 | import torch 47 | import torch.distributed as dist 48 | from torch.utils.data import Dataset 49 | import transformers 50 | import datasets 51 | from transformers import ( 52 | Trainer, 53 | set_seed, 54 | AutoModelForCausalLM, 55 | AutoTokenizer, 56 | ) 57 | 58 | from train_utils import ( 59 | NoShuffleSeq2SeqTrainer, 60 | WSDTrainer, 61 | WSDNoShuffleTrainer, 62 | ) 63 | 64 | LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) 65 | 66 | 67 | @dataclass 68 | class ModelArguments: 69 | model_name_or_path: Optional[str] = field(default="") 70 | flash_attention: Optional[bool] = field(default=False) 71 | 72 | 73 | @dataclass 74 | class DataArguments: 75 | data_path: str = field( 76 | default=None, metadata={"help": "Path to the training data."} 77 | ) 78 | no_shuffle: bool = field( 79 | default=False, metadata={"help": "Whether to shuffle the training data."} 80 | ) 81 | preprocess_num_workers: int = field( 82 | default=32, 83 | metadata={"help": "The number of processes to use for the preprocessing."}, 84 | ) 85 | load_text_dataset: bool = field( 86 | default=False, metadata={"help": "Whether the dataset is text or input ids."} 87 | ) 88 | min_text_length: int = field( 89 | default=20, metadata={"help": "Minimum text length to include in the dataset."} 90 | ) 91 | single_dataset: bool = field( 92 | default=False, 93 | metadata={"help": "Whether to load a single dataset."}, 94 | ) 95 | 96 | 97 | @dataclass 98 | class TrainingArguments(transformers.TrainingArguments): 99 | cache_dir: Optional[str] = field(default=None) 100 | model_max_length: int = field( 101 | default=2048, 102 | metadata={ 103 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 104 | }, 105 | ) 106 | use_wsd: bool = field(default=False) 107 | 108 | 109 | class SupervisedDataset(Dataset): 110 | def __init__(self, train_dataset): 111 | super(SupervisedDataset, self).__init__() 112 | self.sources = train_dataset 113 | 114 | def __len__(self): 115 | return len(self.sources) 116 | 117 | def __getitem__(self, idx) -> Dict[str, torch.Tensor]: 118 | ipt_ids = self.sources[idx]["input_ids"] 119 | return dict(input_ids=ipt_ids, labels=copy.deepcopy(ipt_ids)) 120 | 121 | 122 | @dataclass 123 | class DataCollatorForSupervisedDataset(object): 124 | data_args: DataArguments 125 | tokenizer: transformers.PreTrainedTokenizer 126 | 127 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 128 | return dict( 129 | input_ids=torch.tensor([d["input_ids"] for d in instances]), 130 | labels=torch.tensor([d["labels"] for d in instances]), 131 | ) 132 | 133 | 134 | def make_supervised_data_module(tokenizer, data_args, training_args, model_args): 135 | def process_example(example): 136 | tokenized_list = tokenizer(example["text"], add_special_tokens=False) 137 | token_ids = [ 138 | tokenized + [tokenizer.eos_token_id] 139 | for tokenized in tokenized_list["input_ids"] 140 | ] 141 | return {"input_ids": token_ids} 142 | 143 | def group_texts(examples): 144 | processed_data = [] 145 | tmp = [] 146 | for example in examples["input_ids"]: 147 | if len(example) < data_args.min_text_length: 148 | continue 149 | tmp.extend(example) 150 | while len(tmp) >= training_args.model_max_length: 151 | processed_data.append(tmp[: training_args.model_max_length]) 152 | tmp = tmp[training_args.model_max_length :] 153 | return {"input_ids": processed_data} 154 | 155 | # Load and preprocess the text dataset if needed 156 | if data_args.load_text_dataset: 157 | world_size = int(os.environ.get("WORLD_SIZE", "1")) 158 | rank = int(os.environ.get("RANK", "0")) 159 | if rank == 0 and LOCAL_RANK == 0: 160 | total_paths = [ 161 | os.path.join(data_args.data_path, file_name) 162 | for file_name in os.listdir(data_args.data_path) 163 | if file_name.endswith(".jsonl") or file_name.endswith(".json") 164 | ] 165 | # Process per 60G 166 | file_sizes = [os.path.getsize(path) for path in total_paths] 167 | file_size_limit = 60 * 1024**3 168 | agg_paths = [] 169 | temp_size = 0 170 | temp_paths = [] 171 | for i, size in enumerate(file_sizes): 172 | temp_size += size 173 | temp_paths.append(total_paths[i]) 174 | if temp_size >= file_size_limit: 175 | agg_paths.append(temp_paths) 176 | temp_size = 0 177 | temp_paths = [] 178 | if temp_paths: 179 | agg_paths.append(temp_paths) 180 | for i, paths in enumerate(agg_paths): 181 | data_save_dir = os.path.join( 182 | training_args.output_dir, 183 | data_args.data_path.split("/")[-1] + f"_{i}", 184 | ) 185 | print( 186 | f"Start processing data list: {paths} | {len(paths)} files | save to {data_save_dir}" 187 | ) 188 | raw_train_dataset = datasets.Dataset.from_json( 189 | path_or_paths=paths, 190 | num_proc=data_args.preprocess_num_workers, 191 | ) 192 | print(raw_train_dataset) 193 | print(f"Raw dataset size: {len(raw_train_dataset)}") 194 | print("Tokenizing dataset") 195 | train_dataset = raw_train_dataset.map( 196 | process_example, 197 | batched=True, 198 | num_proc=data_args.preprocess_num_workers, 199 | remove_columns=raw_train_dataset.column_names, 200 | desc="Running tokenizer on train dataset", 201 | ) 202 | print("Tokenizing dataset finished") 203 | print( 204 | f"Grouping texts with sequence length {training_args.model_max_length}. Filter texts shorter than {data_args.min_text_length}" 205 | ) 206 | train_dataset = train_dataset.map( 207 | group_texts, 208 | batched=True, 209 | num_proc=data_args.preprocess_num_workers, 210 | desc=f"Grouping texts with sequence length {training_args.model_max_length}", 211 | ) 212 | if not os.path.exists(training_args.output_dir): 213 | os.mkdir(training_args.output_dir) 214 | print(train_dataset) 215 | train_dataset.save_to_disk(data_save_dir) 216 | if world_size > 1: 217 | dist.barrier() 218 | print( 219 | f"Preprocess finished. Please set `load_text_dataset` to False and reload from {data_save_dir} with `single_dataset` set to True | Exit" 220 | ) 221 | exit(0) 222 | 223 | if data_args.single_dataset: 224 | train_dataset = datasets.load_from_disk(data_args.data_path) 225 | print(train_dataset) 226 | else: 227 | train_dataset = [] 228 | for data_name in os.listdir(data_args.data_path): 229 | train_dataset.append( 230 | datasets.load_from_disk(os.path.join(data_args.data_path, data_name)) 231 | ) 232 | print(f"Dataset {data_name} loaded") 233 | print(len(train_dataset)) 234 | train_dataset = datasets.concatenate_datasets(train_dataset) 235 | print(train_dataset) 236 | 237 | print(f"train dataset size: {len(train_dataset)}") 238 | if LOCAL_RANK == 0: 239 | for index in [0] + list(random.sample(range(len(train_dataset)), 1)): 240 | print(f"Sample {index} of the training set: {train_dataset[index]}.") 241 | print("---------" * 9) 242 | if isinstance(train_dataset[index]["input_ids"][0], list): 243 | print(tokenizer.decode(train_dataset[index]["input_ids"][0])) 244 | else: 245 | print(tokenizer.decode(train_dataset[index]["input_ids"])) 246 | print("=========" * 9) 247 | 248 | train_dataset = SupervisedDataset(train_dataset=train_dataset) 249 | data_collator = DataCollatorForSupervisedDataset( 250 | data_args=data_args, tokenizer=tokenizer 251 | ) 252 | return dict( 253 | train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator 254 | ) 255 | 256 | 257 | def get_model_tokenizer(model_args, data_args, training_args): 258 | model = AutoModelForCausalLM.from_pretrained( 259 | model_args.model_name_or_path, 260 | attn_implementation="flash_attention_2" if model_args.flash_attention else None, 261 | cache_dir=training_args.cache_dir, 262 | ) 263 | tokenizer = AutoTokenizer.from_pretrained( 264 | model_args.model_name_or_path, 265 | cache_dir=training_args.cache_dir, 266 | model_max_length=training_args.model_max_length, 267 | padding_side="right", 268 | ) 269 | assert tokenizer.eos_token_id is not None, "Tokenizer must have an EOS token" 270 | assert model.get_output_embeddings().weight.data.size(0) == len( 271 | tokenizer 272 | ), "The vocabulary size of the model and the tokenizer should be the same" 273 | return model, tokenizer 274 | 275 | 276 | def train(): 277 | parser = transformers.HfArgumentParser( 278 | (ModelArguments, DataArguments, TrainingArguments) 279 | ) 280 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 281 | 282 | if training_args.gradient_checkpointing: 283 | training_args.gradient_checkpointing_kwargs = { 284 | "use_reentrant": False 285 | } # OR gradient_checkpointing_kwargs={'use_reentrant':True}, please refer to https://github.com/huggingface/transformers/issues/26969 286 | 287 | model, tokenizer = get_model_tokenizer(model_args, data_args, training_args) 288 | 289 | set_seed(training_args.seed) 290 | 291 | data_module = make_supervised_data_module( 292 | tokenizer=tokenizer, 293 | data_args=data_args, 294 | training_args=training_args, 295 | model_args=model_args, 296 | ) 297 | model.is_parallelizable = True 298 | model.model_parallel = True 299 | trainer_class = Trainer 300 | if data_args.no_shuffle: 301 | if training_args.use_wsd: 302 | trainer_class = WSDNoShuffleTrainer 303 | else: 304 | trainer_class = NoShuffleSeq2SeqTrainer 305 | elif training_args.use_wsd: 306 | trainer_class = WSDTrainer 307 | trainer = trainer_class( 308 | model=model, 309 | tokenizer=tokenizer, 310 | args=training_args, 311 | **data_module, 312 | ) 313 | model.config.use_cache = False 314 | trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 315 | trainer.save_state() 316 | trainer.save_model(output_dir=training_args.output_dir) 317 | 318 | 319 | if __name__ == "__main__": 320 | train() 321 | --------------------------------------------------------------------------------