├── README.md ├── RuleRAG Rule-Guided Retrieval-Augmented Generation with Language Models for Question Answering.pdf ├── RuleRAG_eval_utils.py ├── RuleRAG_evaler.py ├── RuleRAG_inference.py ├── basic.py ├── config.py ├── data.png ├── framework.png ├── lab ├── constants.py ├── create.py ├── generate.py ├── model.py └── rotary.py ├── main_train.py ├── name.png ├── neox.py ├── paper.pdf └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | **RuleRAG: Rule-guided Retrieval-Augmented Generation with Language Models for Question Answering** 2 | ==== 3 | 4 | **** 5 | 6 | 7 | This is the official implementation repository of the paper, RuleRAG: Rule-guided retrieval-augmented generation with language models for question answering. 8 | 9 | **** 10 | ![image](framework.png) 11 | The framework of our proposed RuleRAG, including RuleRAG-ICL and RuleRAG-FT. RuleRAG-ICL relies on in-context learning (ICL) with the guidance of rules. RuleRAG-FT involves fine-tuning retrievers and generators ahead. 12 | 13 | 14 | **** 15 | 16 | Our constructed five RuleQA benchmarks are stored on [Google Drive](https://drive.google.com/drive/folders/13tbJS-Eq3Cswck3JRPU0LJIZ1Vrz3Bga?usp=sharing). 17 | 18 | ![image](data.png) 19 | The statistics of the five RuleQA benchmarks. 20 | 21 | 22 | **** 23 | 24 | Trian the retrievers of RuleRAG: We use the public implementation of the retrievers. The training data is stored on [Google Drive](https://drive.google.com/drive/folders/13tbJS-Eq3Cswck3JRPU0LJIZ1Vrz3Bga?usp=sharing) 25 | 26 | Train the generators of RuleRAG: main_train.py 27 | 28 | Inference of RuleRAG: RuleRAG_inference.py 29 | 30 | -------------------------------------------------------------------------------- /RuleRAG Rule-Guided Retrieval-Augmented Generation with Language Models for Question Answering.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenzhongwu20/RuleRAG_ICL_FT/79a4a7742af9e172516044a79a75a417c1089f91/RuleRAG Rule-Guided Retrieval-Augmented Generation with Language Models for Question Answering.pdf -------------------------------------------------------------------------------- /RuleRAG_eval_utils.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaForCausalLM, BitsAndBytesConfig 2 | from peft import PeftModel 3 | import torch 4 | # from vllm import LLM, SamplingParams 5 | # from vllm.lora.request import LoRARequest 6 | 7 | import argparse 8 | import csv 9 | import re 10 | from basic import get_file_extension 11 | 12 | def read_results(path_results, divider=', \n'): 13 | with open(path_results, 'r', encoding='utf-8') as file: 14 | content = file.read() # 15 | tests = content.split(', \n\n') 16 | return [re.split(divider, test) for test in tests] 17 | 18 | def read_num_and_li_results(test): 19 | pattern_end = r'(.+?)( \n|"]})' 20 | num_Test = re.search(r'\d+', test[0]).group() if re.search(r'\d+', test[0]) else "" 21 | li_results = [re.match(pattern_end, answer).group(1) if re.match(pattern_end, answer) else answer \ 22 | for answer in test[1:]] 23 | return num_Test, li_results 24 | 25 | def read_test_and_divide(path): 26 | with open(path, 'r', encoding='utf-8') as file: 27 | content = file.read() 28 | tests = content.split('\n\n') 29 | return tests 30 | 31 | def read_test_an(pth_ans, col=2): 32 | file_type = get_file_extension(pth_ans) 33 | test_ans = [] 34 | if file_type == ".csv": 35 | with open(pth_ans, "r", encoding='utf-8') as f: 36 | reader = csv.reader(f) 37 | test_ans = [row1[col] for row1 in reader] 38 | test_ans = test_ans[1:] # 39 | else: 40 | with open(pth_ans, "r", encoding='utf-8') as f: 41 | lines = f.readlines() 42 | for i in range(len(lines)): 43 | test_ans.append(lines[i].split('\t')[col]) 44 | 45 | return test_ans 46 | 47 | def read_test_ans_rel(pth_ans, col=2): 48 | file_type = get_file_extension(pth_ans) 49 | test_ans = [] 50 | relation_query = [] 51 | if file_type == ".csv": 52 | with open(pth_ans, "r", encoding='utf-8') as f: 53 | reader = csv.reader(f) 54 | test_ans = [row1[col] for row1 in reader] 55 | test_ans = test_ans[1:] # 56 | else: 57 | with open(pth_ans, "r", encoding='utf-8') as f: 58 | lines = f.readlines() 59 | for i in range(len(lines)): 60 | test_ans.append(lines[i].split('\t')[col]) 61 | relation_query.append(lines[i].split('\t')[1]) 62 | 63 | return test_ans,relation_query 64 | 65 | def read_last_metric(last_metric): 66 | if last_metric != '': 67 | with open(last_metric, 'r') as file: 68 | 69 | lines = file.readlines() 70 | 71 | last_c_k = lines[-4:-1] # [-5:-2] 72 | 73 | last_c_k = [int( line.strip() ) for line in last_c_k] 74 | 75 | c1 = int(last_c_k[0]) 76 | 77 | else: # 78 | c1 = 0 # 79 | return {"c1": c1} #c 80 | 81 | def decide_model(args): 82 | if args.BIT_8: # 83 | model = LlamaForCausalLM.from_pretrained( 84 | "/home/" 85 | "Llama-2-7B-fp16", 86 | trust_remote_code=True, 87 | ).cuda() 88 | elif args.BIT_4: 89 | quant_config = BitsAndBytesConfig( 90 | load_in_4bit=True, 91 | bnb_4bit_use_double_quant=True, 92 | bnb_4bit_quant_type="nf4", 93 | bnb_4bit_compute_dtype=torch.bfloat16 94 | ) 95 | model = LlamaForCausalLM.from_pretrained( 96 | "/home/", 97 | 98 | quantization_config=quant_config, 99 | device_map="auto", 100 | trust_remote_code=True, 101 | ) 102 | 103 | return PeftModel.from_pretrained(model, args.LORA_CHECKPOINT_DIR) 104 | 105 | 106 | 107 | def parse_args(): 108 | parser = argparse.ArgumentParser(description="Config of Llama2") 109 | 110 | parser.add_argument('--MODEL_NAME', type=str, default="", help='Model name') 111 | parser.add_argument('--LORA_CHECKPOINT_DIR', type=str, default="", help='Your Lora checkpoint') 112 | parser.add_argument('--CONTEXT_LEN', type=int, default=20000, help='Truncation length of context (in json)') 113 | parser.add_argument('--BIT_8', default=False, action="store_true", help='Use 8-bit') 114 | parser.add_argument('--BIT_4', default=True, action="store_true", help='Use 4-bit') 115 | parser.add_argument('--TEMPERATURE', type=int, default=0, help='Temperature when inference') 116 | parser.add_argument('--PROMPT', type=str, default="Input your prompt", help='Your prompt when inference') 117 | parser.add_argument('--input_file', type=str, default="", help='Your history_facts file') 118 | parser.add_argument('--output_file', type=str, default="", help='Output text prediction') 119 | parser.add_argument('--test_ans_file', type=str, default="", help='Your ground truth file') 120 | parser.add_argument('--fulltest', type=str, default="", help='fulltest with dense quadruples. For whole set filtering') 121 | parser.add_argument('--time2id', type=str, default="", help='time2id json file. For whole set filtering') 122 | parser.add_argument('--begin', type=int, default=0, help='Where to continue. default to -1') 123 | parser.add_argument('--max_gen_len', type=int, default=27, help='18 for Gdelt 27 for Yago; 27 as default') 124 | parser.add_argument('--max_seq_len', type=int, default=4096, help='4096 for llama2 icl; 30 as default') 125 | parser.add_argument('--last_metric', type=str, default="", help='Last metric result *file* when interrupted. ') 126 | parser.add_argument('--FILTER', type=int, default=1, help='Set 1 to filter multiple objects. ') 127 | parser.add_argument('--path_results', type=str, default="", help='Path of the result file to be filtered. ') 128 | parser.add_argument('--ft', type=int, default=1, help='Set 0: unfinetuned model. ') 129 | parser.add_argument('--local-rank', type=int, default=0, help='for torch.distributed.launch. ') 130 | parser.add_argument('--instruct_yes', type=int, default=0, help='Set 0 to give no instruction in the pompts. ') 131 | 132 | return parser.parse_args() -------------------------------------------------------------------------------- /RuleRAG_evaler.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList 4 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 5 | from transformers.generation.stopping_criteria import ( 6 | MaxLengthCriteria, 7 | MaxTimeCriteria, 8 | StoppingCriteria, 9 | StoppingCriteriaList, 10 | validate_stopping_criteria, 11 | ) 12 | import numpy as np 13 | import copy 14 | import inspect 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | import torch 19 | from torch import nn 20 | import torch.distributed as dist 21 | # from vllm import LLM, SamplingParams 22 | # from vllm.lora.request import LoRARequest 23 | from transformers.deepspeed import is_deepspeed_zero3_enabled 24 | from transformers.utils import logging 25 | from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint 26 | from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer 27 | from transformers.generation.configuration_utils import GenerationConfig 28 | from transformers.generation.utils import ( 29 | GreedySearchEncoderDecoderOutput, 30 | GreedySearchDecoderOnlyOutput, 31 | BeamSearchEncoderDecoderOutput, 32 | BeamSearchDecoderOnlyOutput, 33 | ) 34 | from transformers.generation.logits_process import LogitsProcessorList 35 | from tqdm import tqdm 36 | import time 37 | from data_utils.basic import read_txt_as_list, read_json 38 | from RuleRAG_eval_utils import read_results, read_num_and_li_results 39 | from collections import defaultdict, Counter 40 | import json 41 | import os 42 | import sys 43 | import time 44 | from pathlib import Path 45 | from typing import List, Literal, Optional, Tuple, TypedDict 46 | 47 | import torch 48 | import torch.nn.functional as F 49 | from fairscale.nn.model_parallel.initialize import ( 50 | get_model_parallel_rank, 51 | initialize_model_parallel, 52 | model_parallel_is_initialized, 53 | ) 54 | 55 | from neox import init_neox, text_generation 56 | from basic import blockPrinting 57 | import os 58 | 59 | 60 | Role = Literal["system", "user", "assistant"] 61 | 62 | if TYPE_CHECKING: 63 | from transformers.modeling_utils import PreTrainedModel 64 | from transformers.generation.streamers import BaseStreamer 65 | 66 | logger = logging.get_logger(__name__) 67 | m_F1=[] 68 | 69 | class Evaler: 70 | def __init__(self, topk, tests, test_ans, 71 | eval_txt_path, args, 72 | model=None, tokenizer=None, patterns=None, 73 | early_stop_chars=None, obligations=[]): 74 | 75 | self.llama = 1 76 | self.model = model 77 | self.tokenizer = tokenizer 78 | self.patterns = patterns 79 | self.tests = tests 80 | self.test_ans = test_ans 81 | self.eval_txt_path = eval_txt_path 82 | self.topk = topk 83 | 84 | self.args = args 85 | 86 | self.obligations = obligations 87 | self.constraints = [] 88 | self.zone_zero = early_stop_chars 89 | 90 | self.first_check = 0 91 | self.top = 1 92 | f_entity2id = open( 93 | '/home/', 94 | 'r') 95 | entity_json = f_entity2id.read() 96 | self.entity_set = json.loads(entity_json) 97 | 98 | def restrict_list_hard(self, tokens, prev_pos, min_prompt_len, input_text_mask, eos_reached, m=0): 99 | logits = self.model.forward(tokens[:, prev_pos:min_prompt_len], prev_pos) 100 | logits_last = logits[:, -1] 101 | top_10_indices = torch.topk(logits_last, k=logits.shape[-1], dim=-1).indices 102 | values_to_extract = [29900, 29896, 29906, 29941, 29946, 29945, 29953, 29955, 29947, 29929] 103 | top_10_indices_np = top_10_indices.cpu().numpy() 104 | mask = np.isin(top_10_indices_np, values_to_extract) 105 | extracted_elements = top_10_indices_np[mask][:10] 106 | # Convert back to Tensor type 107 | top_10_indices = torch.tensor(extracted_elements) 108 | 109 | next_token = top_10_indices[m] 110 | next_token = next_token.reshape(-1) 111 | 112 | next_token = torch.where( 113 | input_text_mask[:, min_prompt_len], tokens[:, min_prompt_len], next_token 114 | ) 115 | 116 | tokens[:, min_prompt_len] = next_token 117 | eos_reached |= (~input_text_mask[:, min_prompt_len]) & ( 118 | next_token == self.tokenizer.eos_id 119 | ) 120 | 121 | self.first_check = 1 # skip first_check 122 | return next_token, eos_reached 123 | 124 | def first_checking(self, next_tokens, next_tokens_scores): 125 | this_peer_finished = False 126 | if self.first_check == 0: # first check 127 | if self.obligations and (next_tokens not in self.obligations): 128 | this_peer_finished = True 129 | self.first_check = -1 # not begin with nums 130 | 131 | if self.constraints and (next_tokens in self.constraints): 132 | self.top += 1 133 | next_tokens = torch.argsort(next_tokens_scores, dim=-1, descending=True)[:, self.top - 1] 134 | self.constraints.append(next_tokens) 135 | self.first_check = -1 # breach of obligs 136 | else: 137 | self.constraints.append(next_tokens) 138 | self.first_check = 1 # check sign passed 139 | return this_peer_finished, next_tokens 140 | 141 | def gen_set_ans(self, tests='', dir_full_test='', dir_time2id=''): 142 | 143 | if tests == '': 144 | tests = self.tests 145 | dict_qu_ans = {} 146 | if dir_full_test == '': 147 | full_test_ans = self.test_ans 148 | for i in tqdm(range(0, len(tests) - 1)): 149 | try: 150 | query = tests[i].split('Question')[3] 151 | except: 152 | query = tests[i].split('Question')[2] 153 | if query == '': 154 | break 155 | if dict_qu_ans.get(query) == None: 156 | dict_qu_ans[query] = set() 157 | dict_qu_ans[query].add(full_test_ans[i]) # add answers to the set 158 | # time.sleep(0.001) 159 | else: 160 | dict_t2id = {} 161 | if dir_time2id != '': 162 | dict_t2id = read_json(dir_time2id) 163 | else: 164 | print("Attention: icews18 needs its ts2id file to convert time into time_id") 165 | fulltest = read_txt_as_list(dir_full_test) # only load essentially 166 | li_queries = [test.split('\n')[-1] for test in tests] 167 | # build sets 168 | for i in range(0, len(li_queries) - 1): 169 | query = li_queries[i] 170 | if query == '': 171 | break 172 | if dict_qu_ans.get(query) is None: 173 | dict_qu_ans[query] = set() 174 | end_time = li_queries[-3].split(':')[0] 175 | for line in fulltest: 176 | quadruple = line.strip().split('\t') 177 | time_quadruple = dict_t2id[quadruple[3]] if dir_time2id != '' else quadruple[3] 178 | if int(time_quadruple) > int(end_time): 179 | break 180 | built_query = f"{time_quadruple}: [{quadruple[0]}, {quadruple[1]}," 181 | if dict_qu_ans.get(built_query) is not None: 182 | dict_qu_ans[built_query].add(quadruple[2]) # add answers to the set 183 | print("duplicate answers checked") 184 | return dict_qu_ans 185 | 186 | def generate_extra_answers(self, m_inloop, k_inloop): 187 | if self.args.ft == 1: 188 | raw_answers, answer_regs = self.model_calling(m_inloop) # call for more generated ans 189 | elif self.llama == 1: # icl llama2 190 | answer_regs = self.text_completion(m_inloop, 191 | str(self.args.PROMPT), 192 | max_gen_len=self.args.max_gen_len, 193 | temperature=self.args.TEMPERATURE, 194 | # top_p=top_p, 195 | ) 196 | answer_regs = [answer_reg['generation'] for answer_reg in answer_regs] 197 | raw_answers = answer_regs 198 | else: # icl gpt neox 199 | raw_answers = text_generation(m_inloop, k_inloop, self.model, self.tokenizer, 200 | str(self.args.PROMPT), 201 | # icews14 28, icews18 34, ecola 18, GDELT 16, YAGO 25. 202 | max_seq_len=34, 203 | verbose=False) 204 | pattern = re.compile(r'\s*(\d+)\.(.*?)\]') 205 | answer_regs = re.match(pattern, raw_answers).group(2).strip() \ 206 | if re.match(pattern, raw_answers) else raw_answers 207 | answer_regs = [answer_regs] 208 | return raw_answers, answer_regs 209 | 210 | def my_generate_top10(self, model_instance, m, gen_length, **kwargs): 211 | # base_model = model_instance.base_model 212 | base_model = model_instance 213 | 214 | # original prepare_inputs_for_generation and generation_config 215 | original_prepare_inputs_for_generation = base_model.prepare_inputs_for_generation 216 | original_generation_config = getattr(base_model, "generation_config", None) 217 | 218 | # prepare_inputs_for_generation and generation_config 219 | base_model.prepare_inputs_for_generation = model_instance.prepare_inputs_for_generation 220 | if hasattr(base_model, "model"): 221 | base_model.model.generation_config = model_instance.generation_config 222 | else: 223 | base_model.generation_config = model_instance.generation_config 224 | 225 | try: 226 | # base_model generate_top10 227 | outputs = self.my_utils_generate_top10(base_model, m, gen_length, **kwargs) 228 | except Exception as e: 229 | # prepare_inputs_for_generation 230 | base_model.prepare_inputs_for_generation = original_prepare_inputs_for_generation 231 | if original_generation_config is not None: 232 | base_model.generation_config = original_generation_config 233 | raise e 234 | else: 235 | base_model.prepare_inputs_for_generation = original_prepare_inputs_for_generation 236 | # recover generation_config 237 | if original_generation_config is not None: 238 | base_model.generation_config = original_generation_config 239 | return outputs 240 | 241 | @torch.no_grad() 242 | def my_utils_generate_top10(self, 243 | model_instance, m, 244 | gen_length, 245 | inputs: Optional[torch.Tensor] = None, 246 | generation_config: Optional[GenerationConfig] = None, 247 | logits_processor: Optional[LogitsProcessorList] = None, 248 | stopping_criteria: Optional[StoppingCriteriaList] = None, 249 | # max_length=max_length, 250 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, 251 | synced_gpus: Optional[bool] = None, 252 | assistant_model: Optional["PreTrainedModel"] = None, 253 | streamer: Optional["BaseStreamer"] = None, 254 | **kwargs, 255 | ): 256 | 257 | if synced_gpus is None: 258 | if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: 259 | synced_gpus = True 260 | else: 261 | synced_gpus = False 262 | 263 | if generation_config is None: 264 | if model_instance.generation_config._from_model_config: 265 | new_generation_config = GenerationConfig.from_model_config(model_instance.config) 266 | if new_generation_config != model_instance.generation_config: 267 | warnings.warn( 268 | "You have modified the pretrained model configuration to control generation. This is a" 269 | " deprecated strategy to control generation and will be removed soon, in a future version." 270 | " Please use a generation configuration file (see" 271 | " https://huggingface.co/docs/transformers/main_classes/text_generation)" 272 | ) 273 | model_instance.generation_config = new_generation_config 274 | generation_config = model_instance.generation_config 275 | 276 | generation_config = copy.deepcopy(generation_config) 277 | model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs 278 | generation_config.validate() 279 | model_instance._validate_model_kwargs(model_kwargs.copy()) 280 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 281 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 282 | 283 | if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: 284 | if model_kwargs.get("attention_mask", None) is None: 285 | logger.warning( 286 | "The attention mask and the pad token id were not set. As a consequence, you may observe " 287 | "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." 288 | ) 289 | eos_token_id = generation_config.eos_token_id 290 | if isinstance(eos_token_id, list): 291 | eos_token_id = eos_token_id[0] 292 | logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") 293 | generation_config.pad_token_id = eos_token_id 294 | 295 | inputs_tensor, model_input_name, model_kwargs = model_instance._prepare_model_inputs( 296 | inputs, generation_config.bos_token_id, model_kwargs 297 | ) 298 | batch_size = inputs_tensor.shape[0] 299 | 300 | model_kwargs["output_attentions"] = generation_config.output_attentions 301 | model_kwargs["output_hidden_states"] = generation_config.output_hidden_states 302 | model_kwargs["use_cache"] = generation_config.use_cache 303 | 304 | accepts_attention_mask = "attention_mask" in set(inspect.signature(model_instance.forward).parameters.keys()) 305 | requires_attention_mask = "encoder_outputs" not in model_kwargs 306 | 307 | if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: 308 | model_kwargs["attention_mask"] = model_instance._prepare_attention_mask_for_generation( 309 | inputs_tensor, inputs_tensor, inputs_tensor 310 | ) 311 | 312 | if not model_instance.config.is_encoder_decoder: 313 | if ( 314 | generation_config.pad_token_id is not None 315 | and len(inputs_tensor.shape) == 2 316 | and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 317 | ): 318 | logger.warning( 319 | "A decoder-only architecture is being used, but right-padding was detected! For correct " 320 | "generation results, please set `padding_side='left'` when initializing the tokenizer." 321 | ) 322 | 323 | if model_instance.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: 324 | model_kwargs = model_instance._prepare_encoder_decoder_kwargs_for_generation( 325 | inputs_tensor, model_kwargs, model_input_name 326 | ) 327 | 328 | if model_instance.config.is_encoder_decoder: 329 | input_ids, model_kwargs = model_instance._prepare_decoder_input_ids_for_generation( 330 | batch_size=batch_size, 331 | model_input_name=model_input_name, 332 | model_kwargs=model_kwargs, 333 | decoder_start_token_id=generation_config.decoder_start_token_id, 334 | bos_token_id=generation_config.bos_token_id, 335 | device=inputs_tensor.device, 336 | ) 337 | else: 338 | input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") 339 | 340 | if streamer is not None: 341 | streamer.put(input_ids.cpu()) 342 | 343 | input_ids_seq_length = input_ids.shape[-1] 344 | has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None 345 | if has_default_max_length and generation_config.max_new_tokens is None: 346 | warnings.warn( 347 | f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " 348 | "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" 349 | " recommend using `max_new_tokens` to control the maximum length of the generation.", 350 | UserWarning, 351 | ) 352 | elif generation_config.max_new_tokens is not None: 353 | if not has_default_max_length: 354 | logger.warning( 355 | f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" 356 | f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " 357 | "Please refer to the documentation for more information. " 358 | "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" 359 | ) 360 | generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length 361 | 362 | if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: 363 | raise ValueError( 364 | f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" 365 | f" the maximum length ({generation_config.max_length})" 366 | ) 367 | if input_ids_seq_length >= generation_config.max_length: 368 | input_ids_string = "decoder_input_ids" if model_instance.config.is_encoder_decoder else "input_ids" 369 | logger.warning( 370 | f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" 371 | f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" 372 | " increasing `max_new_tokens`." 373 | ) 374 | 375 | is_constraint_gen_mode = ( 376 | generation_config.constraints is not None or generation_config.force_words_ids is not None 377 | ) 378 | 379 | is_contrastive_search_gen_mode = ( 380 | (generation_config.num_beams == 1) 381 | and generation_config.top_k is not None 382 | and generation_config.top_k > 1 383 | and generation_config.do_sample is False 384 | and generation_config.penalty_alpha is not None 385 | and generation_config.penalty_alpha > 0 386 | ) 387 | 388 | is_greedy_gen_mode = ( 389 | (generation_config.num_beams == 1) 390 | and (generation_config.num_beam_groups == 1) 391 | and generation_config.do_sample is False 392 | and not is_constraint_gen_mode 393 | and not is_contrastive_search_gen_mode 394 | ) 395 | is_sample_gen_mode = ( 396 | (generation_config.num_beams == 1) 397 | and (generation_config.num_beam_groups == 1) 398 | and generation_config.do_sample is True 399 | and not is_constraint_gen_mode 400 | and not is_contrastive_search_gen_mode 401 | ) 402 | is_beam_gen_mode = ( 403 | (generation_config.num_beams > 1) 404 | and (generation_config.num_beam_groups == 1) 405 | and generation_config.do_sample is False 406 | and not is_constraint_gen_mode 407 | and not is_contrastive_search_gen_mode 408 | ) 409 | is_beam_sample_gen_mode = ( 410 | (generation_config.num_beams > 1) 411 | and (generation_config.num_beam_groups == 1) 412 | and generation_config.do_sample is True 413 | and not is_constraint_gen_mode 414 | and not is_contrastive_search_gen_mode 415 | ) 416 | is_group_beam_gen_mode = ( 417 | (generation_config.num_beams > 1) 418 | and (generation_config.num_beam_groups > 1) 419 | and not is_constraint_gen_mode 420 | and not is_contrastive_search_gen_mode 421 | ) 422 | is_assisted_gen_mode = False 423 | if assistant_model is not None: 424 | if not (is_greedy_gen_mode or is_sample_gen_mode): 425 | raise ValueError( 426 | "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate " 427 | "is only supported with Greedy Search and Sample." 428 | ) 429 | is_assisted_gen_mode = True 430 | 431 | if generation_config.num_beam_groups > generation_config.num_beams: 432 | raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") 433 | if is_group_beam_gen_mode and generation_config.do_sample is True: 434 | raise ValueError( 435 | "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." 436 | ) 437 | 438 | if streamer is not None and (generation_config.num_beams > 1): 439 | raise ValueError( 440 | "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." 441 | ) 442 | 443 | if model_instance.device.type != input_ids.device.type: 444 | warnings.warn( 445 | "You are calling .generate() with the `input_ids` being on a device type different" 446 | f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" 447 | f" is on {model_instance.device.type}. You may experience unexpected behaviors or slower generation." 448 | " Please make sure that you have put `input_ids` to the" 449 | f" correct device by calling for example input_ids = input_ids.to('{model_instance.device.type}') before" 450 | " running `.generate()`.", 451 | UserWarning, 452 | ) 453 | 454 | logits_processor = model_instance._get_logits_processor( 455 | generation_config=generation_config, 456 | input_ids_seq_length=input_ids_seq_length, 457 | encoder_input_ids=inputs_tensor, 458 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 459 | logits_processor=logits_processor, 460 | ) 461 | 462 | stopping_criteria = model_instance._get_stopping_criteria( 463 | generation_config=generation_config, stopping_criteria=stopping_criteria 464 | ) 465 | if is_assisted_gen_mode: 466 | if generation_config.num_return_sequences > 1: 467 | raise ValueError( 468 | "num_return_sequences has to be 1 when doing assisted generate, " 469 | f"but is {generation_config.num_return_sequences}." 470 | ) 471 | if batch_size > 1: 472 | raise ValueError("assisted generate is only supported for batch_size = 1") 473 | if not model_kwargs["use_cache"]: 474 | raise ValueError("assisted generate requires `use_cache=True`") 475 | 476 | if assistant_model.config.is_encoder_decoder: 477 | assistant_model_kwargs = copy.deepcopy(model_kwargs) 478 | inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs( 479 | inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs 480 | ) 481 | assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( 482 | inputs_tensor, assistant_model_kwargs, model_input_name 483 | ) 484 | model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"] 485 | 486 | return model_instance.assisted_decoding( 487 | input_ids, 488 | assistant_model=assistant_model, 489 | do_sample=generation_config.do_sample, 490 | logits_processor=logits_processor, 491 | logits_warper=model_instance._get_logits_warper( 492 | generation_config) if generation_config.do_sample else None, 493 | stopping_criteria=stopping_criteria, 494 | pad_token_id=generation_config.pad_token_id, 495 | eos_token_id=generation_config.eos_token_id, 496 | output_scores=generation_config.output_scores, 497 | return_dict_in_generate=generation_config.return_dict_in_generate, 498 | synced_gpus=synced_gpus, 499 | streamer=streamer, 500 | **model_kwargs, 501 | ) 502 | if is_greedy_gen_mode: 503 | if generation_config.num_return_sequences > 1: 504 | raise ValueError( 505 | "num_return_sequences has to be 1 when doing greedy search, " 506 | f"but is {generation_config.num_return_sequences}." 507 | ) 508 | return self.my_utils_greedy_search_top10(model_instance, 509 | m, 510 | gen_length, 511 | input_ids, 512 | logits_processor=logits_processor, 513 | stopping_criteria=stopping_criteria, 514 | pad_token_id=generation_config.pad_token_id, 515 | eos_token_id=generation_config.eos_token_id, 516 | output_scores=generation_config.output_scores, 517 | return_dict_in_generate=generation_config.return_dict_in_generate, 518 | synced_gpus=synced_gpus, 519 | streamer=streamer, 520 | **model_kwargs, 521 | ) 522 | 523 | elif is_contrastive_search_gen_mode: 524 | if generation_config.num_return_sequences > 1: 525 | raise ValueError( 526 | "num_return_sequences has to be 1 when doing contrastive search, " 527 | f"but is {generation_config.num_return_sequences}." 528 | ) 529 | if not model_kwargs["use_cache"]: 530 | raise ValueError("Contrastive search requires `use_cache=True`") 531 | 532 | return model_instance.contrastive_search( 533 | input_ids, 534 | top_k=generation_config.top_k, 535 | penalty_alpha=generation_config.penalty_alpha, 536 | logits_processor=logits_processor, 537 | stopping_criteria=stopping_criteria, 538 | pad_token_id=generation_config.pad_token_id, 539 | eos_token_id=generation_config.eos_token_id, 540 | output_scores=generation_config.output_scores, 541 | return_dict_in_generate=generation_config.return_dict_in_generate, 542 | synced_gpus=synced_gpus, 543 | streamer=streamer, 544 | **model_kwargs, 545 | ) 546 | 547 | elif is_sample_gen_mode: 548 | logits_warper = model_instance._get_logits_warper(generation_config) 549 | 550 | input_ids, model_kwargs = model_instance._expand_inputs_for_generation( 551 | input_ids=input_ids, 552 | expand_size=generation_config.num_return_sequences, 553 | is_encoder_decoder=model_instance.config.is_encoder_decoder, 554 | **model_kwargs, 555 | ) 556 | 557 | return model_instance.sample( 558 | input_ids, 559 | logits_processor=logits_processor, 560 | logits_warper=logits_warper, 561 | stopping_criteria=stopping_criteria, 562 | pad_token_id=generation_config.pad_token_id, 563 | eos_token_id=generation_config.eos_token_id, 564 | output_scores=generation_config.output_scores, 565 | return_dict_in_generate=generation_config.return_dict_in_generate, 566 | synced_gpus=synced_gpus, 567 | streamer=streamer, 568 | **model_kwargs, 569 | ) 570 | 571 | elif is_beam_gen_mode: 572 | if generation_config.num_return_sequences > generation_config.num_beams: 573 | raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") 574 | 575 | if stopping_criteria.max_length is None: 576 | raise ValueError("`max_length` needs to be a stopping_criteria for now.") 577 | 578 | beam_scorer = BeamSearchScorer( 579 | batch_size=batch_size, 580 | num_beams=generation_config.num_beams, 581 | device=inputs_tensor.device, 582 | length_penalty=generation_config.length_penalty, 583 | do_early_stopping=generation_config.early_stopping, 584 | num_beam_hyps_to_keep=generation_config.num_return_sequences, 585 | max_length=generation_config.max_length, 586 | ) 587 | input_ids, model_kwargs = model_instance._expand_inputs_for_generation( 588 | input_ids=input_ids, 589 | expand_size=generation_config.num_beams, 590 | is_encoder_decoder=model_instance.config.is_encoder_decoder, 591 | **model_kwargs, 592 | ) 593 | return model_instance.beam_search( 594 | input_ids, 595 | beam_scorer, 596 | logits_processor=logits_processor, 597 | stopping_criteria=stopping_criteria, 598 | pad_token_id=generation_config.pad_token_id, 599 | eos_token_id=generation_config.eos_token_id, 600 | output_scores=generation_config.output_scores, 601 | return_dict_in_generate=generation_config.return_dict_in_generate, 602 | synced_gpus=synced_gpus, 603 | **model_kwargs, 604 | ) 605 | 606 | elif is_beam_sample_gen_mode: 607 | logits_warper = model_instance._get_logits_warper(generation_config) 608 | 609 | if stopping_criteria.max_length is None: 610 | raise ValueError("`max_length` needs to be a stopping_criteria for now.") 611 | beam_scorer = BeamSearchScorer( 612 | batch_size=batch_size * generation_config.num_return_sequences, 613 | num_beams=generation_config.num_beams, 614 | device=inputs_tensor.device, 615 | length_penalty=generation_config.length_penalty, 616 | do_early_stopping=generation_config.early_stopping, 617 | max_length=generation_config.max_length, 618 | ) 619 | 620 | input_ids, model_kwargs = model_instance._expand_inputs_for_generation( 621 | input_ids=input_ids, 622 | expand_size=generation_config.num_beams * generation_config.num_return_sequences, 623 | is_encoder_decoder=model_instance.config.is_encoder_decoder, 624 | **model_kwargs, 625 | ) 626 | 627 | return model_instance.beam_sample( 628 | input_ids, 629 | beam_scorer, 630 | logits_processor=logits_processor, 631 | logits_warper=logits_warper, 632 | stopping_criteria=stopping_criteria, 633 | pad_token_id=generation_config.pad_token_id, 634 | eos_token_id=generation_config.eos_token_id, 635 | output_scores=generation_config.output_scores, 636 | return_dict_in_generate=generation_config.return_dict_in_generate, 637 | synced_gpus=synced_gpus, 638 | **model_kwargs, 639 | ) 640 | 641 | elif is_group_beam_gen_mode: 642 | if generation_config.num_return_sequences > generation_config.num_beams: 643 | raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") 644 | 645 | if generation_config.num_beams % generation_config.num_beam_groups != 0: 646 | raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") 647 | 648 | if stopping_criteria.max_length is None: 649 | raise ValueError("`max_length` needs to be a stopping_criteria for now.") 650 | 651 | has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 652 | if not has_default_typical_p: 653 | raise ValueError("Decoder argument `typical_p` is not supported with beam groups.") 654 | 655 | beam_scorer = BeamSearchScorer( 656 | batch_size=batch_size, 657 | num_beams=generation_config.num_beams, 658 | device=inputs_tensor.device, 659 | length_penalty=generation_config.length_penalty, 660 | do_early_stopping=generation_config.early_stopping, 661 | num_beam_hyps_to_keep=generation_config.num_return_sequences, 662 | num_beam_groups=generation_config.num_beam_groups, 663 | max_length=generation_config.max_length, 664 | ) 665 | input_ids, model_kwargs = model_instance._expand_inputs_for_generation( 666 | input_ids=input_ids, 667 | expand_size=generation_config.num_beams, 668 | is_encoder_decoder=model_instance.config.is_encoder_decoder, 669 | **model_kwargs, 670 | ) 671 | return model_instance.group_beam_search( 672 | input_ids, 673 | beam_scorer, 674 | logits_processor=logits_processor, 675 | stopping_criteria=stopping_criteria, 676 | pad_token_id=generation_config.pad_token_id, 677 | eos_token_id=generation_config.eos_token_id, 678 | output_scores=generation_config.output_scores, 679 | return_dict_in_generate=generation_config.return_dict_in_generate, 680 | synced_gpus=synced_gpus, 681 | **model_kwargs, 682 | ) 683 | 684 | elif is_constraint_gen_mode: 685 | if generation_config.num_return_sequences > generation_config.num_beams: 686 | raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") 687 | 688 | if stopping_criteria.max_length is None: 689 | raise ValueError("`max_length` needs to be a stopping_criteria for now.") 690 | 691 | if generation_config.num_beams <= 1: 692 | raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.") 693 | 694 | if generation_config.do_sample: 695 | raise ValueError("`do_sample` needs to be false for constrained generation.") 696 | 697 | if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1: 698 | raise ValueError("`num_beam_groups` not supported yet for constrained generation.") 699 | 700 | final_constraints = [] 701 | if generation_config.constraints is not None: 702 | final_constraints = generation_config.constraints 703 | 704 | if generation_config.force_words_ids is not None: 705 | 706 | def typeerror(): 707 | raise ValueError( 708 | "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" 709 | f"of positive integers, but is {generation_config.force_words_ids}." 710 | ) 711 | 712 | if ( 713 | not isinstance(generation_config.force_words_ids, list) 714 | or len(generation_config.force_words_ids) == 0 715 | ): 716 | typeerror() 717 | 718 | for word_ids in generation_config.force_words_ids: 719 | if isinstance(word_ids[0], list): 720 | if not isinstance(word_ids, list) or len(word_ids) == 0: 721 | typeerror() 722 | if any(not isinstance(token_ids, list) for token_ids in word_ids): 723 | typeerror() 724 | if any( 725 | any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) 726 | for token_ids in word_ids 727 | ): 728 | typeerror() 729 | 730 | constraint = DisjunctiveConstraint(word_ids) 731 | else: 732 | if not isinstance(word_ids, list) or len(word_ids) == 0: 733 | typeerror() 734 | if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): 735 | typeerror() 736 | 737 | constraint = PhrasalConstraint(word_ids) 738 | final_constraints.append(constraint) 739 | 740 | constrained_beam_scorer = ConstrainedBeamSearchScorer( 741 | constraints=final_constraints, 742 | batch_size=batch_size, 743 | num_beams=generation_config.num_beams, 744 | device=inputs_tensor.device, 745 | length_penalty=generation_config.length_penalty, 746 | do_early_stopping=generation_config.early_stopping, 747 | num_beam_hyps_to_keep=generation_config.num_return_sequences, 748 | max_length=generation_config.max_length, 749 | ) 750 | input_ids, model_kwargs = model_instance._expand_inputs_for_generation( 751 | input_ids=input_ids, 752 | expand_size=generation_config.num_beams, 753 | is_encoder_decoder=model_instance.config.is_encoder_decoder, 754 | **model_kwargs, 755 | ) 756 | return model_instance.constrained_beam_search( 757 | input_ids, 758 | constrained_beam_scorer=constrained_beam_scorer, 759 | logits_processor=logits_processor, 760 | stopping_criteria=stopping_criteria, 761 | pad_token_id=generation_config.pad_token_id, 762 | eos_token_id=generation_config.eos_token_id, 763 | output_scores=generation_config.output_scores, 764 | return_dict_in_generate=generation_config.return_dict_in_generate, 765 | synced_gpus=synced_gpus, 766 | **model_kwargs, 767 | ) 768 | 769 | def my_utils_greedy_search_top10(self, 770 | model_instance, 771 | gen_length, 772 | input_ids: torch.LongTensor, 773 | logits_processor: Optional[LogitsProcessorList] = None, 774 | stopping_criteria: Optional[StoppingCriteriaList] = None, 775 | max_length: Optional[int] = None, 776 | pad_token_id: Optional[int] = None, 777 | eos_token_id: Optional[Union[int, List[int]]] = None, 778 | output_attentions: Optional[bool] = None, 779 | output_hidden_states: Optional[bool] = None, 780 | output_scores: Optional[bool] = None, 781 | return_dict_in_generate: Optional[bool] = None, 782 | synced_gpus: bool = False, 783 | streamer: Optional["BaseStreamer"] = None, 784 | **model_kwargs, 785 | ): 786 | 787 | # init values 788 | stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=gen_length + input_ids.shape[1])]) 789 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 790 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 791 | if max_length is not None: 792 | warnings.warn( 793 | "`max_length` is deprecated in this function, use" 794 | " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", 795 | UserWarning, 796 | ) 797 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) 798 | pad_token_id = pad_token_id if pad_token_id is not None else model_instance.generation_config.pad_token_id 799 | eos_token_id = eos_token_id if eos_token_id is not None else model_instance.generation_config.eos_token_id 800 | if isinstance(eos_token_id, int): 801 | eos_token_id = [eos_token_id] 802 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None 803 | output_scores = output_scores if output_scores is not None else model_instance.generation_config.output_scores 804 | output_attentions = ( 805 | output_attentions if output_attentions is not None else model_instance.generation_config.output_attentions 806 | ) 807 | output_hidden_states = ( 808 | output_hidden_states if output_hidden_states is not None else model_instance.generation_config.output_hidden_states 809 | ) 810 | return_dict_in_generate = ( 811 | return_dict_in_generate 812 | if return_dict_in_generate is not None 813 | else model_instance.generation_config.return_dict_in_generate 814 | ) 815 | 816 | scores = () if (return_dict_in_generate and output_scores) else None 817 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 818 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 819 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 820 | 821 | if return_dict_in_generate and model_instance.config.is_encoder_decoder: 822 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 823 | encoder_hidden_states = ( 824 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 825 | ) 826 | 827 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) 828 | 829 | this_peer_finished = False # used by synced_gpus only 830 | 831 | 832 | while True: 833 | if synced_gpus: 834 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 835 | 836 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 837 | 838 | if this_peer_finished_flag.item() == 0.0: 839 | break 840 | 841 | model_inputs = model_instance.prepare_inputs_for_generation(input_ids, **model_kwargs) 842 | 843 | outputs = model_instance( 844 | **model_inputs, 845 | return_dict=True, 846 | output_attentions=output_attentions, 847 | output_hidden_states=output_hidden_states, 848 | ) 849 | 850 | if synced_gpus and this_peer_finished: 851 | continue 852 | 853 | next_token_logits = outputs.logits[:, -1, :] 854 | 855 | next_tokens_scores = logits_processor(input_ids, next_token_logits) 856 | 857 | if return_dict_in_generate: 858 | if output_scores: 859 | scores += (next_tokens_scores,) 860 | if output_attentions: 861 | decoder_attentions += ( 862 | (outputs.decoder_attentions,) if model_instance.config.is_encoder_decoder else ( 863 | outputs.attentions,) 864 | ) 865 | if model_instance.config.is_encoder_decoder: 866 | cross_attentions += (outputs.cross_attentions,) 867 | 868 | if output_hidden_states: 869 | decoder_hidden_states += ( 870 | (outputs.decoder_hidden_states,) 871 | if model_instance.config.is_encoder_decoder 872 | else (outputs.hidden_states,) 873 | ) 874 | 875 | top_sign = self.top - 1 if self.first_check == 0 else 0 # first check, or to generate the rest 876 | next_tokens = torch.argsort(next_tokens_scores, dim=-1, descending=True)[:, top_sign] 877 | 878 | this_peer_finished, next_tokens = self.first_checking(next_tokens, next_tokens_scores) 879 | 880 | if next_tokens in self.zone_zero: 881 | this_peer_finished = True 882 | 883 | if eos_token_id is not None: 884 | if pad_token_id is None: 885 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 886 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 887 | 888 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 889 | if streamer is not None: 890 | streamer.put(next_tokens.cpu()) 891 | model_kwargs = model_instance._update_model_kwargs_for_generation( 892 | outputs, model_kwargs, is_encoder_decoder=model_instance.config.is_encoder_decoder 893 | ) 894 | 895 | if eos_token_id_tensor is not None: 896 | unfinished_sequences = unfinished_sequences.mul( 897 | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) 898 | ) 899 | 900 | if unfinished_sequences.max() == 0: 901 | this_peer_finished = True 902 | 903 | if stopping_criteria(input_ids, scores): 904 | this_peer_finished = True 905 | 906 | if this_peer_finished and not synced_gpus: 907 | break 908 | 909 | if streamer is not None: 910 | streamer.end() 911 | 912 | if return_dict_in_generate: 913 | if model_instance.config.is_encoder_decoder: 914 | return GreedySearchEncoderDecoderOutput( 915 | sequences=input_ids, 916 | scores=scores, 917 | encoder_attentions=encoder_attentions, 918 | encoder_hidden_states=encoder_hidden_states, 919 | decoder_attentions=decoder_attentions, 920 | cross_attentions=cross_attentions, 921 | decoder_hidden_states=decoder_hidden_states, 922 | ) 923 | else: 924 | return GreedySearchDecoderOnlyOutput( 925 | sequences=input_ids, 926 | scores=scores, 927 | attentions=decoder_attentions, 928 | hidden_states=decoder_hidden_states, 929 | ) 930 | else: 931 | return input_ids 932 | 933 | def require_first_to_be(self, next_tokens_scores, values_to_extract=[29871]): 934 | top_k_indices = torch.topk(next_tokens_scores, k=next_tokens_scores.shape[-1], dim=-1).indices 935 | top_k_indices_np = top_k_indices.cpu().numpy() 936 | mask = np.isin(top_k_indices_np, values_to_extract) 937 | top_k_indices = top_k_indices_np[mask][0] 938 | top_k_indices = torch.tensor(top_k_indices) 939 | 940 | next_tokens = top_k_indices.item() 941 | next_tokens = torch.tensor(next_tokens).reshape(-1) 942 | current_device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' 943 | return next_tokens.to(current_device), top_k_indices.to(current_device) 944 | 945 | 946 | def remove_str(self , entity): 947 | entity = entity.replace('_', ' ').replace('-', ' ') 948 | entity = entity.replace('(', ' ').replace(')', ' ') 949 | entity = entity.replace('\\', ' ').replace('.', ' ') 950 | entity = entity.replace('\"', ' ').replace('/', ' ') 951 | entity = entity.replace('\'', ' ').replace('&', ' ') 952 | entity = entity.replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') 953 | entity = entity.lower() 954 | return entity 955 | 956 | 957 | 958 | 959 | def tell_entity_name(self, answer_regs, query, m_inloop, filter_m_count): 960 | entity_ans = [] 961 | answer_regs = self.remove_str(answer_regs) 962 | for entity in self.entity_set: 963 | entity = self.remove_str(entity) 964 | if entity in answer_regs: 965 | entity_ans.append(entity) 966 | if len(entity_ans) == 1 and self.remove_str(entity_ans[0]) not in self.remove_str(query): 967 | return entity_ans[0], m_inloop, filter_m_count 968 | elif len(entity_ans) == 0: 969 | return "final", m_inloop, filter_m_count 970 | elif len(entity_ans) > 1: 971 | id_answ = 1000000000000000000 972 | for id in range(len(entity_ans)-1, -1, -1): 973 | current_position = answer_regs.find(entity_ans[id]) 974 | if current_position < id_answ and self.remove_str(entity_ans[id]) not in self.remove_str(query): 975 | id_answ = current_position 976 | id_right = id 977 | try: 978 | return entity_ans[id_right], m_inloop, filter_m_count 979 | except: 980 | return entity_ans[0], m_inloop, filter_m_count 981 | else: 982 | return "final", m_inloop , filter_m_count 983 | 984 | 985 | 986 | 987 | def model_calling(self, m_inloop, query, filter_m_count): 988 | ids = self.tokenizer.encode(self.args.PROMPT) 989 | input_ids = torch.LongTensor([ids]).to('cuda') 990 | self.first_check = 0 991 | out = self.my_generate_top10(model_instance=self.model, m=m_inloop, 992 | input_ids=input_ids, 993 | max_length=self.args.CONTEXT_LEN, 994 | gen_length=36, 995 | do_sample=False, 996 | ) 997 | out_text = self.tokenizer.decode(out[0]) 998 | answer = out_text.replace(self.args.PROMPT, "").replace("\nEND", "").strip() 999 | 1000 | answer = answer.replace("\n", "") 1001 | answer_regs, m_inloop, filter_m_count = self.tell_entity_name(answer, query, m_inloop, filter_m_count) 1002 | answer___ = [] 1003 | answer___.append(answer_regs) 1004 | return answer, answer___, m_inloop, filter_m_count 1005 | 1006 | def eval(self, c, cnt=0, path_results=None, filter_yes=True): 1007 | def clculate_f1(tok_pred, tok_gold): 1008 | if len(tok_gold) == 0: # do not generate anything 1009 | if len(tok_pred) == 0: 1010 | f1 = 1 1011 | m_F1.append(f1) 1012 | else: 1013 | f1 = 0 1014 | m_F1.append(f1) 1015 | else: 1016 | tok_gold_dict = Counter(tok_gold) 1017 | tok_pred_dict = Counter(tok_pred) 1018 | tokens = set([*tok_gold_dict] + [*tok_pred_dict]) 1019 | hit = 0 1020 | for token in tokens: 1021 | hit += min(tok_gold_dict.get(token, 0), tok_pred_dict.get(token, 0)) 1022 | p = hit / (sum(tok_pred_dict.values()) + 1e-10) 1023 | r = hit / (sum(tok_gold_dict.values()) + 1e-10) 1024 | F1 = 2 * p * r / (p + r + 1e-10) 1025 | m_F1.append(F1) 1026 | 1027 | query = "" 1028 | c1 = c["c1"] 1029 | 1030 | if path_results is not None: 1031 | test_results = read_results(path_results) 1032 | dict_qu_ans = self.gen_set_ans(dir_full_test=self.args.fulltest, dir_time2id=self.args.time2id) 1033 | set_checked_qu = set() 1034 | num_infer = len(self.tests) # 1035 | for i in tqdm(range(cnt, num_infer)): 1036 | his_query = self.tests[i] 1037 | try: 1038 | query = his_query.split('Question')[3] 1039 | except: 1040 | query = his_query.split('Question')[2] 1041 | val_trunc = -1 1042 | if len(his_query) - 1 > val_trunc and val_trunc != -1: 1043 | li_his_trunc = his_query.split('\n')[-val_trunc - 1:-1] # backward 1044 | li_his_trunc.append(query) 1045 | his_query = "\n".join(li_his_trunc) 1046 | 1047 | 1048 | 1049 | delete = False 1050 | if delete == True: 1051 | his_query = re.sub(r'\d+:\s', '', his_query) 1052 | 1053 | 1054 | ins = '''You must be able to correctly predict the next {object} from a given text consisting of multiple sentnences in the form of "At time {time} {subject} {relation} {object}." and the query in the form of "At time {time} what does {subject} {relation} ?" in the end. You must directly generate the missing {object}.\n''' 1055 | self.args.PROMPT = ins + his_query 1056 | 1057 | 1058 | if query not in set_checked_qu: 1059 | set_checked_qu.add(query) 1060 | hello = "For" 1061 | 1062 | else: 1063 | hello = "Duplicate query:" 1064 | print(hello, query) 1065 | if query == '': 1066 | continue 1067 | print("Given answers", dict_qu_ans[query], "with", self.test_ans[i], "as the gt") 1068 | 1069 | content_to_write = [] 1070 | content_to_write2 = [] 1071 | m_inloop = -1 1072 | filter_m_count = -1 1073 | k_inloop = 5 1074 | self.constraints = [] 1075 | self.top = 1 1076 | exist_num = 0 1077 | if path_results is not None: 1078 | num_Test, li_results = read_num_and_li_results(test_results[i]) 1079 | exist_num = len(li_results) 1080 | if int(num_Test) != i: 1081 | print(num_Test, i) 1082 | raise ValueError("Test id and i do not match.") 1083 | while m_inloop < k_inloop - 1 and m_inloop <= 5: 1084 | m_inloop += 1 1085 | filter_m_count += 1 1086 | with torch.no_grad(): 1087 | if path_results is None: 1088 | raw_ans, answer_regs, m_inloop, filter_m_count = self.model_calling(m_inloop, query, filter_m_count) 1089 | print(str(m_inloop) + "-th time, I would say, ", answer_regs) 1090 | else: 1091 | 1092 | if m_inloop >= exist_num: 1093 | if not filter_yes: 1094 | break 1095 | else: 1096 | print("call of duty") 1097 | raw_ans, answer_regs = self.generate_extra_answers(m_inloop, k_inloop) 1098 | print(str(m_inloop) + "-th time, I would say, ", answer_regs) 1099 | else: 1100 | raw_ans = answer_regs = [li_results[m_inloop]] 1101 | pattern = re.compile(r'.*?[\d:@][._](.*)\]') 1102 | answer_regs = [re.match(pattern, answer_regs[0]).group(2).strip()] \ 1103 | if re.match(pattern, answer_regs[0]) else answer_regs 1104 | print(str(m_inloop) + " read ", answer_regs) 1105 | self.top += 1 1106 | 1107 | content_to_write.append('\n' + str(answer_regs)) 1108 | content_to_write2.append('\n' + str(raw_ans)) 1109 | 1110 | bingo = False 1111 | dict_qu_ans_lower = [self.remove_str(ans).lower() for ans in dict_qu_ans[query]] 1112 | for answer in answer_regs: 1113 | answerlow = answer.lower() 1114 | 1115 | gtlow = self.test_ans[i].lower() 1116 | clculate_f1(answerlow, gtlow) 1117 | if answer == '': 1118 | content_to_write.append("(none string; removed)") 1119 | k_inloop += 1 1120 | filter_m_count -= 1 1121 | print("increased k: " + str(k_inloop)) 1122 | break 1123 | if ( 1124 | self.remove_str(answerlow) != self.remove_str(gtlow) and answerlow in dict_qu_ans_lower) and filter_yes: # first_check = -1 if to check breach of obligation 1125 | print("Got another answer: " + answer) 1126 | bingo = True 1127 | if filter_m_count == 0: 1128 | c1 += 1 1129 | print("Bingo! Line: ", i, "count after filtering: ", filter_m_count + 1, "all count: ", \ 1130 | m_inloop + 1, "answer: ", answer, "gt: ", self.test_ans[i]) 1131 | break 1132 | elif self.remove_str(answerlow) == self.remove_str(gtlow): 1133 | bingo = True 1134 | if filter_m_count == 0: 1135 | c1 += 1 1136 | print("Bingo! Line: ", i, "count after filtering: ", filter_m_count + 1, "all count: ", \ 1137 | m_inloop + 1, "answer: ", answer, "gt: ", self.test_ans[i]) 1138 | break 1139 | 1140 | if bingo: 1141 | break 1142 | hits_1 = c1 / (i + 1) 1143 | 1144 | with open(self.eval_txt_path, "a", encoding="utf-8") as fout: 1145 | if self.args.ft == 1: 1146 | fout.write('current model: ' + self.args.LORA_CHECKPOINT_DIR + ', \n') 1147 | else: 1148 | fout.write('current model: ' + self.args.MODEL_NAME + ', \n') 1149 | fout.write(self.args.output_file + ' currently finished: ' + str(i + 1) + '; results: \n') 1150 | fout.write("Hits@1: " + str(round(hits_1, 3)) + "\n") 1151 | fout.write(str(c1) + "\n") 1152 | fout.write("F1" + "\n") 1153 | F111 = sum(m_F1)/len(m_F1) 1154 | fout.write( str(F111) + "\n") 1155 | 1156 | with open(self.args.output_file, 'a', encoding='utf-8') as f: 1157 | f.write('{"Test' + str(i) + '": ["' + ', '.join(content_to_write) + '"]}, \n\n') 1158 | with open(self.args.output_file.replace(".txt", "_raw.txt"), 'a', encoding='utf-8') as f: 1159 | f.write('{"Test' + str(i) + '": ["' + ', '.join(content_to_write2) + '"]}, \n\n') 1160 | 1161 | print('processing: ' + self.args.output_file, i + 1) 1162 | time.sleep(0.001) -------------------------------------------------------------------------------- /RuleRAG_inference.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaTokenizer 2 | import torch 3 | import re 4 | 5 | import os 6 | from RuleRAG_evaler import Evaler 7 | from RuleRAG_eval_utils import parse_args, read_test_and_divide, read_test_an, read_last_metric, decide_model 8 | import os 9 | torch.backends.cudnn.enabled = False 10 | if __name__ == "__main__": 11 | 12 | args = parse_args() 13 | 14 | eval_txt_path = args.output_file[:-4]+'_metric_results.txt' 15 | 16 | test_ans = read_test_an(args.test_ans_file) 17 | model = decide_model(args) 18 | tokenizer = LlamaTokenizer.from_pretrained("/home/", trust_remote_code=True) 19 | 20 | tests = read_test_and_divide(args.input_file) 21 | 22 | c = read_last_metric(args.last_metric) 23 | 24 | pattern1 = re.compile(r'.*?[\d:@][._](.*?)[\]\[]?([< ].*?)?$') 25 | pattern2 = re.compile(r' .*?[\n]?([A-Z\u00C0-\u00DD\u0388-\u03AB\u0410-\u042F\u0600-\u06FF\u4e00-\u9fa5].*)\]') 26 | pattern3 = re.compile(r' *(.*)\]') 27 | is_with_id = True 28 | if is_with_id: 29 | patterns = [pattern1] 30 | else: 31 | patterns = [pattern1, pattern2, pattern3] 32 | topk= 10 33 | cnt = args.begin 34 | early_stop_chars = [] 35 | obligations = [] 36 | evaler = Evaler(topk, tests, test_ans, eval_txt_path, args, model, tokenizer, patterns, early_stop_chars, obligations) 37 | 38 | path_results = args.path_results 39 | path_results = os.path.normpath(path_results) 40 | 41 | if path_results != '.': 42 | evaler.eval(c, cnt, path_results) 43 | else: 44 | evaler.eval(c, cnt, filter_yes=args.FILTER) -------------------------------------------------------------------------------- /basic.py: -------------------------------------------------------------------------------- 1 | import json 2 | import csv 3 | import os 4 | import random 5 | from pathlib import Path 6 | import sys 7 | 8 | def flip_dict(original_dict): 9 | return {v: k for k, v in original_dict.items()} 10 | 11 | def str_dict(original_dict): 12 | return {str(k): str(v) for k, v in original_dict.items()} 13 | 14 | def blockPrinting(func): 15 | def func_wrapper(*args, **kwargs): 16 | # block all printing to the console 17 | sys.stdout = open(os.devnull, 'w') 18 | # call the method in question 19 | value = func(*args, **kwargs) 20 | # enable all printing to the console 21 | sys.stdout = sys.__stdout__ 22 | # pass the return value of the method back 23 | return value 24 | return func_wrapper 25 | 26 | def get_ins(): #ins for datasets; in json every " should be \" 27 | ins = '''You must be able to correctly predict the next {object} from a given text consisting of multiple sentnences in the form of "At time {time} {subject} {relation} {object}." and the query in the form of "At time {time} what does {subject} {relation} ?" in the end. You must directly generate the missing {object}.\n''' 28 | 29 | return ins 30 | 31 | def get_file_extension(file_path): 32 | _, extension = os.path.splitext(file_path) 33 | return extension 34 | 35 | def read_csv(csv_dir, col=None): 36 | with open(csv_dir, 'r', newline='', encoding='utf-8') as q: 37 | csv_data = csv.reader(q) 38 | if col is None: 39 | return [row for row in csv_data] 40 | else: 41 | return [row[col] for row in csv_data] 42 | 43 | def read_csv_as_dict(path_csv_file): 44 | csv_data = [] 45 | with open(path_csv_file, 'r', encoding='utf-8') as csv_file: 46 | csv_reader = csv.DictReader(csv_file) 47 | for row in csv_reader: 48 | csv_data.append({ 49 | "space": row["space"], 50 | "underscore": row["underscore"] 51 | }) 52 | return csv_data 53 | 54 | def read_json(json_dir): 55 | with open(json_dir, "r", encoding="utf-8") as f: 56 | json_data = json.load(f) 57 | return json_data 58 | 59 | def read_json_as_list(json_dir): 60 | return list(read_json(json_dir).keys()) 61 | 62 | def just_read_txt(path_txt): 63 | with open(path_txt) as file: 64 | content = file.read() 65 | return content 66 | 67 | def read_txt_as_list(path_txt): 68 | with open(path_txt, 'r', encoding='utf-8-sig') as file: 69 | data = file.readlines() 70 | return data 71 | 72 | def read_txt_as_index_dict(path_txt, divider='\t'): 73 | li_corres = [] 74 | li = read_txt_as_list(path_txt) 75 | for line in li: 76 | line_splited = line.strip().split(divider) 77 | li_corres.append({ 78 | "space": line_splited[0], 79 | "underscore": line_splited[1] 80 | }) 81 | return li_corres 82 | 83 | def write_txt(txt_dir, out_list, head='\t'): 84 | with open(txt_dir, 'w', encoding='utf-8') as txtfile: 85 | for sublist in out_list: 86 | txtfile.write(head.join(map(str, sublist)) + '\n') 87 | 88 | def write_dict(txt_dir, out_dict): 89 | with open(txt_dir, 'w', encoding='utf-8') as txtfile: 90 | for key, value in out_dict.items(): 91 | txtfile.write(f"{key}\t{value}\n") 92 | 93 | def write_csv(data, out_dict): 94 | with open(out_dict, 'w', encoding='utf-8') as out_dict: 95 | writer = csv.writer(out_dict) 96 | writer.writerow(['Column1','Column2','Column3','Column4']) 97 | for entry in data: 98 | out_dict.write(','.join(map(str, entry))+ '\n') 99 | 100 | def just_write_json(data, out, indent=4): 101 | with open(out, 'w', encoding='utf-8') as json_file: 102 | json.dump(data, json_file, indent=indent) 103 | 104 | def sample_dataset(dir_dataset, num_sample, is_json): 105 | path = Path(dir_dataset) 106 | with open(path, 'r', encoding='utf-8') as input_file: 107 | input_file = list(json.load(input_file)) if is_json else list(input_file) 108 | output_data = random.sample(input_file, num_sample) 109 | sampled_dir_name = str(path.parent / path.stem) + "_" + str(num_sample) + path.suffix 110 | with open(sampled_dir_name, 'w', encoding='utf-8') as output_file: 111 | if is_json: 112 | json.dump(output_data, output_file, indent=4) 113 | else: 114 | output_file.writelines([f"{item}" for item in output_data]) 115 | print("sampled", sampled_dir_name) 116 | return sampled_dir_name 117 | 118 | def create_folder_for_file(file_path): 119 | directory = os.path.dirname(file_path) 120 | os.makedirs(directory, exist_ok=True) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description="Config of Llama2-lora") 5 | parser.add_argument("--MICRO_BATCH_SIZE", type=int, default=1, help="Per device train batch size") 6 | parser.add_argument("--BATCH_SIZE", type=int, default=2, help="batch size") 7 | parser.add_argument('--EPOCHS', type=int, default=5, help='Training epochs') 8 | parser.add_argument('--WARMUP_STEPS', type=int, default=100, help='Warmup steps') 9 | parser.add_argument('--LEARNING_RATE', type=float, default= 3e-4 , help='Training learning rate') 10 | parser.add_argument('--CONTEXT_LEN', type=int, default=1000, help='Truncation length of context (in json)') 11 | parser.add_argument('--TARGET_LEN', type=int, default=36, help='Truncation length of target (in json)') 12 | parser.add_argument('--TEXT_LEN', type=int, default=256, help='Truncation length of text (in txt)') 13 | parser.add_argument('--LORA_R', type=int, default=4, help='Lora low rank') 14 | parser.add_argument('--LORA_ALPHA', type=int, default=16, help='Lora Alpha') 15 | parser.add_argument('--LORA_DROPOUT', type=float, default=0.05, help='Lora dropout') 16 | parser.add_argument('--MODEL_NAME', type=str, default="", help='Model name') 17 | parser.add_argument('--LOGGING_STEPS', type=int, default=10, help='Logging steps in training') 18 | parser.add_argument('--LOAD_BEST_MODEL_AT_END', type=int, default=0, help='set 1 to save the best checkpoint') 19 | parser.add_argument('--OUTPUT_DIR', type=str, default="", help='Output dir') 20 | parser.add_argument('--DATA_PATH', type=str, default="", help='Input dir of trainset') 21 | parser.add_argument('--DATA_TYPE', type=str, choices= ["json" , "txt"], default="json", help='Input trainsetfile type') 22 | 23 | parser.add_argument('--EVAL_STRATEGY', type=str, default="no", help='eval by the history fact file') 24 | parser.add_argument('--EVAL_BY_HF', type=int, default=1, help='set 1 to eval by the history fact file') 25 | parser.add_argument('--EVAL_PATH', type=str, default=None, help='Input dir of evalset') 26 | parser.add_argument('--EVAL_TYPE', type=str, choices= ["json" , "txt"], default="txt", help='Input evalset file type') 27 | parser.add_argument('--EVAL_STEPS', type=int, default=10, help='Eval the model according to steps') 28 | 29 | parser.add_argument('--SAVE_STEPS', type=int, default=1000, help='Save the model according to steps') 30 | parser.add_argument('--SAVE_TOTAL_LIMIT', type=int, default=None, help='The number of the checkpoint you will save (Excluding the final one)') 31 | parser.add_argument('--BIT_8', default=False, action="store_true", help='Use 8-bit') 32 | parser.add_argument('--BIT_4', default=False, action="store_true", help='Use 4-bit') 33 | parser.add_argument('--REPORT_TO', type=str, default=None, help='logging to e.g. wandb') 34 | parser.add_argument('--PROJ_NAME', type=str, default=None, help='Project name for e.g. wandb') 35 | parser.add_argument('--RUN_NAME', type=str, default=None, help='Run name for e.g. wandb') 36 | 37 | parser.add_argument('--W_RESUME', type=int, default=0, help='set 1 to enable WANDB_RESUME') 38 | parser.add_argument('--W_ID', type=str, default=0, help='set 1 to enable WANDB_RESUME') 39 | 40 | return parser.parse_args() -------------------------------------------------------------------------------- /data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenzhongwu20/RuleRAG_ICL_FT/79a4a7742af9e172516044a79a75a417c1089f91/data.png -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenzhongwu20/RuleRAG_ICL_FT/79a4a7742af9e172516044a79a75a417c1089f91/framework.png -------------------------------------------------------------------------------- /lab/constants.py: -------------------------------------------------------------------------------- 1 | class Args20b: 2 | vocab_size = 50432 3 | hidden_size = 6144 4 | num_attention_heads = 64 5 | rotary_pct = 0.25 6 | rotary_emb_base = 10000 7 | layernorm_epsilon = 1e-5 8 | num_layers = 44 9 | 10 | 11 | class ArgsDummy: 12 | vocab_size = 50432 13 | hidden_size = 64 14 | num_attention_heads = 4 15 | rotary_pct = 0.25 16 | rotary_emb_base = 10000 17 | layernorm_epsilon = 1e-5 18 | num_layers = 2 19 | -------------------------------------------------------------------------------- /lab/create.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import auto as tqdm_lib 3 | 4 | import torch 5 | import tokenizers 6 | 7 | import lab.model as model20b 8 | from lab.constants import Args20b, ArgsDummy 9 | 10 | 11 | def create_model(checkpoint_path, use_cache=False, device=torch.device("cuda:0")): 12 | 13 | pbar = tqdm_lib.tqdm(total=48) 14 | pbar.set_description("Instantiating model (~1 min)") 15 | model = model20b.NeoX20BModel(Args20b, use_cache=use_cache, device="meta") 16 | model = model.half().to_empty(device=device) # 17 | 18 | 19 | pbar.update(1) 20 | 21 | for layer_i in range(Args20b.num_layers): 22 | pbar.set_description(f"Loading layer {layer_i}") 23 | filename_tp1 = f"layer_{layer_i + 2:02d}-model_00-model_states.pt" 24 | filename_tp2 = f"layer_{layer_i + 2:02d}-model_01-model_states.pt" 25 | loaded_tp1 = torch.load(os.path.join(checkpoint_path, filename_tp1)) 26 | loaded_tp2 = torch.load(os.path.join(checkpoint_path, filename_tp2)) 27 | state_dict = {} 28 | for key in [ 29 | "attention.dense.weight", 30 | "mlp.dense_4h_to_h.weight", 31 | ]: 32 | state_dict[key] = torch.cat([loaded_tp1[key], loaded_tp2[key]], dim=1) 33 | state_dict["input_layernorm.weight"] = ( 34 | loaded_tp1["input_layernorm.weight"] + loaded_tp2["input_layernorm.weight"]) / 2 35 | state_dict["input_layernorm.bias"] = ( 36 | loaded_tp1["input_layernorm.bias"] + loaded_tp2["input_layernorm.bias"]) / 2 37 | state_dict["post_attention_layernorm.weight"] = ( 38 | loaded_tp1["post_attention_layernorm.weight"] + loaded_tp2["post_attention_layernorm.weight"]) / 2 39 | state_dict["post_attention_layernorm.bias"] = ( 40 | loaded_tp1["post_attention_layernorm.bias"] + loaded_tp2["post_attention_layernorm.bias"]) / 2 41 | # LinearWithTPMerge 42 | state_dict["mlp.dense_h_to_4h.weight"] = torch.cat([ 43 | loaded_tp1["mlp.dense_h_to_4h.weight"], 44 | loaded_tp2["mlp.dense_h_to_4h.weight"], 45 | ], dim=0) 46 | state_dict["mlp.dense_h_to_4h.bias"] = torch.cat([ 47 | loaded_tp1["mlp.dense_h_to_4h.bias"], 48 | loaded_tp2["mlp.dense_h_to_4h.bias"], 49 | ], dim=0) 50 | state_dict["attention.query_key_value.weight"] = torch.cat([ 51 | loaded_tp1["attention.query_key_value.weight"], 52 | loaded_tp2["attention.query_key_value.weight"], 53 | ], dim=0) 54 | state_dict["attention.query_key_value.bias"] = torch.cat([ 55 | loaded_tp1["attention.query_key_value.bias"], 56 | loaded_tp2["attention.query_key_value.bias"], 57 | ], dim=0) 58 | # LinearWithTPSplitBias 59 | state_dict["mlp.dense_4h_to_h.bias"] = ( 60 | loaded_tp1["mlp.dense_4h_to_h.bias"] 61 | + loaded_tp2["mlp.dense_4h_to_h.bias"] 62 | ) 63 | state_dict["attention.dense.bias"] = ( 64 | loaded_tp1["attention.dense.bias"] 65 | + loaded_tp2["attention.dense.bias"] 66 | ) 67 | # Just take one 68 | state_dict["attention.rotary_emb.inv_freq"] = loaded_tp1["attention.rotary_emb.inv_freq"] 69 | model.layer_list[layer_i].load_state_dict(state_dict) 70 | # model.module.layer_list[layer_i].load_state_dict(state_dict) 71 | del loaded_tp1 72 | del loaded_tp2 73 | pbar.update(1) 74 | 75 | # Load input embedding 76 | pbar.set_description(f"Loading input embedding") 77 | loaded_tp1 = torch.load(os.path.join(checkpoint_path, "layer_00-model_00-model_states.pt")) 78 | loaded_tp2 = torch.load(os.path.join(checkpoint_path, "layer_00-model_01-model_states.pt")) 79 | model.embed_in.load_state_dict({"weight": torch.cat([ 80 | loaded_tp1["word_embeddings.weight"], 81 | loaded_tp2["word_embeddings.weight"], 82 | ], dim=0)}) # default no.module 83 | del loaded_tp1 84 | del loaded_tp2 85 | pbar.update(1) 86 | 87 | # Load final layer norm 88 | pbar.set_description(f"Loading final layer norm") 89 | loaded_tp1 = torch.load(os.path.join(checkpoint_path, "layer_47-model_00-model_states.pt")) 90 | loaded_tp2 = torch.load(os.path.join(checkpoint_path, "layer_47-model_01-model_states.pt")) 91 | model.final_layer_norm.load_state_dict({ 92 | "weight": (loaded_tp1["norm.weight"] + loaded_tp2["norm.weight"])/2, 93 | "bias": (loaded_tp1["norm.bias"] + loaded_tp2["norm.bias"])/2, 94 | }) 95 | del loaded_tp1 96 | del loaded_tp2 97 | pbar.update(1) 98 | 99 | # Load output embedding 100 | pbar.set_description(f"Loading output embedding") 101 | loaded_tp1 = torch.load(os.path.join(checkpoint_path, "layer_48-model_00-model_states.pt")) 102 | loaded_tp2 = torch.load(os.path.join(checkpoint_path, "layer_48-model_01-model_states.pt")) 103 | model.logits_out.load_state_dict({ 104 | "weight": torch.cat([ 105 | loaded_tp1["final_linear.weight"], 106 | loaded_tp2["final_linear.weight"], 107 | ], dim=0), 108 | }) 109 | del loaded_tp1 110 | del loaded_tp2 111 | pbar.update(1) 112 | pbar.set_description("Done.") 113 | 114 | return model 115 | 116 | 117 | def create_dummy_model(use_cache=False, device=torch.device("cpu")): 118 | model = model20b.NeoX20BModel(ArgsDummy, use_cache=use_cache).half().to(device) 119 | return model 120 | 121 | 122 | def create_tokenizer(tokenizer_path): 123 | return tokenizers.Tokenizer.from_file(tokenizer_path) 124 | -------------------------------------------------------------------------------- /lab/generate.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from tqdm import auto as tqdm_lib 5 | 6 | 7 | 8 | class BreakOuterLoop(Exception): 9 | pass 10 | 11 | def greedy_generate(m, model: nn.Module, input_ids: torch.Tensor, max_seq_len: int, 12 | verbose=True): 13 | 14 | initial_input_length = input_ids.shape[1] 15 | current_input_ids = input_ids 16 | max_seq_len = initial_input_length + max_seq_len # It is enough to output only 30 more tokens 17 | layer_past = None 18 | layer_past_length = 0 19 | all_token_ids = input_ids.tolist() 20 | batch_size = len(all_token_ids) 21 | 22 | trange = range(initial_input_length, max_seq_len) 23 | 24 | input_length = current_input_ids.shape[1] 25 | model_out, layer_past = model( 26 | current_input_ids, 27 | layer_past=layer_past, 28 | ) 29 | 30 | top_10_indices = torch.topk(model_out[:, -1], k=10, dim=-1).indices 31 | greedy_predicted_token_ids = top_10_indices[:, m] # 32 | current_input_ids = greedy_predicted_token_ids[:, None] 33 | l = [] 34 | l.append(greedy_predicted_token_ids.item()) 35 | 36 | try: 37 | should_break = False #Initialize the flag variable to False 38 | for _ in trange: # Specify the iteration range appropriately 39 | input_length = current_input_ids.shape[1] 40 | model_out, layer_past = model( 41 | current_input_ids, 42 | layer_past=layer_past, 43 | ) 44 | 45 | greedy_predicted_token_ids = model_out[:, -1].argmax(-1) 46 | 47 | current_input_ids = greedy_predicted_token_ids[:, None] 48 | layer_past_length += input_length 49 | 50 | for i in range(batch_size): 51 | if greedy_predicted_token_ids[i].item() == 187: 52 | should_break = True 53 | raise BreakOuterLoop 54 | 55 | l.append(greedy_predicted_token_ids[i]) 56 | 57 | if should_break: 58 | break 59 | 60 | except BreakOuterLoop: 61 | pass 62 | 63 | return l 64 | 65 | 66 | def greedy_generate_text(m, model: nn.Module, 67 | tokenizer, 68 | initial_str: str, 69 | max_seq_len: int, 70 | device=torch.device("cuda:0"), 71 | verbose=True): 72 | 73 | tokenized = tokenizer.encode(initial_str) 74 | if len(tokenized.ids) > 2020: 75 | input_ids = torch.LongTensor([tokenized.ids[-2020:]]).to(device) 76 | else: 77 | input_ids = torch.LongTensor([tokenized.ids]).to(device) 78 | 79 | try: 80 | all_token_ids = greedy_generate(m, model=model, input_ids=input_ids, max_seq_len=max_seq_len, verbose=verbose) 81 | except BreakOuterLoop: 82 | pass 83 | 84 | decoded_str = tokenizer.decode(all_token_ids) 85 | if len(decoded_str)< 2: 86 | return '"#'+str(m)+'"' 87 | elif decoded_str[1].isdigit(): 88 | return tokenizer.decode(all_token_ids) 89 | else: 90 | return '"#'+str(m)+'"' 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /lab/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | import lab.rotary as rotary 6 | 7 | 8 | class NeoX20BModel(nn.Module): 9 | def __init__(self, args, use_cache=False, device=None): 10 | super().__init__() 11 | self.use_cache = use_cache 12 | self.embed_in = nn.Embedding(args.vocab_size, args.hidden_size, device=device) 13 | self.layer_list = nn.ModuleList([]) 14 | for layer_i in range(args.num_layers): 15 | self.layer_list.append(TransformerLayer(args, use_cache, device=device)) 16 | self.final_layer_norm = nn.LayerNorm( 17 | args.hidden_size, 18 | eps=args.layernorm_epsilon, 19 | device=device, 20 | ) 21 | self.logits_out = nn.Linear( 22 | args.hidden_size, 23 | args.vocab_size, 24 | bias=False, 25 | device=device, 26 | ) 27 | 28 | def forward(self, x, attention_mask=None, layer_past=None): 29 | if attention_mask is None: 30 | attention_mask = generate_mask(x.shape[1]).to(x.device) 31 | if self.use_cache: 32 | if layer_past is None: 33 | kv_length = x.shape[1] 34 | else: 35 | kv_length = layer_past[0].shape[1] + 1 36 | attention_mask = attention_mask[..., :x.shape[1], :kv_length] 37 | 38 | if layer_past is None: 39 | layer_past = [None] * len(self.layer_list) 40 | kv_cache_list = [] 41 | hidden_states = self.embed_in(x) 42 | hidden_states = self.pre_transformer_transpose(hidden_states) 43 | 44 | for layer_i, layer in enumerate(self.layer_list): 45 | hidden_states, kv_cache = layer( 46 | x=hidden_states, 47 | attention_mask=attention_mask, 48 | layer_past=layer_past[layer_i], 49 | ) 50 | kv_cache_list.append(kv_cache) 51 | hidden_states = self.post_transformer_transpose(hidden_states) 52 | hidden_states = self.final_layer_norm(hidden_states) 53 | logits = self.logits_out(hidden_states) 54 | if self.use_cache: 55 | return logits, kv_cache_list 56 | else: 57 | return logits 58 | 59 | @classmethod 60 | def pre_transformer_transpose(cls, x): 61 | return x.transpose(0, 1).contiguous() 62 | 63 | @classmethod 64 | def post_transformer_transpose(cls, x): 65 | return x.transpose(0, 1).contiguous() 66 | 67 | 68 | class TransformerLayer(nn.Module): 69 | def __init__(self, args, use_cache, device=None): 70 | super().__init__() 71 | self.use_cache = use_cache 72 | self.input_layernorm = nn.LayerNorm( 73 | args.hidden_size, 74 | eps=args.layernorm_epsilon, 75 | device=device, 76 | ) 77 | self.post_attention_layernorm = nn.LayerNorm( 78 | args.hidden_size, 79 | eps=args.layernorm_epsilon, 80 | device=device, 81 | ) 82 | self.attention = SelfAttention(args, self.use_cache, device=device) 83 | self.mlp = MLP(args) 84 | 85 | def forward(self, x, attention_mask, layer_past=None): 86 | residual = x 87 | ln_output = self.input_layernorm(x) 88 | attention_output, kv_cache = self.attention( 89 | ln_output, 90 | attention_mask, 91 | layer_past=layer_past, 92 | ) 93 | post_attn_ln = self.post_attention_layernorm(x) 94 | mlp_output = self.mlp(hidden_states=post_attn_ln) 95 | output = residual + mlp_output + attention_output 96 | return output, kv_cache 97 | 98 | 99 | class SelfAttention(nn.Module): 100 | def __init__(self, args, use_cache=False, device=None): 101 | super().__init__() 102 | self.hidden_size = args.hidden_size 103 | self.use_cache = use_cache 104 | self.num_attention_heads = args.num_attention_heads 105 | self.hidden_size_per_attention_head = args.hidden_size // args.num_attention_heads 106 | self.rotary_ndims = int(self.hidden_size_per_attention_head * args.rotary_pct) 107 | self.rotary_emb = rotary.RotaryEmbedding( 108 | self.rotary_ndims, 109 | base=args.rotary_emb_base, 110 | device=device, 111 | ) 112 | self.query_key_value = nn.Linear( 113 | args.hidden_size, 114 | 3 * args.hidden_size, 115 | device=device, 116 | ) 117 | self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) 118 | self.dense = nn.Linear( 119 | args.hidden_size, 120 | args.hidden_size, 121 | device=device, 122 | ) 123 | 124 | def forward(self, hidden_states, attention_mask, layer_past=None): 125 | has_layer_past = layer_past is not None and layer_past.numel() > 0 126 | 127 | # Compute QKV 128 | # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] 129 | qkv = self.query_key_value(hidden_states) 130 | 131 | # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] 132 | new_qkv_shape = qkv.size()[:-1] + ( 133 | self.num_attention_heads, 134 | 3 * self.hidden_size_per_attention_head, 135 | ) 136 | qkv = qkv.view(*new_qkv_shape) 137 | 138 | # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] 139 | query_layer = qkv[..., :self.hidden_size_per_attention_head] 140 | key_layer = qkv[..., self.hidden_size_per_attention_head: 2 * self.hidden_size_per_attention_head] 141 | value_layer = qkv[..., 2 * self.hidden_size_per_attention_head:] 142 | 143 | # Compute rotary embeddings 144 | query_rot, query_pass = ( 145 | query_layer[..., : self.rotary_ndims], 146 | query_layer[..., self.rotary_ndims:], 147 | ) 148 | key_rot, key_pass = ( 149 | key_layer[..., : self.rotary_ndims], 150 | key_layer[..., self.rotary_ndims:], 151 | ) 152 | seq_len = key_layer.shape[0] 153 | offset = 0 154 | if has_layer_past: 155 | offset = layer_past[0].shape[0] 156 | seq_len += offset 157 | cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) 158 | query_layer, key_layer = rotary.apply_rotary_pos_emb( 159 | query_rot, key_rot, cos, sin, offset=offset, 160 | ) 161 | query_layer = torch.cat((query_layer, query_pass), dim=-1) 162 | key_layer = torch.cat((key_layer, key_pass), dim=-1) 163 | 164 | # Cache QKV values 165 | if has_layer_past: 166 | past_key, past_value = layer_past 167 | key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) 168 | value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) 169 | if self.use_cache: 170 | kv_cache = torch.stack((key_layer, value_layer)) 171 | else: 172 | kv_cache = None 173 | 174 | # Compute attention 175 | # noinspection PyTypeChecker 176 | context_layer = self.attention( 177 | query_layer, key_layer, value_layer, attention_mask 178 | ) 179 | 180 | # Reshape outputs 181 | # [b, np, sq, hn] --> [sq, b, np, hn] 182 | context_layer = context_layer.permute(2, 0, 1, 3).contiguous() 183 | 184 | # [sq, b, np, hn] --> [sq, b, hp] 185 | new_context_layer_shape = context_layer.size()[:-2] + ( 186 | self.hidden_size, 187 | ) 188 | context_layer = context_layer.view(*new_context_layer_shape) 189 | 190 | # ================= 191 | # Output. [sq, b, h] 192 | # ================= 193 | output = self.dense(context_layer) 194 | 195 | return output, kv_cache 196 | 197 | def attention(self, query_layer, key_layer, value_layer, attention_mask): 198 | # =================================== 199 | # Raw attention scores. [b, np, s, s] 200 | # =================================== 201 | 202 | # [b, np, sq, sk] 203 | output_size = ( 204 | query_layer.size(1), 205 | query_layer.size(2), 206 | query_layer.size(0), 207 | key_layer.size(0), 208 | ) 209 | 210 | # [sq, b, np, hn] -> [sq, b * np, hn] 211 | query_layer = query_layer.view( 212 | output_size[2], output_size[0] * output_size[1], -1 213 | ) 214 | key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) 215 | 216 | # preallocating result tensor: [b * np, sq, sk] 217 | matmul_result = torch.empty( 218 | output_size[0] * output_size[1], 219 | output_size[2], 220 | output_size[3], 221 | dtype=query_layer.dtype, 222 | device=query_layer.device, 223 | ) 224 | 225 | # Raw attention scores. [b * np, sq, sk] 226 | matmul_result = torch.baddbmm( 227 | matmul_result, 228 | query_layer.transpose(0, 1), # [b * np, sq, hn] 229 | key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] 230 | beta=0.0, 231 | alpha=(1.0 / self.norm_factor), 232 | ) 233 | 234 | # change view to [b, np, sq, sk] 235 | attention_scores = matmul_result.view(*output_size) 236 | 237 | # ================================================== 238 | # Update attention mask for inference. [b, np, sq, sk] 239 | # ================================================== 240 | 241 | # =========================== 242 | # Attention probs and dropout 243 | # =========================== 244 | 245 | # attention scores and attention mask [b, np, sq, sk] 246 | masked_scores = attention_mask_func(attention_scores, attention_mask) \ 247 | if attention_mask is not None else attention_scores 248 | attention_probs = torch.nn.Softmax(dim=-1)(masked_scores) 249 | 250 | # # This is actually dropping out entire tokens to attend to, which might 251 | # # seem a bit unusual, but is taken from the original Transformer paper. 252 | # attention_probs = self.attention_dropout(attention_probs) 253 | 254 | # ========================= 255 | # Context layer. [sq, b, hp] 256 | # ========================= 257 | 258 | # value_layer -> context layer. 259 | # [sk, b, np, hn] --> [b, np, sq, hn] 260 | 261 | # context layer shape: [b, np, sq, hn] 262 | output_size = ( 263 | value_layer.size(1), 264 | value_layer.size(2), 265 | query_layer.size(0), 266 | value_layer.size(3), 267 | ) 268 | 269 | # change view [sk, b * np, hn] 270 | value_layer = value_layer.view( 271 | value_layer.size(0), output_size[0] * output_size[1], -1 272 | ) 273 | 274 | # change view [b * np, sq, sk] 275 | attention_probs = attention_probs.view( 276 | output_size[0] * output_size[1], output_size[2], -1 277 | ) 278 | 279 | # matmul: [b * np, sq, hn] 280 | context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) 281 | 282 | # change view [b, np, sq, hn] 283 | context_layer = context_layer.view(*output_size) 284 | return context_layer 285 | 286 | 287 | class MLP(nn.Module): 288 | def __init__(self, args, device=None): 289 | super().__init__() 290 | ff_dim = 4 * args.hidden_size 291 | self.dense_h_to_4h = nn.Linear(args.hidden_size, ff_dim, device=device) 292 | self.dense_4h_to_h = nn.Linear(ff_dim, args.hidden_size, device=device) 293 | 294 | def forward(self, hidden_states): 295 | intermediate_parallel = self.dense_h_to_4h(hidden_states) 296 | intermediate_parallel = bias_gelu_impl(intermediate_parallel) 297 | output = self.dense_4h_to_h(intermediate_parallel) 298 | return output 299 | 300 | 301 | # noinspection PyAbstractClass 302 | class GeLUFunction(torch.autograd.Function): 303 | # noinspection PyMethodOverriding 304 | @staticmethod 305 | # bias is an optional argument 306 | def forward(ctx, inputs): 307 | ctx.save_for_backward(inputs) 308 | return gelu(inputs) 309 | 310 | # noinspection PyMethodOverriding 311 | @staticmethod 312 | def backward(ctx, grad_output): 313 | inputs = ctx.saved_tensors 314 | tmp = gelu_back(grad_output, inputs) 315 | return tmp, tmp 316 | 317 | 318 | bias_gelu_impl = GeLUFunction.apply 319 | 320 | 321 | def generate_mask(seq_len): 322 | return torch.tril(torch.ones((1, 1, seq_len, seq_len), dtype=torch.bool)) 323 | 324 | 325 | def attention_mask_func(attention_scores, ltor_mask): 326 | """Assign -10000.0 to False cells in ltor_mask""" 327 | attention_scores.masked_fill_(~ltor_mask, -10000.0) 328 | return attention_scores 329 | 330 | 331 | @torch.jit.script 332 | def gelu(x): 333 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) 334 | 335 | 336 | # gradient of tanh approximation of gelu 337 | # gradient of actual gelu is: 338 | # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) 339 | @torch.jit.script 340 | def gelu_back(g, x): 341 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 342 | # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 343 | ff = 0.5 * x * ( 344 | (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) 345 | ) + 0.5 * (1 + tanh_out) 346 | return ff * g 347 | -------------------------------------------------------------------------------- /lab/rotary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class RotaryEmbedding(torch.nn.Module): 5 | 6 | def __init__(self, dim, base=10000, device=None): 7 | super().__init__() 8 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 9 | self.register_buffer('inv_freq', inv_freq) 10 | # Delay initialization until first forward call, because initial model on the 'meta' device 11 | self.cos_cached = None 12 | self.sin_cached = None 13 | 14 | def forward(self, x, seq_dim=1, seq_len=None): 15 | if seq_len is None: 16 | seq_len = x.shape[seq_dim] 17 | if self.cos_cached is None: 18 | t = torch.arange(2048, device=x.device, dtype=self.inv_freq.dtype) 19 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 20 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 21 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 22 | # [sx, 1 (b * np), hn] 23 | self.cos_cached = emb.cos()[:, None, None, :] 24 | self.sin_cached = emb.sin()[:, None, None, :] 25 | return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] 26 | 27 | 28 | def rotate_half(x): 29 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 30 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions 31 | 32 | 33 | # @torch.jit.script 34 | def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): 35 | cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] 36 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 37 | -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | from config import parse_args 2 | from utils import get_lora_config, get_model_and_tokenizer, process_data, get_trainer, get_dataset 3 | import wandb 4 | import re 5 | from datetime import datetime 6 | import os 7 | import os 8 | import os 9 | os.environ["WANDB_DISABLED"]="true" 10 | 11 | def set_wandb_para(proj_name): 12 | os.environ["WANDB_API_KEY"] = '' 13 | os.environ["WANDB_PROJECT"] = proj_name 14 | os.environ["WANDB_LOG_MODEL"] = "checkpoint" 15 | os.environ["WANDB_RESUME"] = "allow" 16 | os.environ["WANDB__SERVICE_WAIT"] = "300" 17 | os.environ["WANDB_MODE"] = "offline" 18 | 19 | def prepair_data(args, tokenizer): 20 | if args.EVAL_PATH is not None and args.EVAL_STRATEGY != "no": 21 | train_dataset = get_dataset(args.DATA_TYPE, args.DATA_PATH) 22 | if args.EVAL_BY_HF: 23 | with open(args.EVAL_PATH, 'r', encoding='utf-8') as file: 24 | content = file.read() 25 | eval_dataset = content.split('\n\n') 26 | eval_dataset = get_dataset(args.EVAL_TYPE, eval_dataset) 27 | else: 28 | eval_dataset = get_dataset(args.EVAL_TYPE, args.EVAL_PATH) 29 | train_data = process_data(args, tokenizer, args.DATA_TYPE, train_dataset) 30 | eval_data = process_data(args, tokenizer, args.EVAL_TYPE, eval_dataset) 31 | data = {"train": train_data, "test": eval_data} 32 | else: 33 | dataset = get_dataset(args.DATA_TYPE, args.DATA_PATH) 34 | data = process_data(args, tokenizer, args.DATA_TYPE, dataset) 35 | return data 36 | 37 | def wandb_resume(id, outdir): 38 | run = wandb.init(entity="tkg_forecaster", project=os.environ["WANDB_PROJECT"], id=id, resume="must") 39 | trainer = get_trainer(args, model, data, tokenizer) 40 | run_name = run.name 41 | ckpt_name = f"checkpoint-{run_name}:latest" 42 | ckpt_artifact = run.use_artifact(ckpt_name) 43 | ckpt_dir = ckpt_artifact.download() #get ckpt from server wandb 44 | trainer.train(resume_from_checkpoint=ckpt_dir) 45 | print("Resumed wandb run: ", str(run.id)) 46 | trainer.save_model(outdir + "/model_final") 47 | 48 | 49 | def generate_run_name(outdir, time=None): 50 | match = re.search(r'([^\/]+)$', outdir) 51 | filename = match.group(1) if match else "" 52 | run_name = filename + "_" 53 | if time is None: 54 | run_name = run_name + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 55 | else: 56 | run_name = run_name + time 57 | return run_name 58 | 59 | if __name__ == "__main__": 60 | args = parse_args() 61 | 62 | print("load_best_model_at_end is set as", args.LOAD_BEST_MODEL_AT_END) 63 | 64 | lora_config = get_lora_config(args) 65 | model, tokenizer = get_model_and_tokenizer(args, lora_config) 66 | 67 | data = prepair_data(args, tokenizer) 68 | if args.W_RESUME == 1: 69 | wandb_resume(args.W_ID, args.OUTPUT_DIR) 70 | else: 71 | trainer = get_trainer(args, model, data, tokenizer) 72 | trainer.train(resume_from_checkpoint=False) 73 | trainer.save_model(args.OUTPUT_DIR + "/model_final") 74 | -------------------------------------------------------------------------------- /name.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenzhongwu20/RuleRAG_ICL_FT/79a4a7742af9e172516044a79a75a417c1089f91/name.png -------------------------------------------------------------------------------- /neox.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | from lab import create_model, create_tokenizer 5 | import glob 6 | 7 | class BreakOuterLoop(Exception): 8 | pass 9 | 10 | def init_neox(path_model, 11 | ): 12 | model_meta_directory = os.path.dirname(path_model) 13 | path_tokenizer = glob.glob(model_meta_directory+'/*tokenizer.json')[0] 14 | model = create_model(path_model, use_cache=True,device='cuda:0',) 15 | tokenizer = create_tokenizer(path_tokenizer) 16 | return model, tokenizer 17 | 18 | 19 | def greedy_generate(m, k, model: nn.Module, input_ids: torch.Tensor, max_seq_len: int, 20 | verbose=True): 21 | 22 | initial_input_length = input_ids.shape[1] 23 | current_input_ids = input_ids 24 | max_seq_len = initial_input_length + max_seq_len # It is enough to output only 30 more tokens 25 | layer_past = None 26 | layer_past_length = 0 27 | all_token_ids = input_ids.tolist() 28 | batch_size = len(all_token_ids) 29 | 30 | trange = range(initial_input_length, max_seq_len) 31 | 32 | input_length = current_input_ids.shape[1] 33 | model_out, layer_past = model( 34 | current_input_ids, 35 | layer_past=layer_past, 36 | ) 37 | 38 | top_10_indices = torch.topk(model_out[:, -1], k=k, dim=-1).indices 39 | greedy_predicted_token_ids = top_10_indices[:, m] # 40 | current_input_ids = greedy_predicted_token_ids[:, None] 41 | l = [] 42 | l.append(greedy_predicted_token_ids.item()) 43 | 44 | try: 45 | should_break = False # Initialize flag variable to False 46 | for _ in trange: # Specify the iteration range appropriately 47 | input_length = current_input_ids.shape[1] 48 | model_out, layer_past = model( 49 | current_input_ids, 50 | layer_past=layer_past, 51 | ) 52 | 53 | greedy_predicted_token_ids = model_out[:, -1].argmax(-1) 54 | 55 | current_input_ids = greedy_predicted_token_ids[:, None] 56 | layer_past_length += input_length 57 | 58 | for i in range(batch_size): 59 | if greedy_predicted_token_ids[i].item() == 187: 60 | should_break = True 61 | raise BreakOuterLoop 62 | l.append(greedy_predicted_token_ids[i]) 63 | 64 | if should_break: 65 | break 66 | 67 | except BreakOuterLoop: 68 | pass 69 | 70 | return l 71 | 72 | 73 | def text_generation(m, k, model: nn.Module, 74 | tokenizer, 75 | initial_str: str, 76 | max_seq_len: int, 77 | device=torch.device("cuda:0"), # 0. 78 | verbose=True): 79 | 80 | tokenized = tokenizer.encode(initial_str) 81 | if len(tokenized.ids) > 1919: 82 | input_ids = torch.LongTensor([tokenized.ids[-2009:]]).to(device) 83 | else: 84 | input_ids = torch.LongTensor([tokenized.ids]).to(device) 85 | 86 | try: 87 | all_token_ids = greedy_generate(m, k, model=model, input_ids=input_ids, max_seq_len=max_seq_len, verbose=verbose) 88 | except BreakOuterLoop: 89 | pass 90 | 91 | decoded_str = tokenizer.decode(all_token_ids) 92 | 93 | if len(decoded_str)< 2: 94 | return '"#'+str(m)+'"' 95 | elif decoded_str[1].isdigit(): 96 | return decoded_str 97 | else: 98 | return '"#'+str(m)+'"' -------------------------------------------------------------------------------- /paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenzhongwu20/RuleRAG_ICL_FT/79a4a7742af9e172516044a79a75a417c1089f91/paper.pdf -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, TaskType 2 | from transformers import LlamaTokenizer, LlamaForCausalLM, BitsAndBytesConfig 3 | import transformers 4 | from datasets import load_dataset 5 | import torch 6 | 7 | 8 | def get_model_and_tokenizer(args, config): 9 | if args.BIT_8: 10 | model = LlamaForCausalLM.from_pretrained( 11 | args.MODEL_NAME, 12 | load_in_8bit=True, 13 | device_map="auto", 14 | trust_remote_code=True, 15 | ) 16 | elif args.BIT_4: 17 | quant_config = BitsAndBytesConfig( 18 | load_in_4bit=True, 19 | bnb_4bit_use_double_quant=True, 20 | bnb_4bit_quant_type="nf4", 21 | bnb_4bit_compute_dtype=torch.bfloat16 22 | ) 23 | model = LlamaForCausalLM.from_pretrained( 24 | args.MODEL_NAME, 25 | quantization_config=quant_config, 26 | device_map="auto", 27 | trust_remote_code=True, 28 | ) 29 | else: 30 | model = LlamaForCausalLM.from_pretrained( 31 | "/home/", 32 | device_map="auto", 33 | trust_remote_code=True, 34 | ) 35 | 36 | tokenizer = LlamaTokenizer.from_pretrained( 37 | "/home/", 38 | trust_remote_code=True, 39 | pad_token="" 40 | ) 41 | 42 | model = prepare_model_for_kbit_training(model) 43 | model = get_peft_model(model, config) 44 | model.config.use_cache = False 45 | 46 | return model, tokenizer 47 | 48 | 49 | def llama2_tokenizer(args, tokenizer, data_type, data_point): 50 | if data_type == "json": 51 | data_slice_source = tokenizer( 52 | data_point["context"], 53 | max_length=args.CONTEXT_LEN, 54 | padding="max_length", 55 | truncation=True 56 | ) 57 | data_slice_target = tokenizer( 58 | data_point["target"], 59 | max_length=args.TARGET_LEN, 60 | padding=False, 61 | truncation=True 62 | ) 63 | 64 | data_slice = {} 65 | data_slice['input_ids'] = data_slice_source['input_ids'] + data_slice_target['input_ids'] + [ 66 | tokenizer.eos_token_id] + [2] * (args.TARGET_LEN - len(data_slice_target['input_ids'])) 67 | data_slice['attention_mask'] = data_slice_source['attention_mask'] + data_slice_target['attention_mask'] + [ 68 | 1] + [0] * (args.TARGET_LEN - len(data_slice_target['input_ids'])) 69 | data_slice['labels'] = [-100] * args.CONTEXT_LEN + data_slice_target['input_ids'] + [ 70 | tokenizer.eos_token_id] + [-100] * (args.TARGET_LEN - len(data_slice_target['input_ids'])) 71 | 72 | 73 | elif data_type == "txt": 74 | data_slice = tokenizer( 75 | data_point["text"], 76 | max_length=args.TEXT_LEN, 77 | padding="max_length", 78 | truncation=True 79 | ) 80 | data_slice['input_ids'] = data_slice['input_ids'].extend([tokenizer.eos_token_id]) 81 | data_slice['attention_mask'] = data_slice['attention_mask'].extend([1]) 82 | 83 | return data_slice 84 | 85 | 86 | def process_data(args, tokenizer, data_type, dataset): 87 | data = dataset.shuffle().map( 88 | lambda data_point: llama2_tokenizer( 89 | args, 90 | tokenizer, 91 | data_type, 92 | data_point 93 | ) 94 | ) 95 | 96 | return data 97 | 98 | 99 | def get_lora_config(args): 100 | config = LoraConfig( 101 | r=args.LORA_R, 102 | lora_alpha=args.LORA_ALPHA, 103 | lora_dropout=args.LORA_DROPOUT, 104 | task_type=TaskType.CAUSAL_LM, 105 | target_modules=["q_proj", "v_proj"] 106 | ) 107 | 108 | return config 109 | 110 | class llama2_trainer(transformers.Trainer): 111 | def compute_loss(self, model, inputs, return_outputs=False): 112 | 113 | return model( 114 | input_ids=inputs["input_ids"], 115 | labels=inputs["labels"], 116 | ).loss 117 | 118 | def get_trainer(args, model, data, tokenizer): 119 | GRADIENT_ACCUMULATION_STEPS = args.BATCH_SIZE // args.MICRO_BATCH_SIZE 120 | LOAD_BEST_MODEL_AT_END = False 121 | if args.LOAD_BEST_MODEL_AT_END == 1: 122 | LOAD_BEST_MODEL_AT_END = True 123 | trainer = llama2_trainer( 124 | model=model, 125 | train_dataset=data['train'], 126 | args=transformers.TrainingArguments( 127 | per_device_train_batch_size=args.MICRO_BATCH_SIZE, 128 | gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, 129 | warmup_steps=args.WARMUP_STEPS, 130 | num_train_epochs=args.EPOCHS, 131 | learning_rate=args.LEARNING_RATE, 132 | save_strategy="steps", 133 | save_steps=args.SAVE_STEPS, 134 | eval_steps=args.EVAL_STEPS, 135 | output_dir=args.OUTPUT_DIR, 136 | overwrite_output_dir=True, 137 | save_total_limit=args.SAVE_TOTAL_LIMIT, 138 | evaluation_strategy=args.EVAL_STRATEGY, 139 | report_to=None, # enable logging to W&B 140 | run_name=None, # name of the W&B run (optional) 141 | load_best_model_at_end=LOAD_BEST_MODEL_AT_END, 142 | logging_steps=args.LOGGING_STEPS, 143 | bf16=False, #True, 144 | adam_beta1= 0.9, #adjust adam 145 | adam_beta2= 0.95, 146 | ), 147 | data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), 148 | ) 149 | 150 | return trainer 151 | 152 | 153 | def get_dataset(data_type, data_path): 154 | if data_type == "json": 155 | dataset = load_dataset("json", data_files=data_path) 156 | elif data_type == "txt": 157 | dataset = load_dataset("text", data_files=data_path) 158 | 159 | return dataset 160 | --------------------------------------------------------------------------------