├── README.md ├── prolong ├── run.py ├── run_batch_multi_process.py ├── run_batch_multinodes.py ├── scripts │ ├── filelist │ ├── iphost │ ├── run_multinodes.sh │ ├── run_multiprocess.sh │ └── run_single.sh └── tools │ └── sort_and_get.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # ProLong 2 | 3 | This repository contains code for the **ACL'2024 Oral** paper "**[Long Context is Not Long at All: A Prospector of Long-Dependency Data for Large Language Models](https://arxiv.org/abs/2405.17915)**". 4 | 5 | ## Abstract 6 | 7 | Long-context modeling capabilities are important for large language models (LLMs) in various applications. However, directly training LLMs with long context windows is insufficient to enhance this capability since some training samples do not exhibit strong semantic dependencies across long contexts. 8 | In this study, we propose a data mining framework **ProLong** that can assign each training sample with a long dependency score, which can be used to rank and filter samples that are more advantageous for enhancing long-context modeling abilities in LLM training. Specifically, we first use delta perplexity scores to measure the *Dependency Strength* between text segments in a given document. Then we refine this metric based on the *Dependency Distance* of these segments to incorporate spatial relationships across long-contexts. Final results are calibrated with a *Dependency Specificity* metric to prevent trivial dependencies introduced by repetitive patterns. Moreover, a random sampling approach is proposed to optimize the computational efficiency of ProLong. Comprehensive experiments on multiple benchmarks indicate that ProLong effectively identifies documents that carry long dependencies and LLMs trained on these documents exhibit significantly enhanced long-context modeling capabilities. 9 | 10 | ## Requirements 11 | ``` 12 | torch==2.1.1 13 | transformers==4.36.0 14 | tqdm==4.66.1 15 | numpy==1.22.2 16 | matplotlib==3.8.0 17 | datasets==2.15.0 18 | ``` 19 | 20 | Install the required packages 21 | ```bash 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ## Data Preparation 26 | The file to be processed should be in the following jsonl format: 27 | ```json 28 | {"text": "This is the first sentence. This is the second sentence. This is the third sentence."} 29 | {"text": "This is the first sentence. This is the second sentence. This is the third sentence."} 30 | ``` 31 | 32 | ## Usage 33 | 34 | ### process single file with single process 35 | ```bash 36 | bash scripts/run_single.sh 37 | ``` 38 | 39 | ### process single file with single node and multiple processes 40 | ```bash 41 | bash scripts/run_multiprocess.sh 42 | ``` 43 | 44 | ### process multiple files with multiple nodes multiple processes 45 | ```bash 46 | bash scripts/run_multinodes.sh 47 | ``` 48 | 49 | ## Key parameters 50 | * `chunk_size` - The chunk size to be used for processing the data, here we use 128 51 | * `window_size` - The maximum window size to be considered, here we use 32768 52 | * `dlt_ppl_threshold` - The threshold to be used for filter delta perplexity, here we use 0.1 53 | * `single_ppl_batch_size` - The batch size to be used for calculating single perplexity 54 | * `pair_ppl_batch_size` - The batch size to be used for calculating pair perplexity 55 | * `sample_size` - The sample size to be used when calculating pair perplexity, if sample size is set to -1, then sampling strategy will not be used, all pairs will be calculated 56 | 57 | ## ProLong Test Set 58 | * The toy test set constructed in the paper is relatively small. Subsequently, we will release a larger ProLong test set with a broader source to assist users in selecting hyperparameters based on their experimental settings. 59 | 60 | ## Citation 61 | 62 | If you find this repository helpful, please consider citing the following paper: 63 | 64 | ```bib 65 | @article{chen2024long, 66 | title={Long Context is Not Long at All: A Prospector of Long-Dependency Data for Large Language Models}, 67 | author={Chen, Longze and Liu, Ziqiang and He, Wanwei and Li, Yunshui and Luo, Run and Yang, Min}, 68 | journal={arXiv preprint arXiv:2405.17915}, 69 | year={2024} 70 | } 71 | ``` 72 | 73 | ## Contact 74 | 75 | 76 | If you have any questions, feel free to contact us at `lz.chen2@siat.ac.cn` or `ww.he@siat.ac.cn`. 77 | -------------------------------------------------------------------------------- /prolong/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from tqdm import tqdm 4 | import json 5 | import os 6 | import random 7 | import argparse 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | from transformers import LlamaTokenizer, LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM 12 | from torch.nn import CrossEntropyLoss 13 | 14 | 15 | IGNORE_INDEX = -100 16 | 17 | 18 | def parse_config(): 19 | parser = argparse.ArgumentParser() 20 | 21 | # data parameter 22 | parser.add_argument('--data_file', type=str) 23 | 24 | # lds parameter 25 | parser.add_argument('--chunk_size', type=int, default=128) 26 | parser.add_argument('--dlt_ppl_threshold', type=float, default=0.1) 27 | parser.add_argument('--window_size', type=int, default=32768) 28 | 29 | # save parameter 30 | parser.add_argument('--save_file', type=str) 31 | parser.add_argument('--score_file', type=str) 32 | parser.add_argument('--pic_dir', type=str) 33 | 34 | # model configuration 35 | parser.add_argument('--model_name', type=str, default='facebook/opt-350m') 36 | parser.add_argument('--use_flash_attention_2', action='store_true') 37 | 38 | # other 39 | parser.add_argument('--seed', type=int, default=11) 40 | 41 | parser.add_argument('--single_ppl_batch_size', type=int, default=256) 42 | parser.add_argument('--pair_ppl_batch_size', type=int, default=256) 43 | parser.add_argument('--sample_size', type=int, default=500) 44 | parser.add_argument('--need_draw', type=bool, default=False) 45 | 46 | return parser.parse_args() 47 | 48 | 49 | def set_seed(seed): 50 | """ fix random seed """ 51 | torch.manual_seed(seed) 52 | torch.cuda.manual_seed(seed) 53 | random.seed(seed) 54 | np.random.seed(seed) 55 | 56 | def sample_preserve_order(array, sample_size): 57 | indices = list(range(len(array))) 58 | assert sample_size <= len(indices) 59 | sampled_indices = sorted(random.sample(indices, sample_size)) 60 | return [array[i] for i in sampled_indices] 61 | 62 | def construct_data(data, tokenizer, chunk_size, chunk_num): 63 | tokenized_data = tokenizer(data)['input_ids'] 64 | data_list = [tokenized_data[i:i + chunk_size] for i in range(0, len(tokenized_data), chunk_size)] 65 | 66 | if len(data_list[-1]) < chunk_size: 67 | data_list = data_list[:-1] 68 | if len(data_list) > chunk_num: 69 | data_list = sample_preserve_order(array=data_list, sample_size=chunk_num) 70 | return data_list 71 | 72 | def compute_ppl(logits, labels, nums): 73 | # Shift so that tokens < n predict n 74 | shift_logits = logits[..., :-1, :].contiguous() 75 | shift_labels = labels[..., 1:].contiguous() 76 | # Flatten the tokens 77 | loss_fct = CrossEntropyLoss(reduction='none') 78 | shift_logits = shift_logits.view(-1, model.config.vocab_size) 79 | shift_labels = shift_labels.view(-1) 80 | # Enable model parallelism 81 | shift_labels = shift_labels.to(shift_logits.device) 82 | loss = loss_fct(shift_logits, shift_labels) 83 | loss = loss.view(labels.size(0), -1) # reshape loss back to sequence length 84 | 85 | batch_ppl = [] 86 | for i, num in enumerate(nums): 87 | avg_loss = loss[i, -num:].mean() 88 | batch_ppl.append(torch.exp(avg_loss).float().cpu().item()) 89 | return batch_ppl 90 | 91 | 92 | def compute_single_ppl(data_list, batch_size): 93 | single_ppl = [0 for _ in range(len(data_list))] 94 | with torch.no_grad(): 95 | model.eval() 96 | for i in range(0, len(data_list), batch_size): 97 | batch = data_list[i:i+batch_size] 98 | nums = [len(b) - 1 for b in batch] 99 | inputs = torch.tensor(batch).to(device) 100 | labels = inputs.clone() 101 | logits = model(input_ids=inputs)[0] 102 | batch_ppl = compute_ppl(logits, labels, nums) 103 | single_ppl[i:i+batch_size] = batch_ppl 104 | return single_ppl 105 | 106 | 107 | def compute_pair_ppl(data_list, batch_size, sample_size=-1): 108 | pair_ppl = [[float('inf') for _ in range(len(data_list))] for _ in range(len(data_list))] 109 | with torch.no_grad(): 110 | model.eval() 111 | pairs = [(i, j) for i in range(len(data_list)) for j in range(i)] 112 | if sample_size > 0: 113 | if len(pairs) < sample_size: 114 | return pair_ppl 115 | pairs = random.sample(pairs, sample_size) 116 | for batch_start in range(0, len(pairs), batch_size): 117 | batch_pairs = pairs[batch_start:batch_start+batch_size] 118 | nums = [len(data_list[i]) - 1 for i, _ in batch_pairs] 119 | inputs = [data_list[j] + data_list[i] for i, j in batch_pairs] 120 | inputs = torch.tensor(inputs).to(device) 121 | labels = torch.tensor([[IGNORE_INDEX] * (len(data_list[j]) + 1) + data_list[i][1:] for i, j in batch_pairs]).to(device) 122 | logits = model(input_ids=inputs)[0] 123 | batch_ppl = compute_ppl(logits, labels, nums) 124 | for k, (i, j) in enumerate(batch_pairs): 125 | pair_ppl[i][j] = batch_ppl[k] 126 | return pair_ppl 127 | 128 | def compute_de(logits): 129 | 130 | def _softmax(x): 131 | e_x = np.exp(x - np.max(x)) 132 | return e_x / e_x.sum(axis=0) 133 | 134 | def _compute_entropy(x): 135 | if 0 in x: 136 | x += 1e-12 137 | x /= np.sum(x) 138 | entropy = -np.sum(x * np.log(x)) 139 | return entropy 140 | 141 | if len(logits) == 1: 142 | return 1 143 | 144 | max_entropy = np.log(len(logits)) 145 | entropy = _compute_entropy(_softmax(logits)) 146 | return np.clip((max_entropy - entropy) / max_entropy, 0, 1) 147 | 148 | def compute_lds(single_ppl, pair_ppl): 149 | dlt_ppl = [[0 for _ in range(len(single_ppl))] for _ in range(len(single_ppl))] 150 | dependency_entropy = [0 for _ in range(len(single_ppl))] 151 | dis_scale = 1 / (args.chunk_num - 1) 152 | lds = 0 153 | 154 | for i in range(len(single_ppl)): 155 | row_logits = [] 156 | for j in range(i): 157 | dlt_ppl[i][j] = single_ppl[i] - pair_ppl[i][j] 158 | if pair_ppl[i][j] != float('inf'): 159 | row_logits.append(dlt_ppl[i][j]) 160 | if len(row_logits) > 0: 161 | dependency_entropy[i] = compute_de(logits=row_logits) 162 | 163 | for i in range(len(single_ppl)): 164 | for j in range(i): 165 | dlt_ppl[i][j] /= single_ppl[i] 166 | if dlt_ppl[i][j] > args.dlt_ppl_threshold: 167 | distance_gain = np.clip((i - j) * dis_scale, 0, 1) 168 | lds += (dlt_ppl[i][j] + distance_gain) * dependency_entropy[i] 169 | 170 | if np.isnan(lds): 171 | lds = 0. 172 | return lds, dlt_ppl, dependency_entropy 173 | 174 | 175 | def draw(save_path, matrix_data, lds): 176 | matrix_array = np.array(matrix_data) 177 | 178 | plt.imshow(matrix_array, cmap='viridis', interpolation='nearest') 179 | plt.colorbar() 180 | plt.title(f'LDS = {lds}') 181 | plt.show() 182 | plt.savefig(save_path) 183 | plt.clf() 184 | 185 | 186 | if __name__ == "__main__": 187 | # cuda 188 | cuda_available = torch.cuda.is_available() 189 | 190 | if cuda_available: 191 | print('Cuda is available.') 192 | device = torch.device('cuda') 193 | else: 194 | print('Cuda is not available') 195 | device = torch.device('cpu') 196 | 197 | # args 198 | args = parse_config() 199 | args.chunk_num = args.window_size // args.chunk_size 200 | 201 | # set seed 202 | set_seed(seed=args.seed) 203 | 204 | # model 205 | print ('Start loading tokenizer...') 206 | tokenizer = AutoTokenizer.from_pretrained( 207 | args.model_name, 208 | trust_remote_code=True, 209 | use_fast=False 210 | ) 211 | if hasattr(tokenizer, "add_bos_token"): 212 | setattr(tokenizer, "add_bos_token", False) 213 | if hasattr(tokenizer, "add_eos_token"): 214 | setattr(tokenizer, "add_eos_token", False) 215 | 216 | # model 217 | print ('Start loading model...') 218 | model = AutoModelForCausalLM.from_pretrained( 219 | args.model_name, 220 | torch_dtype=torch.float16, 221 | use_flash_attention_2=args.use_flash_attention_2, 222 | trust_remote_code=True, 223 | ) 224 | model.to(device) 225 | print ('Model loaded') 226 | 227 | with open(args.data_file, 'r', encoding='utf-8') as i: 228 | all_data = i.readlines() 229 | for idx, single_data in tqdm(enumerate(all_data), total=len(all_data)): 230 | # Loading Data 231 | if args.data_file.endswith('.jsonl'): 232 | json_data = json.loads(single_data) 233 | data = json_data['text'] 234 | else: 235 | data = single_data 236 | # Construct Data 237 | data_list = construct_data(data=data, tokenizer=tokenizer, chunk_size=args.chunk_size, chunk_num=args.chunk_num) 238 | 239 | start = time.time() 240 | try: 241 | single_ppl = compute_single_ppl(data_list, args.single_ppl_batch_size) 242 | pair_ppl = compute_pair_ppl(data_list, args.pair_ppl_batch_size, args.sample_size) 243 | long_dependency_score, dlt_ppl, dependency_entropy = compute_lds(single_ppl=single_ppl, pair_ppl=pair_ppl) 244 | except Exception as e: 245 | print (f'Error: {e}, set LDS to 0.') 246 | long_dependency_score = 0. 247 | 248 | print(f'long_dependency_score: {long_dependency_score}') 249 | end = time.time() 250 | print(f'cost time: {end - start} seconds') 251 | 252 | # draw 253 | if args.need_draw: 254 | if idx % 1000 == 0: 255 | save_path = os.path.join(args.pic_dir, str(idx) + '.png') 256 | draw(save_path, dlt_ppl, long_dependency_score) 257 | 258 | # wirte score to file 259 | with open(args.score_file, 'a', encoding='utf-8') as o: 260 | o.write(str(long_dependency_score) + '\n') 261 | 262 | # store data with score 263 | new_data = {'text': data, 'lds': long_dependency_score} 264 | with open(args.save_file, 'a', encoding='utf-8') as o: 265 | o.write(json.dumps(new_data, ensure_ascii=False) + '\n') 266 | 267 | -------------------------------------------------------------------------------- /prolong/run_batch_multi_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from tqdm import tqdm 4 | import json 5 | import os 6 | import random 7 | import argparse 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import multiprocessing 11 | from transformers import LlamaTokenizer, LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM 12 | from torch.nn import CrossEntropyLoss 13 | 14 | multiprocessing.set_start_method('spawn', force=True) 15 | 16 | IGNORE_INDEX = -100 17 | 18 | def set_seed(seed): 19 | """ fix random seed """ 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | 25 | class DataProcessor: 26 | def __init__(self, model_name, use_flash_attention_2, data_file, input_chunk_path, output_chunk_path, score_chunk_path, pic_dir, chunk_size, dlt_ppl_threshold, window_size, single_ppl_batch_size, pair_ppl_batch_size, sample_size, need_draw, seed, device): 27 | self.model_name = model_name 28 | self.data_file = data_file 29 | self.input_chunk_path = input_chunk_path 30 | self.output_chunk_path = output_chunk_path 31 | self.score_chunk_path = score_chunk_path 32 | self.pic_dir = pic_dir 33 | self.chunk_size = chunk_size 34 | self.dlt_ppl_threshold = dlt_ppl_threshold 35 | self.window_size = window_size 36 | self.chunk_num = window_size // chunk_size 37 | self.single_ppl_batch_size = single_ppl_batch_size 38 | self.pair_ppl_batch_size = pair_ppl_batch_size 39 | self.sample_size = sample_size 40 | self.need_draw = need_draw 41 | self.seed = seed 42 | self.device = device 43 | 44 | self.model = AutoModelForCausalLM.from_pretrained( 45 | model_name, 46 | torch_dtype=torch.float16, 47 | use_flash_attention_2=use_flash_attention_2, 48 | trust_remote_code=True, 49 | ) 50 | self.model.to(device) 51 | 52 | self.tokenizer = AutoTokenizer.from_pretrained( 53 | model_name, 54 | use_fast=False, 55 | trust_remote_code=True 56 | ) 57 | if hasattr(self.tokenizer, "add_bos_token"): 58 | setattr(self.tokenizer, "add_bos_token", False) 59 | if hasattr(self.tokenizer, "add_eos_token"): 60 | setattr(self.tokenizer, "add_eos_token", False) 61 | 62 | def sample_preserve_order(self, array, sample_size): 63 | indices = list(range(len(array))) 64 | assert sample_size <= len(indices) 65 | sampled_indices = sorted(random.sample(indices, sample_size)) 66 | return [array[i] for i in sampled_indices] 67 | 68 | 69 | def construct_data(self, data, tokenizer, chunk_size, chunk_num): 70 | tokenized_data = tokenizer(data)['input_ids'] 71 | data_list = [tokenized_data[i:i + chunk_size] for i in range(0, len(tokenized_data), chunk_size)] 72 | 73 | if len(data_list[-1]) < chunk_size: 74 | data_list = data_list[:-1] 75 | if len(data_list) > chunk_num: 76 | data_list = self.sample_preserve_order(array=data_list, sample_size=chunk_num) 77 | return data_list 78 | 79 | 80 | def compute_de(self, logits): 81 | 82 | def _softmax(x): 83 | e_x = np.exp(x - np.max(x)) 84 | return e_x / e_x.sum(axis=0) 85 | 86 | def _compute_entropy(x): 87 | if 0 in x: 88 | x += 1e-12 89 | x /= np.sum(x) 90 | entropy = -np.sum(x * np.log(x)) 91 | return entropy 92 | 93 | if len(logits) == 1: 94 | return 1 95 | 96 | max_entropy = np.log(len(logits)) 97 | entropy = _compute_entropy(_softmax(logits)) 98 | return np.clip((max_entropy - entropy) / max_entropy, 0, 1) 99 | 100 | 101 | def compute_lds(self, single_ppl, pair_ppl): 102 | dlt_ppl = [[0 for _ in range(len(single_ppl))] for _ in range(len(single_ppl))] 103 | dependency_entropy = [0 for _ in range(len(single_ppl))] 104 | dis_scale = 1 / (self.chunk_num - 1) 105 | lds = 0 106 | 107 | for i in range(len(single_ppl)): 108 | row_logits = [] 109 | for j in range(i): 110 | dlt_ppl[i][j] = single_ppl[i] - pair_ppl[i][j] 111 | if pair_ppl[i][j] != float('inf'): 112 | row_logits.append(dlt_ppl[i][j]) 113 | if len(row_logits) > 0: 114 | dependency_entropy[i] = self.compute_de(logits=row_logits) 115 | 116 | for i in range(len(single_ppl)): 117 | for j in range(i): 118 | dlt_ppl[i][j] /= single_ppl[i] 119 | if dlt_ppl[i][j] > self.dlt_ppl_threshold: 120 | distance_gain = np.clip((i - j) * dis_scale, 0, 1) 121 | lds += (dlt_ppl[i][j] + distance_gain) * dependency_entropy[i] 122 | if np.isnan(lds): 123 | lds = 0. 124 | return lds, dlt_ppl, dependency_entropy 125 | 126 | def compute_single_ppl(self, data_list, batch_size): 127 | single_ppl = [0 for _ in range(len(data_list))] 128 | with torch.no_grad(): 129 | self.model.eval() 130 | for i in range(0, len(data_list), batch_size): 131 | batch = data_list[i:i+batch_size] 132 | nums = [len(b) - 1 for b in batch] 133 | inputs = torch.tensor(batch).to(self.device) 134 | labels = inputs.clone() 135 | logits = self.model(input_ids=inputs)[0] 136 | batch_ppl = self.compute_ppl(logits, labels, nums) 137 | single_ppl[i:i+batch_size] = batch_ppl 138 | return single_ppl 139 | 140 | 141 | def compute_pair_ppl(self, data_list, batch_size, sample_size=-1): 142 | pair_ppl = [[float('inf') for _ in range(len(data_list))] for _ in range(len(data_list))] 143 | with torch.no_grad(): 144 | self.model.eval() 145 | pairs = [(i, j) for i in range(len(data_list)) for j in range(i)] 146 | if sample_size > 0: 147 | if len(pairs) < sample_size: 148 | return pair_ppl 149 | pairs = random.sample(pairs, sample_size) 150 | for batch_start in range(0, len(pairs), batch_size): 151 | batch_pairs = pairs[batch_start:batch_start+batch_size] 152 | nums = [len(data_list[i]) - 1 for i, _ in batch_pairs] 153 | inputs = [data_list[j] + data_list[i] for i, j in batch_pairs] 154 | inputs = torch.tensor(inputs).to(self.device) 155 | labels = torch.tensor([[IGNORE_INDEX] * (len(data_list[j]) + 1) + data_list[i][1:] for i, j in batch_pairs]).to(self.device) 156 | logits = self.model(input_ids=inputs)[0] 157 | batch_ppl = self.compute_ppl(logits, labels, nums) 158 | for k, (i, j) in enumerate(batch_pairs): 159 | pair_ppl[i][j] = batch_ppl[k] 160 | return pair_ppl 161 | 162 | 163 | def compute_ppl(self, logits, labels, nums): 164 | # Shift so that tokens < n predict n 165 | shift_logits = logits[..., :-1, :].contiguous() 166 | shift_labels = labels[..., 1:].contiguous() 167 | # Flatten the tokens 168 | loss_fct = CrossEntropyLoss(reduction='none') 169 | shift_logits = shift_logits.view(-1, self.model.config.vocab_size) 170 | shift_labels = shift_labels.view(-1) 171 | # Enable model parallelism 172 | shift_labels = shift_labels.to(shift_logits.device) 173 | loss = loss_fct(shift_logits, shift_labels) 174 | loss = loss.view(labels.size(0), -1) # reshape loss back to sequence length 175 | 176 | batch_ppl = [] 177 | for i, num in enumerate(nums): 178 | avg_loss = loss[i, -num:].mean() 179 | batch_ppl.append(torch.exp(avg_loss).float().cpu().item()) 180 | return batch_ppl 181 | 182 | def draw(self, save_path, matrix_data, lds): 183 | matrix_array = np.array(matrix_data) 184 | 185 | plt.imshow(matrix_array, cmap='viridis', interpolation='nearest') 186 | plt.colorbar() 187 | plt.title(f'LDS = {lds}') 188 | plt.show() 189 | plt.savefig(save_path) 190 | plt.clf() 191 | 192 | def process_data(self): 193 | base_name = self.data_file.split('.')[0] 194 | input_file = os.path.join(self.input_chunk_path, self.data_file) 195 | output_file = os.path.join(self.output_chunk_path, base_name+'.jsonl') 196 | score_chunk_path = os.path.join(self.score_chunk_path, base_name+'.txt') 197 | pic_dir = os.path.join(self.pic_dir, base_name) 198 | if not os.path.exists(score_chunk_path): 199 | # create file 200 | with open(score_chunk_path, 'w', encoding='utf-8') as f: 201 | pass 202 | chunk_num = self.window_size // self.chunk_size 203 | if not os.path.exists(pic_dir): 204 | os.makedirs(pic_dir) 205 | exist_score_lines = sum(1 for _ in open(score_chunk_path, 'r', encoding='utf-8')) 206 | print(f'[INFO] file {score_chunk_path} exist_score_lines: {exist_score_lines}') 207 | with open(input_file, 'r', encoding='utf-8') as fin, \ 208 | open(score_chunk_path, 'a', encoding='utf-8') as fscore, \ 209 | open(output_file, 'a', encoding='utf-8') as fout: 210 | total_lines = sum(1 for _ in fin) 211 | fin.seek(0) 212 | idx = 0 213 | for single_data in tqdm(fin, total=total_lines): 214 | # Loading Data 215 | if idx < exist_score_lines: 216 | idx += 1 217 | continue 218 | if input_file.endswith('.jsonl'): 219 | json_data = json.loads(single_data) 220 | if 'text' in json_data: 221 | data = json_data['text'] 222 | else: 223 | data = single_data 224 | try: 225 | data_list = self.construct_data(data, self.tokenizer, self.chunk_size, chunk_num) 226 | start = time.time() 227 | single_ppl = self.compute_single_ppl(data_list, self.single_ppl_batch_size) 228 | pair_ppl = self.compute_pair_ppl(data_list, self.pair_ppl_batch_size, self.sample_size) 229 | long_dependency_score, dlt_ppl, dependency_entropy = self.compute_lds(single_ppl=single_ppl, pair_ppl=pair_ppl) 230 | except Exception as e: 231 | print (f'[Error]: {e}, [file_name]: {input_file}, [idx]: {idx}, set LDS to 0.') 232 | long_dependency_score = 0. 233 | end = time.time() 234 | 235 | # draw 236 | if self.need_draw: 237 | if idx % 5000 == 0: 238 | print(f'file: {base_name} idx: {idx}') 239 | save_pic_path = os.path.join(pic_dir, str(idx) + '.png') 240 | self.draw(save_pic_path, dlt_ppl, long_dependency_score) 241 | 242 | # wirte score to file 243 | fscore.write(str(long_dependency_score) + '\n') 244 | 245 | # store data with score 246 | new_data = {'text': data, 'lds': long_dependency_score} 247 | fout.write(json.dumps(new_data, ensure_ascii=False) + '\n') 248 | idx += 1 249 | os.remove(input_file) 250 | 251 | def split_jsonl(input_file, output_path, lines_per_file, total_lines): 252 | with open(input_file, 'r', encoding='utf-8') as f: 253 | file_count, line_count = 0, 0 254 | current_out = open(f"{output_path}/chunk_{file_count}.jsonl", 'w', encoding='utf-8') 255 | for line in tqdm(f, total=total_lines): 256 | if line_count < lines_per_file: 257 | current_out.write(line) 258 | line_count += 1 259 | else: 260 | current_out.close() 261 | file_count += 1 262 | current_out = open(f"{output_path}/chunk_{file_count}.jsonl", 'w', encoding='utf-8') 263 | current_out.write(line) 264 | line_count = 1 265 | current_out.close() 266 | 267 | def merge(file_list, output_file, output_chunk_path): 268 | count = 0 269 | with open(output_file, 'w', encoding='utf-8') as fout: 270 | for file_name in file_list: 271 | with open(os.path.join(output_chunk_path, file_name), 'r', encoding='utf-8') as fin: 272 | for line in fin: 273 | print(line, end='', file=fout) 274 | count += 1 275 | return count 276 | 277 | 278 | def process_single_chunk(gpu_id, file_name, input_chunk_path, output_chunk_path, score_chunk_path, pic_dir, chunk_size, window_size, sample_size, single_ppl_batch_size, pair_ppl_batch_size, dlt_ppl_threshold, model_name, use_flash_attention_2, seed, need_draw): 279 | pseed = int(file_name.split('.jsonl')[0][6:]) 280 | random.seed(seed + pseed) 281 | 282 | process_id = os.getpid() 283 | print(f'[PID-{process_id}] {file_name} start!') 284 | 285 | # cuda 286 | cuda_available = torch.cuda.is_available() 287 | if cuda_available: 288 | print('Cuda is available.') 289 | device = torch.device(f'cuda:{gpu_id}') 290 | else: 291 | print('Cuda is not available') 292 | device = torch.device('cpu') 293 | 294 | data_processor = DataProcessor( model_name=model_name, 295 | use_flash_attention_2=use_flash_attention_2, 296 | data_file=file_name, 297 | input_chunk_path=input_chunk_path, 298 | output_chunk_path=output_chunk_path, 299 | score_chunk_path=score_chunk_path, 300 | pic_dir=pic_dir, 301 | chunk_size=chunk_size, 302 | dlt_ppl_threshold=dlt_ppl_threshold, 303 | window_size=window_size, 304 | single_ppl_batch_size=single_ppl_batch_size, 305 | pair_ppl_batch_size=pair_ppl_batch_size, 306 | sample_size=sample_size, 307 | need_draw=need_draw, 308 | seed=seed, 309 | device=device) 310 | data_processor.process_data() 311 | print(f'[PID-{process_id}] {file_name} end!') 312 | return process_id 313 | 314 | def parse_config(): 315 | parser = argparse.ArgumentParser() 316 | 317 | # data parameter 318 | parser.add_argument('--data_file', type=str) 319 | parser.add_argument('--root_path', type=str) 320 | # lds parameter 321 | parser.add_argument('--chunk_size', type=int, default=128) 322 | parser.add_argument('--dlt_ppl_threshold', type=float, default=0.1) 323 | parser.add_argument('--window_size', type=int, default=32768) 324 | 325 | # model configuration 326 | parser.add_argument('--model_name', type=str, default='facebook/opt-350m') 327 | parser.add_argument('--use_flash_attention_2', action='store_true') 328 | 329 | # other 330 | parser.add_argument('--seed', type=int, default=11) 331 | 332 | parser.add_argument('--single_ppl_batch_size', type=int, default=256) 333 | parser.add_argument('--pair_ppl_batch_size', type=int, default=256) 334 | parser.add_argument('--sample_size', type=int, default=5000) 335 | parser.add_argument('--need_draw', action='store_true') 336 | 337 | parser.add_argument('--gpu_ids', nargs='+', type=int, default=[0,1,2,3,4,5,6,7]) 338 | 339 | return parser.parse_args() 340 | 341 | 342 | if __name__ == "__main__": 343 | # args 344 | args = parse_config() 345 | seed = args.seed 346 | # set seed 347 | set_seed(seed=seed) 348 | 349 | # model 350 | model_name = args.model_name 351 | use_flash_attention_2 = args.use_flash_attention_2 352 | 353 | # data path 354 | data_file = args.data_file 355 | root_path = args.root_path 356 | 357 | # output path 358 | input_chunk_path = f'{root_path}/raw' 359 | output_chunk_path = f'{root_path}/processed' 360 | save_file = f'{root_path}/merged' 361 | score_chunk_path = f'{root_path}/scored' 362 | pic_dir = f'{root_path}/pic' 363 | 364 | # create dir 365 | if not os.path.exists(input_chunk_path): 366 | os.makedirs(input_chunk_path) 367 | if not os.path.exists(output_chunk_path): 368 | os.makedirs(output_chunk_path) 369 | if not os.path.exists(save_file): 370 | os.makedirs(save_file) 371 | if not os.path.exists(score_chunk_path): 372 | os.makedirs(score_chunk_path) 373 | if not os.path.exists(pic_dir): 374 | os.makedirs(pic_dir) 375 | 376 | # hyper-parameters 377 | chunk_size = args.chunk_size 378 | dlt_ppl_threshold = args.dlt_ppl_threshold 379 | window_size = args.window_size 380 | single_ppl_batch_size = args.single_ppl_batch_size 381 | pair_ppl_batch_size = args.pair_ppl_batch_size 382 | sample_size = args.sample_size 383 | need_draw = args.need_draw 384 | gpu_ids = args.gpu_ids 385 | num_process = len(gpu_ids) 386 | 387 | with open(data_file, 'r', encoding='utf-8') as f: 388 | total_lines = sum(1 for _ in f) 389 | f.seek(0) 390 | if total_lines != -1: 391 | print(f"The file {data_file} has {total_lines} lines.") 392 | lines_per_file = total_lines // num_process if total_lines % num_process == 0 else total_lines // num_process + 1 393 | 394 | # split data into chunks 395 | split_jsonl(data_file, input_chunk_path, lines_per_file, total_lines) 396 | 397 | assert num_process < multiprocessing.cpu_count() 398 | pool = multiprocessing.Pool(processes=num_process) 399 | 400 | 401 | file_list = [file_name for file_name in os.listdir(input_chunk_path) if file_name.endswith('.jsonl')] 402 | # sort file_list 403 | file_list.sort(key=lambda x: int(x.split('.jsonl')[0][6:])) 404 | print(f'[INFO] {len(file_list)} files found in {input_chunk_path}') 405 | assert len(file_list) == num_process 406 | 407 | 408 | results = [] 409 | for idx, file_name in enumerate(file_list): 410 | gpu_id = gpu_ids[idx] 411 | results.append( 412 | pool.apply_async( 413 | process_single_chunk, 414 | (gpu_id, file_name, input_chunk_path, output_chunk_path, score_chunk_path, pic_dir, chunk_size, window_size, sample_size, single_ppl_batch_size, pair_ppl_batch_size, dlt_ppl_threshold, model_name, use_flash_attention_2, seed, need_draw))) 415 | 416 | pool.close() 417 | pool.join() 418 | 419 | results = [result.get() for result in results] 420 | 421 | score_output_file = os.path.join(save_file, 'scores.txt') 422 | score_file_list = [file_name for file_name in os.listdir(score_chunk_path) if file_name.endswith('.txt')] 423 | score_file_list.sort(key=lambda x: int(x.split('.txt')[0][6:])) 424 | count = merge(file_list=score_file_list, output_file=score_output_file, output_chunk_path=score_chunk_path) 425 | print(f'[INFO] {score_output_file}: {count} lines in total.') 426 | 427 | output_file = os.path.join(save_file, 'merged.jsonl') 428 | count = merge(file_list=file_list, output_file=output_file, output_chunk_path=output_chunk_path) 429 | print(f'[INFO] {output_file}: {count} lines in total.') 430 | -------------------------------------------------------------------------------- /prolong/run_batch_multinodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import subprocess 4 | 5 | def get_arguments(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--input-dir', type=str) 8 | parser.add_argument('--output-dir', type=str) 9 | parser.add_argument('--ip-hostfile', type=str) 10 | parser.add_argument('--file-list', type=str) 11 | 12 | # lds parameter 13 | parser.add_argument('--chunk_size', type=int, default=128) 14 | parser.add_argument('--dlt_ppl_threshold', type=float, default=0.1) 15 | parser.add_argument('--window_size', type=int, default=32768) 16 | 17 | # model configuration 18 | parser.add_argument('--model_name', type=str, default='facebook/opt-350m') 19 | parser.add_argument('--use_flash_attention_2', action='store_true') 20 | 21 | # other 22 | parser.add_argument('--seed', type=int, default=11) 23 | 24 | parser.add_argument('--single_ppl_batch_size', type=int, default=256) 25 | parser.add_argument('--pair_ppl_batch_size', type=int, default=256) 26 | parser.add_argument('--sample_size', type=int, default=500) 27 | parser.add_argument('--need_draw', action='store_true') 28 | 29 | parser.add_argument('--gpu_ids', nargs='+', type=int, default=[0,1,2,3,4,5,6,7]) 30 | 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | def get_ips(hostfile_path): 36 | """Read ips from hostfiles""" 37 | with open(hostfile_path, 'r', encoding='utf-8') as f: 38 | lines = f.readlines() 39 | ip_list = [] 40 | for line in lines: 41 | ip = line.strip().split(' ')[0] 42 | ip_list.append(ip) 43 | return ip_list 44 | 45 | def get_files(file_list): 46 | """Read files from file lists""" 47 | with open(file_list, 'r', encoding='utf-8') as f: 48 | lines = f.readlines() 49 | 50 | lines = [line.strip() for line in lines] 51 | return lines 52 | 53 | def get_cmd(args, filename, output_dir): 54 | path = os.path.abspath("run_batch_multi_process.py") 55 | cmd = "python {}".format(path) 56 | cmd += " --data_file {}".format(filename) 57 | cmd += " --root_path {}".format(output_dir) 58 | cmd += " --chunk_size {}".format(args.chunk_size) 59 | cmd += " --dlt_ppl_threshold {}".format(args.dlt_ppl_threshold) 60 | cmd += " --window_size {}".format(args.window_size) 61 | cmd += " --model_name {}".format(args.model_name) 62 | if args.use_flash_attention_2: 63 | cmd += " --use_flash_attention_2" 64 | cmd += " --seed {}".format(args.seed) 65 | cmd += " --single_ppl_batch_size {}".format(args.single_ppl_batch_size) 66 | cmd += " --pair_ppl_batch_size {}".format(args.pair_ppl_batch_size) 67 | cmd += " --sample_size {}".format(args.sample_size) 68 | if args.need_draw: 69 | cmd += " --need_draw" 70 | gpu_ids = ' '.join([str(gpu_id) for gpu_id in args.gpu_ids]) 71 | cmd += " --gpu_ids {}".format(gpu_ids) 72 | 73 | return cmd 74 | 75 | def main(): 76 | args = get_arguments() 77 | ip_list = get_ips(args.ip_hostfile) 78 | filenames = get_files(args.file_list) 79 | 80 | # get basename of filenames 81 | basename_list = [] 82 | for filename in filenames: 83 | basename = os.path.splitext(filename)[0] 84 | basename_list.append(basename) 85 | 86 | # output dir 87 | output_path_list = [] 88 | for basename in basename_list: 89 | output_path = os.path.join(args.output_dir, basename) 90 | output_path_list.append(output_path) 91 | os.makedirs(output_path, exist_ok=True) 92 | 93 | for idx, filename in enumerate(filenames): 94 | ip = ip_list[idx] 95 | input_path = os.path.join(args.input_dir, filename) 96 | # get input path lines number 97 | with open(input_path, 'r', encoding='utf-8') as fin: 98 | total_lines = sum(1 for _ in fin) 99 | fin.seek(0) 100 | if total_lines < 8: 101 | print(f"[Warning]: current gpu_ids: {args.gpu_ids}") 102 | args.gpu_ids = args.gpu_ids[:total_lines] 103 | print(f"[Warning]: {input_path} has less than 8 lines, change gpu_ids to {args.gpu_ids}") 104 | cmd = get_cmd(args, input_path, output_path_list[idx]) 105 | 106 | meta_cmd = 'pdsh -R ssh -w {} {} &'.format(ip, cmd) 107 | 108 | print(f"ip: {ip}, cmd: {cmd}") 109 | os.system(meta_cmd) 110 | 111 | 112 | if __name__ == '__main__': 113 | main() -------------------------------------------------------------------------------- /prolong/scripts/filelist: -------------------------------------------------------------------------------- 1 | file0.jsonl 2 | file1.jsonl 3 | file2.jsonl -------------------------------------------------------------------------------- /prolong/scripts/iphost: -------------------------------------------------------------------------------- 1 | ip1 slots=8 2 | ip2 slots=8 3 | ip3 slots=8 -------------------------------------------------------------------------------- /prolong/scripts/run_multinodes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ux 3 | 4 | INPUT_DIR= 5 | FILE_LIST=scripts/filelist # this will combine with input dir to get the full path of each file, eg. /file0.jsonl 6 | OUTPUT_DIR= 7 | IP_HOSTFILE=scripts/iphost 8 | MODEL_PATH= 9 | SAMPLE_SIZE=500 10 | CHUNK_SIZE=128 11 | WINDOW_SIZE=32768 12 | SINGLE_PPL_BATCH_SIZE=256 13 | PAIR_PPL_BATCH_SIZE=256 14 | SEED=11 15 | DLT_PPL_THRESHOLD=0.1 16 | 17 | python run_batch_multinodes.py \ 18 | --input-dir $INPUT_DIR \ 19 | --output-dir $OUTPUT_DIR \ 20 | --ip-hostfile $IP_HOSTFILE \ 21 | --file-list $FILE_LIST \ 22 | --model_name $MODEL_PATH \ 23 | --chunk_size $CHUNK_SIZE \ 24 | --window_size $WINDOW_SIZE \ 25 | --single_ppl_batch_size $SINGLE_PPL_BATCH_SIZE \ 26 | --pair_ppl_batch_size $PAIR_PPL_BATCH_SIZE \ 27 | --sample_size $SAMPLE_SIZE \ 28 | --use_flash_attention_2 \ 29 | --need_draw \ 30 | --seed $SEED \ 31 | --dlt_ppl_threshold $DLT_PPL_THRESHOLD -------------------------------------------------------------------------------- /prolong/scripts/run_multiprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ux 3 | 4 | # hyper-parameters 5 | MODEL_PATH=facebook/opt-350m 6 | CHUNK_SIZE=128 7 | WINDOW_SIZE=32768 8 | SINGLE_PPL_BATCH_SIZE=256 9 | PAIR_PPL_BATCH_SIZE=256 10 | SAMPLE_SIZE=500 11 | SEED=11 12 | DLT_PPL_THRESHOLD=0.1 13 | 14 | # output settings 15 | DATA_FILE_PATH= 16 | ROOT_PATH= 17 | 18 | python run_batch_multi_process.py \ 19 | --data_file $DATA_FILE_PATH \ 20 | --root_path $SAVE_FILE \ 21 | --model_name $MODEL_PATH \ 22 | --use_flash_attention_2 \ 23 | --chunk_size $CHUNK_SIZE \ 24 | --window_size $WINDOW_SIZE \ 25 | --single_ppl_batch_size $SINGLE_PPL_BATCH_SIZE \ 26 | --pair_ppl_batch_size $PAIR_PPL_BATCH_SIZE \ 27 | --sample_size $SAMPLE_SIZE \ 28 | --need_draw \ 29 | --seed $SEED \ 30 | --dlt_ppl_threshold $DLT_PPL_THRESHOLD \ 31 | --gpu_ids 0 1 2 3 32 | 33 | -------------------------------------------------------------------------------- /prolong/scripts/run_single.sh: -------------------------------------------------------------------------------- 1 | DATA_FILE_PATH= 2 | SAVE_PATH= 3 | SCORE_PATH= 4 | PIC_DIR= 5 | CHUNK_SIZE=128 6 | DLT_PPL_THRESHOLD=0.1 7 | WINDOW_SIZE=32768 8 | MODEL_PATH=facebook/opt-350m 9 | SEED=11 10 | SINGLE_PPL_BATH_SIZE=256 11 | PAIR_PPL_BATH_SIZE=256 12 | SAMPLE_SIZE=500 13 | 14 | python run.py \ 15 | --data_file $DATA_FILE_PATH \ 16 | --save_file $SAVE_PATH \ 17 | --score_file $SCORE_PATH \ 18 | --pic_dir $PIC_DIR \ 19 | --chunk_size $CHUNK_SIZE \ 20 | --dlt_ppl_threshold $DLT_PPL_THRESHOLD \ 21 | --window_size $WINDOW_SIZE \ 22 | --model_name $MODEL_PATH \ 23 | --seed $SEED \ 24 | --single_ppl_batch_size $SINGLE_PPL_BATH_SIZE \ 25 | --pair_ppl_batch_size $PAIR_PPL_BATH_SIZE \ 26 | --sample_size $SAMPLE_SIZE -------------------------------------------------------------------------------- /prolong/tools/sort_and_get.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datasets import load_dataset 3 | 4 | def sort_scores(in_filename, out_filename): 5 | with open(in_filename, 'r', encoding='utf-8') as f: 6 | scores = [(float(line.strip()), i) for i, line in enumerate(f)] 7 | scores.sort(reverse=True) 8 | 9 | scores_json = [{'Index': i, 'Score': score} for score, i in scores] 10 | 11 | with open(out_filename, 'w', encoding='utf-8') as f: 12 | for score in scores_json: 13 | f.write(json.dumps(score) + '\n') 14 | 15 | print('Done!') 16 | 17 | 18 | def get_dataset(dataset_path, index_path, output_path): 19 | ''' 20 | dataset_path: str, the jsonl path of the dataset 21 | e.g. {'text': 'This is a sentence.'} 22 | {'text': 'This is another sentence.'} 23 | index_path: str, the jsonl path of the index file 24 | e.g. {'Index': 1, 'Score': 1.5} 25 | {'Index': 0, 'Score': 0.8} 26 | output_path: str, the jsonl path of the output file 27 | e.g. {'text': 'This is another sentence.'} 28 | ''' 29 | dataset = load_dataset('text', data_files=dataset_path) 30 | with open(index_path, 'r', encoding='utf-8') as i: 31 | idx_data = i.readlines() 32 | 33 | idx_data = idx_data[:len(idx_data)//2] # choose top 50% TODO: use parameter to control the percentage 34 | index_list = [int(json.loads(idx)['Index']) for idx in idx_data] 35 | new_dataset = dataset['train'].select(index_list) 36 | 37 | with open(output_path, 'w', encoding='utf-8') as f: 38 | for row in new_dataset: 39 | text = row['text'] 40 | f.write(text + '\n') 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.1 2 | transformers==4.36.0 3 | tqdm==4.66.1 4 | numpy==1.22.2 5 | matplotlib==3.8.0 6 | datasets==2.15.0 --------------------------------------------------------------------------------