├── LICENSE ├── README.md ├── dataset.py ├── datasets ├── addition │ ├── 1hole_(50, 50)_10_441_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ └── train.json │ ├── 1hole_(50, 50)_10_441_0-100_cot │ │ ├── test.json │ │ └── train.json │ ├── 1hole_(50, 50)_8_289_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ └── train.json │ ├── 1hole_(50, 50)_8_289_0-100_cot │ │ ├── test.json │ │ └── train.json │ ├── 3hole_1411_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── 3hole_2155_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── 3hole_2363_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── cot_gen.py │ ├── create.py │ └── random_split_0.7_7000_3000_0-100 │ │ ├── test.json │ │ └── train.json ├── base_addition │ ├── 1hole_(50, 50)_10_441_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ └── train.json │ ├── 1hole_(50, 50)_10_441_0-100_cot │ │ ├── test.json │ │ └── train.json │ ├── 3hole_1899_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── 3hole_1907_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── 3hole_2459_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── cot_gen.py │ ├── create.py │ ├── random_split_0.7_7000_3000_0-100 │ │ ├── test.json │ │ └── train.json │ └── test.ipynb ├── linear_regression │ ├── 1hole_(50, 50)_10_441_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ └── train.json │ ├── 3hole_1931_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── 3hole_2243_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── 3hole_2339_0-100 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── create.py │ └── random_split_0.7_7000_3000_0-100 │ │ ├── test.json │ │ └── train.json ├── mod_addition │ ├── 1hole_(56, 56)_10_441_0-113 │ │ ├── data_split.png │ │ ├── test.json │ │ └── train.json │ ├── 1hole_(56, 56)_10_441_0-113_cot │ │ ├── test.json │ │ └── train.json │ ├── 3hole_2803_0-113 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── 3hole_2811_0-113 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── 3hole_3187_0-113 │ │ ├── data_split.png │ │ ├── test.json │ │ ├── test_squares.txt │ │ └── train.json │ ├── cot_gen.py │ ├── create.py │ ├── random_split_0.7_8938_3831_0-113 │ │ ├── test.json │ │ └── train.json │ └── test.ipynb └── rabbits_and_chickens │ ├── 1hole_(70, 50)_10_441_0-100 │ ├── data_split.png │ ├── test.json │ └── train.json │ ├── 1hole_(75, 50)_10_441_0-100 │ ├── data_split.png │ ├── test.json │ └── train.json │ ├── 3hole_363_0-100 │ ├── data_split.png │ ├── test.json │ ├── test_squares.txt │ └── train.json │ ├── 3hole_459_0-100 │ ├── data_split.png │ ├── test.json │ ├── test_squares.txt │ └── train.json │ ├── 3hole_507_0-100 │ ├── data_split.png │ ├── test.json │ ├── test_squares.txt │ └── train.json │ ├── create.py │ ├── rabbits_and_chickens.json │ ├── rabbits_and_chickens_test.json │ ├── rabbits_and_chickens_train.json │ ├── random_split_0.7_3535_1515_0-100 │ ├── test.json │ └── train.json │ ├── split.py │ ├── tmux-client-538840.log │ └── tmux-server-538842.log ├── icl ├── addition_icl.py ├── base_addition_icl.py ├── icl_learning.py ├── icl_learning_base.py ├── prompt_base10.py ├── prompt_base9.py └── readme.md ├── llama ├── README.md ├── fashchat │ ├── .DS_Store │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── llm_judge │ │ ├── .DS_Store │ │ └── gen_model_answer.py │ ├── model │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── apply_delta.py │ │ ├── apply_lora.py │ │ ├── compression.py │ │ ├── convert_fp16.py │ │ ├── llama_condense_monkey_patch.py │ │ ├── make_delta.py │ │ ├── model_adapter.py │ │ ├── model_chatglm.py │ │ ├── model_codet5p.py │ │ ├── model_exllama.py │ │ ├── model_falcon.py │ │ ├── model_registry.py │ │ ├── model_xfastertransformer.py │ │ ├── monkey_patch_non_inplace.py │ │ ├── rwkv_model.py │ │ └── upload_hub.py │ ├── modules │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── awq.py │ │ ├── exllama.py │ │ ├── gptq.py │ │ └── xfastertransformer.py │ ├── train │ │ ├── .DS_Store │ │ ├── llama2_flash_attn_monkey_patch.py │ │ ├── llama_flash_attn_monkey_patch.py │ │ ├── llama_xformers_attn_monkey_patch.py │ │ ├── train.py │ │ ├── train_concat.py │ │ ├── train_from_scratch.py │ │ ├── train_lora.py │ │ ├── train_lora_concat.py │ │ ├── train_lora_scratch.py │ │ ├── train_mem.py │ │ └── train_scratch.py │ └── utils.py ├── playground │ ├── .DS_Store │ ├── deepspeed_config_s2.json │ ├── deepspeed_config_s3.json │ └── test_embedding │ │ ├── README.md │ │ ├── test_classification.py │ │ ├── test_semantic_search.py │ │ └── test_sentence_similarity.py ├── pyproject.toml └── train.sh ├── plot_1hole.ipynb ├── plot_1hole_rec.ipynb ├── plot_3hole.ipynb ├── plot_ablation.ipynb ├── plot_scratch.ipynb ├── requirements.txt └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yi Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Case-Based or Rule-Based: How Do Transformers Do the Math? 2 | 3 | We explore whether LLMs perform case-based or rule-based reasoning in this work. 4 | 5 | :star: Official code for [Case-Based or Rule-Based: How Do Transformers Do the Math?](https://arxiv.org/abs/2402.17709). 6 | 7 | ## Requirements 8 | Tested combination of python packages that can successfully complete the program is listed in [requirements.txt](/requirements.txt). You can run the following script to install them. 9 | 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Replication of Leave-Square-Out 15 | 16 | To replicate our main experiments of Leaving-Square-Out, you need to download the GPT-2 or GPT-2 Medium models and put them in `.\pretrained_models`. Then, you can run the script [train.py](/train.py) to fine-tune the pre-trained models. 17 | 18 | ## Datasets 19 | 20 | We provide the datasets for our main experiments in `.\datasets`. In each dataset, we provide a figure showing the train-test split `data_split.png`. 21 | 22 | 23 | ## Llama 24 | We adopt the FastChat Framework to finetune Llama-7B in `./llama`. 25 | 26 | ## Citation 27 | If you want to use the code for your research, please cite our paper: 28 | ```bibtex 29 | @misc{hu2024casebased, 30 | title={Case-Based or Rule-Based: How Do Transformers Do the Math?}, 31 | author={Yi Hu and Xiaojuan Tang and Haotong Yang and Muhan Zhang}, 32 | year={2024}, 33 | eprint={2402.17709}, 34 | archivePrefix={arXiv}, 35 | primaryClass={cs.AI} 36 | } 37 | ``` -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import Dataset 4 | from transformers.models.gpt2 import GPT2Tokenizer 5 | import json 6 | import numpy as np 7 | import random 8 | 9 | class GPT2Dataset(Dataset): 10 | def __init__(self, file_path: str, max_length: int, eda_aug: bool = False): 11 | self.file_path = file_path 12 | self.max_length = max_length 13 | self.tokenizer = GPT2Tokenizer.from_pretrained("pretrained_models/gpt2") 14 | # self.tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf') 15 | self.tokenizer.pad_token = self.tokenizer.eos_token 16 | # self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 17 | self.eda_aug = eda_aug 18 | 19 | 20 | with open(self.file_path, "r") as f: 21 | self.json_file = json.load(f) # a list 22 | if 'answers' in self.json_file[0].keys(): 23 | self.num_answer = [len(q['answers']) for q in self.json_file] 24 | elif 'answers_with_replace' in self.json_file[0].keys(): 25 | self.num_answer = [len(q['answers_with_replace']) for q in self.json_file] 26 | else: 27 | raise ValueError("The json file must have the key 'answers' or 'answers_with_replace'.") 28 | self.cum_answer = np.cumsum(self.num_answer) 29 | 30 | def __len__(self): 31 | return self.cum_answer[-1] 32 | 33 | def __getitem__(self, idx): 34 | # first find the idx of answer, the first number in the self.cum_answer is larger than the idx. 35 | question_idx = np.searchsorted(self.cum_answer, idx, side='right') 36 | answer_idx = idx - self.cum_answer[question_idx-1] if question_idx > 0 else idx 37 | question = self.json_file[question_idx]['question'] 38 | answer = self.json_file[question_idx]['answers'][answer_idx] 39 | text = f"{question}{answer}" 40 | # if self.eda_aug: 41 | # from eda import eda 42 | # text = eda(text, alpha_sr=0.2, alpha_ri=0.1, alpha_rs=0, p_rd=0.1, num_aug=1)[0] 43 | encoding = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt") 44 | return encoding.input_ids.squeeze(), torch.tensor(0) 45 | 46 | class GPT2DatasetReplace(GPT2Dataset): 47 | def __init__(self, file_path: str, max_length: int, replace_prob:float, random_seed: int, eda_aug: bool = False): 48 | super().__init__(file_path, max_length, eda_aug=eda_aug) 49 | self.replace_prob = replace_prob 50 | self.random_seed = random_seed 51 | np.random.seed(self.random_seed) 52 | torch.manual_seed(self.random_seed) 53 | random.seed(self.random_seed) 54 | 55 | def __getitem__(self, idx): 56 | # first find the idx of answer, the first number in the self.cum_answer is larger than the idx. 57 | question_idx = np.searchsorted(self.cum_answer, idx, side='right') 58 | answer_idx = idx - self.cum_answer[question_idx-1] if question_idx > 0 else idx 59 | if np.random.rand() > self.replace_prob: 60 | question = self.json_file[question_idx]['question'] 61 | answer = self.json_file[question_idx]['answers_with_replace'][answer_idx]['answer'] 62 | text = question + answer 63 | else: 64 | # select one replacement in the self.json_file[question_idx]['answers_with_replace'][answer_idx]['replaced_qa'], which is a list 65 | text = random.choice(self.json_file[question_idx]['answers_with_replace'][answer_idx]['replaced_qa']) 66 | # if self.eda_aug: 67 | # from eda import eda 68 | # text = eda(text, alpha_sr=0.2, alpha_ri=0.1, alpha_rs=0, p_rd=0.1, num_aug=1)[0] 69 | encoding = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt") 70 | return encoding.input_ids.squeeze(), torch.tensor(0) 71 | 72 | class TestDataset(Dataset): 73 | def __init__(self, file_path: str): 74 | self.file_path = file_path 75 | self.tokenizer = GPT2Tokenizer.from_pretrained("pretrained_models/gpt2") 76 | # self.tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf') 77 | 78 | with open(self.file_path, "r") as f: 79 | self.json_file = json.load(f) # a list 80 | 81 | def __len__(self): 82 | return len(self.json_file) 83 | def __getitem__(self, idx): 84 | return self.tokenizer(self.json_file[idx]['question'], return_tensors='pt').input_ids, self.json_file[idx]['answers'][0] -------------------------------------------------------------------------------- /datasets/addition/1hole_(50, 50)_10_441_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/addition/1hole_(50, 50)_10_441_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/addition/1hole_(50, 50)_8_289_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/addition/1hole_(50, 50)_8_289_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/addition/3hole_1411_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/addition/3hole_1411_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/addition/3hole_1411_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(53, 33), (73, 78), (13, 31)] 2 | lengths:[10, 10, 11] -------------------------------------------------------------------------------- /datasets/addition/3hole_2155_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/addition/3hole_2155_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/addition/3hole_2155_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(66, 14), (38, 37), (26, 85)] 2 | lengths:[10, 16, 12] -------------------------------------------------------------------------------- /datasets/addition/3hole_2363_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/addition/3hole_2363_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/addition/3hole_2363_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(20, 27), (81, 55), (45, 66)] 2 | lengths:[10, 15, 15] -------------------------------------------------------------------------------- /datasets/addition/cot_gen.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | def extract_answer(rationale): 6 | answer = rationale.split("\n")[-1].split(",")[1] 7 | return answer 8 | 9 | def count_c(a_digit, b_digit, c): 10 | a_digit = int(a_digit) 11 | b_digit = int(b_digit) 12 | c = int(c) 13 | if a_digit + b_digit + c>= 10: 14 | return 1 15 | else: 16 | return 0 17 | 18 | def gen_cot_rationale(a, b, gt, cot="scratch"): 19 | ''' 20 | return the cot rationale for the question 21 | ''' 22 | if cot == "scratch": 23 | a_digits = [digit for digit in str(a)] 24 | b_digits = [digit for digit in str(b)] 25 | rationale = "" 26 | answer = "" 27 | c = 0 28 | for _ in range(len(str(int(a)+int(b)))+1): 29 | line = f"{''.join(a_digits)}+{''.join(b_digits)},{answer},C:{c}\n" 30 | rationale += line 31 | if a_digits and b_digits: 32 | answer = str(int(a_digits[-1]) + int(b_digits[-1]) + c)[-1] + answer 33 | c = count_c(a_digits[-1], b_digits[-1], c) 34 | a_digits.pop() 35 | b_digits.pop() 36 | elif a_digits: 37 | answer = str(int(a_digits[-1]) + c)[-1] + answer 38 | c = count_c(a_digits[-1], 0, c) 39 | a_digits.pop() 40 | elif b_digits: 41 | answer = str(int(b_digits[-1]) + c)[-1] + answer 42 | c = count_c(0, b_digits[-1], c) 43 | b_digits.pop() 44 | else: 45 | if c: 46 | answer = str(c) + answer 47 | c = 0 48 | rationale = rationale.strip() 49 | assert int(extract_answer(rationale)) == int(gt) 50 | return f"{rationale}\n{gt}" 51 | 52 | 53 | if __name__ == "__main__": 54 | title = "1hole_(50, 50)_8_289_0-100" 55 | with open(f"{title}/train.json", "r") as f: 56 | train_samples = json.load(f) 57 | 58 | with open(f"{title}/test.json", "r") as f: 59 | test_samples = json.load(f) 60 | 61 | cot_train_samples = [] 62 | for train_sample in train_samples: 63 | gt = train_sample['answers'][0] 64 | train_sample['answers'] = [gen_cot_rationale(train_sample["a"], train_sample["b"], gt)] 65 | train_sample["gt"] = gt 66 | cot_train_samples.append(train_sample) 67 | 68 | cot_test_samples = [] 69 | for test_sample in test_samples: 70 | gt = test_sample['answers'][0] 71 | test_sample['answers'] = [gen_cot_rationale(test_sample["a"], test_sample["b"], gt)] 72 | test_sample["gt"] = gt 73 | cot_test_samples.append(test_sample) 74 | 75 | os.mkdir(f"{title}_cot") 76 | with open(f"{title}_cot/train.json", "w") as f: 77 | json.dump(cot_train_samples, f) 78 | 79 | with open(f"{title}_cot/test.json", "w") as f: 80 | json.dump(cot_test_samples, f) 81 | -------------------------------------------------------------------------------- /datasets/addition/create.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The script is to create the addition dataset. 3 | 4 | Here's an example: 5 | 1+1=2 6 | ''' 7 | 8 | import json 9 | import math 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import random_split 13 | import os 14 | import random 15 | import interval 16 | from tqdm import tqdm 17 | import matplotlib.pyplot as plt 18 | 19 | def generate_dataset(a_range, b_range): 20 | idx = 0 21 | samples = [] 22 | for a in a_range: 23 | for b in b_range: 24 | a = int(a) 25 | b = int(b) 26 | samples.append({ 27 | "id": idx, 28 | "question": f"{a}+{b}=", 29 | "a": a, 30 | "b": b, 31 | "answers": [str(a + b)] 32 | }) 33 | idx += 1 34 | print(len(samples)) 35 | return samples 36 | 37 | def generate_test(a_range, b_range, length_range, square_num): 38 | generated_square = 0 39 | centers = [] 40 | lengths = [] 41 | a_start = a_range[0] 42 | a_end = a_range[-1] 43 | b_start = b_range[0] 44 | b_end = b_range[-1] 45 | a_intervals = [] 46 | b_intervals = [] 47 | test_num = 0 48 | while generated_square < square_num: 49 | # generate centers 50 | center = (random.randint(a_start, a_end), 51 | random.randint(b_start, b_end)) 52 | # for each center, generate length, make sure that center +- length is still in the range 53 | length = random.randint(length_range[0], length_range[-1]) 54 | a_interval = interval.Interval(center[0] - length, center[0] + length) 55 | b_interval = interval.Interval(center[1] - length, center[1] + length) 56 | if a_interval in interval.Interval(a_start, a_end) and b_interval in interval.Interval(b_start, b_end): 57 | overlap = False 58 | # make sure the squares do not overlap 59 | for idx in range(generated_square): 60 | if a_interval.overlaps(a_intervals[idx]) and b_interval.overlaps(b_intervals[idx]): 61 | overlap = True 62 | break 63 | if not overlap: 64 | centers.append(center) 65 | lengths.append(length) 66 | a_intervals.append(a_interval) 67 | b_intervals.append(b_interval) 68 | generated_square += 1 69 | test_num += (2 * length + 1) ** 2 70 | return centers, lengths, test_num 71 | 72 | 73 | if __name__ == "__main__": 74 | start = 0 75 | end = 100 76 | dataset = generate_dataset(np.arange(start,end), np.arange(start,end)) 77 | status = "random_split" # choose from "random_split" or "one_hole" or "multi_holes" or "column" 78 | status = "one_hole" 79 | train_ratio = 0.7 80 | hole_num = 3 81 | center_a = round((start + end)/2) 82 | center_b = round((start + end)/2) 83 | length = 8 84 | column_idx = random.sample(range(start, end), 21) 85 | row_range = [40, 60] 86 | 87 | if status == "random_split": 88 | train_num = int(len(dataset) * train_ratio) 89 | test_num = len(dataset) - train_num 90 | train_set, test_set = random_split(dataset=dataset, lengths=[train_num, test_num], generator=torch.Generator().manual_seed(42)) 91 | title = f"random_split_{train_ratio}_{train_num}_{test_num}_{start}-{end}" 92 | os.mkdir(f"{title}") 93 | with open(f"{title}/train.json", "w") as f: 94 | json.dump(list(train_set), f) 95 | with open(f"{title}/test.json", "w") as f: 96 | json.dump(list(test_set), f) 97 | 98 | elif status == "one_hole": 99 | center = (center_a,center_b) 100 | assert length * 2 + 1 <= np.sqrt(len(dataset) * (1-train_ratio)) 101 | 102 | title = f"1hole_{center}_{length}_{(length * 2 + 1)**2}_{start}-{end}" 103 | test_set = [] 104 | train_set = [] 105 | img = np.zeros((end-start, end-start)) 106 | for sample in tqdm(dataset): 107 | test = False 108 | if center[0] - length <= sample["a"] <= center[0] + length and center[1] - length <= sample["b"] <= center[1] + length: 109 | test_set.append(sample) 110 | test = True 111 | img[sample["a"], sample["b"]] = 2 112 | else: 113 | train_set.append(sample) 114 | img[sample["a"], sample["b"]] = 1 115 | os.mkdir(f"{title}") 116 | with open(f"{title}/train.json", "w") as f: 117 | json.dump(list(train_set), f) 118 | with open(f"{title}/test.json", "w") as f: 119 | json.dump(list(test_set), f) 120 | plt.imshow(img) 121 | plt.savefig(f"{title}/data_split") 122 | 123 | elif status == "multi_holes": 124 | centers, lengths, test_num = generate_test(np.arange(start, end), np.arange(start, end), np.arange(10,20), hole_num) 125 | title = f"{hole_num}hole_{test_num}_{start}-{end}" 126 | test_set = [] 127 | train_set = [] 128 | img = np.zeros((end-start, end-start)) 129 | for sample in tqdm(dataset): 130 | test = False 131 | for idx, center in enumerate(centers): 132 | length = lengths[idx] 133 | if center[0] - length <= sample["a"] <= center[0] + length and center[1] - length <= sample["b"] <= center[1] + length: 134 | test_set.append(sample) 135 | test = True 136 | img[sample["a"], sample["b"]] = 1 137 | break 138 | if test:pass 139 | else: 140 | train_set.append(sample) 141 | img[sample["a"], sample["b"]] = 0 142 | os.mkdir(f"{title}") 143 | with open(f"{title}/train.json", "w") as f: 144 | json.dump(list(train_set), f) 145 | with open(f"{title}/test.json", "w") as f: 146 | json.dump(list(test_set), f) 147 | with open(f"{title}/test_squares.txt", "w") as f: 148 | f.write("centers:{}\nlengths:{}".format(centers, lengths)) 149 | plt.imshow(img) 150 | plt.savefig(f"{title}/data_split") 151 | 152 | elif status == "column": 153 | title = f"column_{len(column_idx)}_{row_range}_{int(len(column_idx) * (1 + row_range[1]-row_range[0]))}" 154 | 155 | test_set = [] 156 | train_set = [] 157 | img = np.zeros((end-start, end-start)) 158 | 159 | for sample in tqdm(dataset): 160 | test = False 161 | if row_range[0] <= sample["a"] <= row_range[1] and sample["b"] in column_idx: 162 | test_set.append(sample) 163 | test = True 164 | img[sample["a"], sample["b"]] = 2 165 | else: 166 | train_set.append(sample) 167 | img[sample["a"], sample["b"]] = 1 168 | os.mkdir(f"{title}") 169 | with open(f"{title}/train.json", "w") as f: 170 | json.dump(list(train_set), f) 171 | with open(f"{title}/test.json", "w") as f: 172 | json.dump(list(test_set), f) 173 | plt.imshow(img) 174 | plt.savefig(f"{title}/data_split") 175 | with open(f"{title}/columns.txt", "w") as f: 176 | f.write("columns:{}\nrow_range:{}".format(column_idx, row_range)) 177 | 178 | elif status == "row": 179 | title = f"row_{len(column_idx)}_{row_range}_{int(len(column_idx) * (1 + row_range[1]-row_range[0]))}" 180 | 181 | test_set = [] 182 | train_set = [] 183 | img = np.zeros((end-start, end-start)) 184 | 185 | for sample in tqdm(dataset): 186 | test = False 187 | if row_range[0] <= sample["b"] <= row_range[1] and sample["a"] in column_idx: 188 | test_set.append(sample) 189 | test = True 190 | img[sample["a"], sample["b"]] = 2 191 | else: 192 | train_set.append(sample) 193 | img[sample["a"], sample["b"]] = 1 194 | os.mkdir(f"{title}") 195 | with open(f"{title}/train.json", "w") as f: 196 | json.dump(list(train_set), f) 197 | with open(f"{title}/test.json", "w") as f: 198 | json.dump(list(test_set), f) 199 | plt.imshow(img) 200 | plt.savefig(f"{title}/data_split") 201 | with open(f"{title}/rows.txt", "w") as f: 202 | f.write("columns:{}\nrow_range:{}".format(column_idx, row_range)) 203 | -------------------------------------------------------------------------------- /datasets/base_addition/1hole_(50, 50)_10_441_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/base_addition/1hole_(50, 50)_10_441_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/base_addition/3hole_1899_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/base_addition/3hole_1899_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/base_addition/3hole_1899_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(41, 41), (73, 63), (75, 27)] 2 | lengths:[13, 13, 10] -------------------------------------------------------------------------------- /datasets/base_addition/3hole_1907_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/base_addition/3hole_1907_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/base_addition/3hole_1907_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(76, 41), (10, 21), (84, 78)] 2 | lengths:[12, 10, 14] -------------------------------------------------------------------------------- /datasets/base_addition/3hole_2459_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/base_addition/3hole_2459_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/base_addition/3hole_2459_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(46, 31), (73, 47), (29, 70)] 2 | lengths:[14, 11, 16] -------------------------------------------------------------------------------- /datasets/base_addition/cot_gen.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | def extract_answer(rationale): 6 | answer = rationale.split("\n")[-1].split(",")[1] 7 | return answer 8 | 9 | def count_add(a_digit, b_digit, c, base): 10 | a_digit = int(a_digit) 11 | b_digit = int(b_digit) 12 | c = int(c) 13 | r = a_digit + b_digit + c 14 | if r >= base: 15 | c = 1 16 | r = r - base 17 | else: 18 | c = 0 19 | return r, c 20 | 21 | def convert_to_base(n, base): 22 | s = "" 23 | n = int(n) 24 | while True: 25 | if n < base: 26 | s = str(n) + s 27 | break 28 | else: 29 | d = n % base 30 | n = n // base 31 | s = str(d) + s 32 | return s 33 | 34 | def gen_cot_rationale(a, b, gt, base, cot="scratch"): 35 | ''' 36 | return the cot rationale for the question 37 | ''' 38 | if cot == "scratch": 39 | a_digits = [digit for digit in str(a)] 40 | b_digits = [digit for digit in str(b)] 41 | rationale = "" 42 | answer = "" 43 | c = 0 44 | for _ in range(len(convert_to_base(int(a)+int(b), base))+1): 45 | line = f"{''.join(a_digits)}+{''.join(b_digits)},{answer},C:{c}\n" 46 | rationale += line 47 | if a_digits and b_digits: 48 | r, c = count_add(a_digits[-1], b_digits[-1], c, base) 49 | answer = str(r) + answer 50 | a_digits.pop() 51 | b_digits.pop() 52 | elif a_digits: 53 | r, c = count_add(a_digits[-1], 0, c, base) 54 | answer = str(r) + answer 55 | a_digits.pop() 56 | elif b_digits: 57 | r, c = count_add(0, b_digits[-1], c, base) 58 | answer = str(r) + answer 59 | b_digits.pop() 60 | else: 61 | if c: 62 | r, c = count_add(0, 0, c, base) 63 | answer = str(r) + answer 64 | c = 0 65 | rationale = rationale.strip() 66 | try: 67 | assert int(extract_answer(rationale)) == int(gt) 68 | except: 69 | print(rationale) 70 | print(gt) 71 | raise AssertionError 72 | return f"{rationale}\n{gt}" 73 | 74 | 75 | if __name__ == "__main__": 76 | title = "1hole_(50, 50)_10_441_0-100" 77 | B = 9 78 | with open(f"{title}/train.json", "r") as f: 79 | train_samples = json.load(f) 80 | 81 | with open(f"{title}/test.json", "r") as f: 82 | test_samples = json.load(f) 83 | 84 | cot_train_samples = [] 85 | for train_sample in train_samples: 86 | gt = train_sample['answers'][0] 87 | a = convert_to_base(train_sample["a"], B) 88 | b = convert_to_base(train_sample["b"], B) 89 | train_sample['answers'] = [gen_cot_rationale(a, b, gt, B)] 90 | train_sample["gt"] = gt 91 | cot_train_samples.append(train_sample) 92 | 93 | cot_test_samples = [] 94 | for test_sample in test_samples: 95 | gt = test_sample['answers'][0] 96 | a = convert_to_base(test_sample["a"], B) 97 | b = convert_to_base(test_sample["b"], B) 98 | test_sample['answers'] = [gen_cot_rationale(a, b, gt, B)] 99 | test_sample["gt"] = gt 100 | cot_test_samples.append(test_sample) 101 | 102 | os.mkdir(f"{title}_cot") 103 | with open(f"{title}_cot/train.json", "w") as f: 104 | json.dump(cot_train_samples, f) 105 | 106 | with open(f"{title}_cot/test.json", "w") as f: 107 | json.dump(cot_test_samples, f) 108 | -------------------------------------------------------------------------------- /datasets/base_addition/create.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The script is to create the base addition dataset (in base 9). 3 | 4 | Here's an example: 5 | 76+14=101 6 | ''' 7 | 8 | import json 9 | import math 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import random_split 13 | import os 14 | import random 15 | import interval 16 | from tqdm import tqdm 17 | import matplotlib.pyplot as plt 18 | 19 | def convert_to_base(n, base): 20 | s = "" 21 | while True: 22 | if n < base: 23 | s = str(n) + s 24 | break 25 | else: 26 | d = n % base 27 | n = n // base 28 | s = str(d) + s 29 | return s 30 | 31 | def generate_dataset(a_range, b_range, base=9): 32 | idx = 0 33 | samples = [] 34 | for a in a_range: 35 | for b in b_range: 36 | a = int(a) 37 | b = int(b) 38 | a_base = convert_to_base(a, base) 39 | b_base = convert_to_base(b, base) 40 | samples.append({ 41 | "id": idx, 42 | "question": f"{a_base}+{b_base}=", 43 | "a": a, # in base-10 44 | "b": b, # in base-10 45 | "answers": [convert_to_base(a+b, base)] 46 | }) 47 | idx += 1 48 | print(len(samples)) 49 | return samples 50 | 51 | def generate_test(a_range, b_range, length_range, square_num): 52 | generated_square = 0 53 | centers = [] 54 | lengths = [] 55 | a_start = a_range[0] 56 | a_end = a_range[-1] 57 | b_start = b_range[0] 58 | b_end = b_range[-1] 59 | a_intervals = [] 60 | b_intervals = [] 61 | test_num = 0 62 | while generated_square < square_num: 63 | # generate centers 64 | center = (random.randint(a_start, a_end), 65 | random.randint(b_start, b_end)) 66 | # for each center, generate length, make sure that center +- length is still in the range 67 | length = random.randint(length_range[0], length_range[-1]) 68 | a_interval = interval.Interval(center[0] - length, center[0] + length) 69 | b_interval = interval.Interval(center[1] - length, center[1] + length) 70 | if a_interval in interval.Interval(a_start, a_end) and b_interval in interval.Interval(b_start, b_end): 71 | overlap = False 72 | # make sure the squares do not overlap 73 | for idx in range(generated_square): 74 | if a_interval.overlaps(a_intervals[idx]) and b_interval.overlaps(b_intervals[idx]): 75 | overlap = True 76 | break 77 | if not overlap: 78 | centers.append(center) 79 | lengths.append(length) 80 | a_intervals.append(a_interval) 81 | b_intervals.append(b_interval) 82 | generated_square += 1 83 | test_num += (2 * length + 1) ** 2 84 | return centers, lengths, test_num 85 | 86 | 87 | if __name__ == "__main__": 88 | start = 0 89 | end = 100 90 | dataset = generate_dataset(np.arange(start,end), np.arange(start,end)) 91 | status = "random_split" # choose from "random_split" or "one_hole" or "multi_holes" 92 | status = "one_hole" 93 | status = "multi_holes" 94 | train_ratio = 0.7 95 | hole_num = 3 96 | center_a = round((start + end)/2) 97 | center_b = round((start + end)/2) 98 | length = 10 99 | 100 | if status == "random_split": 101 | train_num = int(len(dataset) * train_ratio) 102 | test_num = len(dataset) - train_num 103 | train_set, test_set = random_split(dataset=dataset, lengths=[train_num, test_num], generator=torch.Generator().manual_seed(42)) 104 | title = f"random_split_{train_ratio}_{train_num}_{test_num}_{start}-{end}" 105 | os.mkdir(f"{title}") 106 | with open(f"{title}/train.json", "w") as f: 107 | json.dump(list(train_set), f) 108 | with open(f"{title}/test.json", "w") as f: 109 | json.dump(list(test_set), f) 110 | 111 | elif status == "one_hole": 112 | center = (center_a,center_b) 113 | assert length * 2 + 1 <= np.sqrt(len(dataset) * (1-train_ratio)) 114 | 115 | title = f"1hole_{center}_{length}_{(length * 2 + 1)**2}_{start}-{end}" 116 | test_set = [] 117 | train_set = [] 118 | img = np.zeros((end-start, end-start)) 119 | for sample in tqdm(dataset): 120 | test = False 121 | if center[0] - length <= sample["a"] <= center[0] + length and center[1] - length <= sample["b"] <= center[1] + length: 122 | test_set.append(sample) 123 | test = True 124 | img[sample["a"], sample["b"]] = 2 125 | else: 126 | train_set.append(sample) 127 | img[sample["a"], sample["b"]] = 1 128 | os.mkdir(f"{title}") 129 | with open(f"{title}/train.json", "w") as f: 130 | json.dump(list(train_set), f) 131 | with open(f"{title}/test.json", "w") as f: 132 | json.dump(list(test_set), f) 133 | plt.imshow(img) 134 | plt.savefig(f"{title}/data_split") 135 | 136 | elif status == "multi_holes": 137 | centers, lengths, test_num = generate_test(np.arange(start, end), np.arange(start, end), np.arange(10,20), hole_num) 138 | title = f"{hole_num}hole_{test_num}_{start}-{end}" 139 | test_set = [] 140 | train_set = [] 141 | img = np.zeros((end-start, end-start)) 142 | for sample in tqdm(dataset): 143 | test = False 144 | for idx, center in enumerate(centers): 145 | length = lengths[idx] 146 | if center[0] - length <= sample["a"] <= center[0] + length and center[1] - length <= sample["b"] <= center[1] + length: 147 | test_set.append(sample) 148 | test = True 149 | img[sample["a"], sample["b"]] = 1 150 | break 151 | if test:pass 152 | else: 153 | train_set.append(sample) 154 | img[sample["a"], sample["b"]] = 0 155 | os.mkdir(f"{title}") 156 | with open(f"{title}/train.json", "w") as f: 157 | json.dump(list(train_set), f) 158 | with open(f"{title}/test.json", "w") as f: 159 | json.dump(list(test_set), f) 160 | with open(f"{title}/test_squares.txt", "w") as f: 161 | f.write("centers:{}\nlengths:{}".format(centers, lengths)) 162 | plt.imshow(img) 163 | plt.savefig(f"{title}/data_split") 164 | -------------------------------------------------------------------------------- /datasets/base_addition/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 14, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "'231'" 12 | ] 13 | }, 14 | "execution_count": 14, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "def convert_to_base(n, base):\n", 21 | " s = \"\"\n", 22 | " while True:\n", 23 | " if n < base:\n", 24 | " s = str(n) + s\n", 25 | " break\n", 26 | " else:\n", 27 | " d = n % base\n", 28 | " n = n // base\n", 29 | " s = str(d) + s\n", 30 | " return s\n", 31 | "\n", 32 | "convert_to_base(190,9)\n", 33 | "# 100//9" 34 | ] 35 | } 36 | ], 37 | "metadata": { 38 | "kernelspec": { 39 | "display_name": "Python 3.9.12 ('base': conda)", 40 | "language": "python", 41 | "name": "python3" 42 | }, 43 | "language_info": { 44 | "codemirror_mode": { 45 | "name": "ipython", 46 | "version": 3 47 | }, 48 | "file_extension": ".py", 49 | "mimetype": "text/x-python", 50 | "name": "python", 51 | "nbconvert_exporter": "python", 52 | "pygments_lexer": "ipython3", 53 | "version": "3.9.12" 54 | }, 55 | "orig_nbformat": 4, 56 | "vscode": { 57 | "interpreter": { 58 | "hash": "2344b6d4cf75e2fe63d7adea2acd8b07cf02ecdef8a7e7834a9c3ab9d9f0906f" 59 | } 60 | } 61 | }, 62 | "nbformat": 4, 63 | "nbformat_minor": 2 64 | } 65 | -------------------------------------------------------------------------------- /datasets/linear_regression/1hole_(50, 50)_10_441_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/linear_regression/1hole_(50, 50)_10_441_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/linear_regression/3hole_1931_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/linear_regression/3hole_1931_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/linear_regression/3hole_1931_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(41, 36), (11, 41), (13, 76)] 2 | lengths:[15, 10, 11] -------------------------------------------------------------------------------- /datasets/linear_regression/3hole_2243_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/linear_regression/3hole_2243_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/linear_regression/3hole_2243_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(59, 37), (29, 26), (27, 69)] 2 | lengths:[15, 10, 14] -------------------------------------------------------------------------------- /datasets/linear_regression/3hole_2339_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/linear_regression/3hole_2339_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/linear_regression/3hole_2339_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(63, 38), (46, 84), (16, 86)] 2 | lengths:[16, 12, 12] -------------------------------------------------------------------------------- /datasets/linear_regression/create.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The script is to create a linear regression dataset. (x+2y+3) 3 | 4 | Here's an example: 5 | (1,2)=8 6 | ''' 7 | 8 | import json 9 | import math 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import random_split 13 | import os 14 | import random 15 | import interval 16 | from tqdm import tqdm 17 | import matplotlib.pyplot as plt 18 | 19 | def generate_dataset(a_range, b_range): 20 | idx = 0 21 | samples = [] 22 | for a in a_range: 23 | for b in b_range: 24 | a = int(a) 25 | b = int(b) 26 | samples.append({ 27 | "id": idx, 28 | "question": f"({a},{b})=", 29 | "a": a, 30 | "b": b, 31 | "answers": [str(a+2*b+3)] 32 | }) 33 | idx += 1 34 | print(len(samples)) 35 | return samples 36 | 37 | def generate_test(a_range, b_range, length_range, square_num): 38 | generated_square = 0 39 | centers = [] 40 | lengths = [] 41 | a_start = a_range[0] 42 | a_end = a_range[-1] 43 | b_start = b_range[0] 44 | b_end = b_range[-1] 45 | a_intervals = [] 46 | b_intervals = [] 47 | test_num = 0 48 | while generated_square < square_num: 49 | # generate centers 50 | center = (random.randint(a_start, a_end), 51 | random.randint(b_start, b_end)) 52 | # for each center, generate length, make sure that center +- length is still in the range 53 | length = random.randint(length_range[0], length_range[-1]) 54 | a_interval = interval.Interval(center[0] - length, center[0] + length) 55 | b_interval = interval.Interval(center[1] - length, center[1] + length) 56 | if a_interval in interval.Interval(a_start, a_end) and b_interval in interval.Interval(b_start, b_end): 57 | overlap = False 58 | # make sure the squares do not overlap 59 | for idx in range(generated_square): 60 | if a_interval.overlaps(a_intervals[idx]) and b_interval.overlaps(b_intervals[idx]): 61 | overlap = True 62 | break 63 | if not overlap: 64 | centers.append(center) 65 | lengths.append(length) 66 | a_intervals.append(a_interval) 67 | b_intervals.append(b_interval) 68 | generated_square += 1 69 | test_num += (2 * length + 1) ** 2 70 | return centers, lengths, test_num 71 | 72 | 73 | if __name__ == "__main__": 74 | start = 0 75 | end = 100 76 | dataset = generate_dataset(np.arange(start,end), np.arange(start,end)) 77 | status = "random_split" # choose from "random_split" or "one_hole" or "multi_holes" 78 | status = "multi_holes" 79 | train_ratio = 0.7 80 | hole_num = 3 81 | center_a = round((start + end)/2) 82 | center_b = round((start + end)/2) 83 | length = 10 84 | 85 | if status == "random_split": 86 | train_num = int(len(dataset) * train_ratio) 87 | test_num = len(dataset) - train_num 88 | train_set, test_set = random_split(dataset=dataset, lengths=[train_num, test_num], generator=torch.Generator().manual_seed(42)) 89 | title = f"random_split_{train_ratio}_{train_num}_{test_num}_{start}-{end}" 90 | os.mkdir(f"{title}") 91 | with open(f"{title}/train.json", "w") as f: 92 | json.dump(list(train_set), f) 93 | with open(f"{title}/test.json", "w") as f: 94 | json.dump(list(test_set), f) 95 | 96 | elif status == "one_hole": 97 | center = (center_a,center_b) 98 | assert length * 2 + 1 <= np.sqrt(len(dataset) * (1-train_ratio)) 99 | 100 | title = f"1hole_{center}_{length}_{(length * 2 + 1)**2}_{start}-{end}" 101 | test_set = [] 102 | train_set = [] 103 | img = np.zeros((end-start, end-start)) 104 | for sample in tqdm(dataset): 105 | test = False 106 | if center[0] - length <= sample["a"] <= center[0] + length and center[1] - length <= sample["b"] <= center[1] + length: 107 | test_set.append(sample) 108 | test = True 109 | img[sample["a"], sample["b"]] = 2 110 | else: 111 | train_set.append(sample) 112 | img[sample["a"], sample["b"]] = 1 113 | os.mkdir(f"{title}") 114 | with open(f"{title}/train.json", "w") as f: 115 | json.dump(list(train_set), f) 116 | with open(f"{title}/test.json", "w") as f: 117 | json.dump(list(test_set), f) 118 | plt.imshow(img) 119 | plt.savefig(f"{title}/data_split") 120 | 121 | elif status == "multi_holes": 122 | centers, lengths, test_num = generate_test(np.arange(start, end), np.arange(start, end), np.arange(10,20), hole_num) 123 | title = f"{hole_num}hole_{test_num}_{start}-{end}" 124 | test_set = [] 125 | train_set = [] 126 | img = np.zeros((end-start, end-start)) 127 | for sample in tqdm(dataset): 128 | test = False 129 | for idx, center in enumerate(centers): 130 | length = lengths[idx] 131 | if center[0] - length <= sample["a"] <= center[0] + length and center[1] - length <= sample["b"] <= center[1] + length: 132 | test_set.append(sample) 133 | test = True 134 | img[sample["a"], sample["b"]] = 1 135 | break 136 | if test:pass 137 | else: 138 | train_set.append(sample) 139 | img[sample["a"], sample["b"]] = 0 140 | os.mkdir(f"{title}") 141 | with open(f"{title}/train.json", "w") as f: 142 | json.dump(list(train_set), f) 143 | with open(f"{title}/test.json", "w") as f: 144 | json.dump(list(test_set), f) 145 | with open(f"{title}/test_squares.txt", "w") as f: 146 | f.write("centers:{}\nlengths:{}".format(centers, lengths)) 147 | plt.imshow(img) 148 | plt.savefig(f"{title}/data_split") 149 | -------------------------------------------------------------------------------- /datasets/mod_addition/1hole_(56, 56)_10_441_0-113/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/mod_addition/1hole_(56, 56)_10_441_0-113/data_split.png -------------------------------------------------------------------------------- /datasets/mod_addition/3hole_2803_0-113/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/mod_addition/3hole_2803_0-113/data_split.png -------------------------------------------------------------------------------- /datasets/mod_addition/3hole_2803_0-113/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(79, 73), (23, 43), (44, 81)] 2 | lengths:[10, 19, 14] -------------------------------------------------------------------------------- /datasets/mod_addition/3hole_2811_0-113/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/mod_addition/3hole_2811_0-113/data_split.png -------------------------------------------------------------------------------- /datasets/mod_addition/3hole_2811_0-113/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(21, 43), (80, 93), (49, 88)] 2 | lengths:[17, 15, 12] -------------------------------------------------------------------------------- /datasets/mod_addition/3hole_3187_0-113/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/mod_addition/3hole_3187_0-113/data_split.png -------------------------------------------------------------------------------- /datasets/mod_addition/3hole_3187_0-113/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(79, 81), (29, 56), (14, 86)] 2 | lengths:[19, 17, 10] -------------------------------------------------------------------------------- /datasets/mod_addition/cot_gen.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | 5 | 6 | def extract_answer(rationale): 7 | answer = rationale.split("\n")[-1].split(",")[1] 8 | return answer 9 | 10 | def count_c(a_digit, b_digit, c, mode): 11 | a_digit = int(a_digit) 12 | b_digit = int(b_digit) 13 | c = int(c) 14 | if mode == "add": 15 | if a_digit + b_digit + c>= 10: 16 | return 1 17 | else: 18 | return 0 19 | elif mode == "subtract": 20 | if a_digit - b_digit + c < 0: 21 | return -1 22 | else: 23 | return 0 24 | 25 | def count_sub(a_digit, b_digit, c): 26 | r = a_digit - b_digit + c 27 | if r < 0: 28 | return str(10 + r) 29 | else: 30 | return str(r) 31 | 32 | def gen_cot_rationale(a, b, mode="add"): 33 | ''' 34 | return the cot rationale for the question 35 | ''' 36 | a_digits = [digit for digit in str(a)] 37 | b_digits = [digit for digit in str(b)] 38 | rationale = "" 39 | answer = "" 40 | c = 0 41 | if mode == "add": 42 | gt = int(a) + int(b) 43 | for _ in range(len(str(int(a)+int(b)))+1): 44 | line = f"{''.join(a_digits)}+{''.join(b_digits)},{answer},C:{c}\n" 45 | rationale += line 46 | if a_digits and b_digits: 47 | answer = str(int(a_digits[-1]) + int(b_digits[-1]) + c)[-1] + answer 48 | c = count_c(a_digits[-1], b_digits[-1], c, mode) 49 | a_digits.pop() 50 | b_digits.pop() 51 | elif a_digits: 52 | answer = str(int(a_digits[-1]) + c)[-1] + answer 53 | c = count_c(a_digits[-1], 0, c, mode) 54 | a_digits.pop() 55 | elif b_digits: 56 | answer = str(int(b_digits[-1]) + c)[-1] + answer 57 | c = count_c(0, b_digits[-1], c, mode) 58 | b_digits.pop() 59 | else: 60 | if c: 61 | answer = str(c) + answer 62 | c = 0 63 | elif mode == "subtract": 64 | gt = int(a) - int(b) 65 | for _ in range(max(len(str(a)), len(str(b)))+1): 66 | line = f"{''.join(a_digits)}-{''.join(b_digits)},{answer},C:{c}\n" 67 | rationale += line 68 | if a_digits and b_digits: 69 | answer = count_sub(int(a_digits[-1]),int(b_digits[-1]),c)[-1] + answer 70 | c = count_c(a_digits[-1], b_digits[-1], c, mode) 71 | a_digits.pop() 72 | b_digits.pop() 73 | elif a_digits: 74 | answer = count_sub(int(a_digits[-1]),0,c)[-1] + answer 75 | c = count_c(a_digits[-1], 0, c, mode) 76 | a_digits.pop() 77 | elif b_digits: 78 | answer = count_sub(0,int(b_digits[-1]),c)[-1] + answer 79 | c = count_c(0, b_digits[-1], c, mode) 80 | b_digits.pop() 81 | else: 82 | if c: 83 | answer = count_sub(0,0,c) + answer 84 | c = 0 85 | rationale = rationale.strip() 86 | assert int(extract_answer(rationale)) == int(gt) 87 | return f"{rationale}\n{gt}" 88 | 89 | def gen_mod_add_cot_rationale(a, b, P): 90 | rationale = gen_cot_rationale(a, b, mode="add") 91 | if a + b >= P: 92 | rationale += f"\n{a+b}>={P}\n{a+b}-{P}=\n" 93 | rationale += gen_cot_rationale(a+b, P, mode="subtract") 94 | else: 95 | rationale += f"\n{a+b}<{P}\n{a+b}" 96 | try: 97 | assert str((a+b)%P) == re.findall(r'[0-9]+\.?[0-9]*', rationale)[-1] 98 | except: 99 | print(rationale) 100 | print((a+b)%P) 101 | raise AssertionError() 102 | return rationale 103 | 104 | 105 | if __name__ == "__main__": 106 | title = "1hole_(56, 56)_10_441_0-113" 107 | P = 113 108 | with open(f"{title}/train.json", "r") as f: 109 | train_samples = json.load(f) 110 | 111 | with open(f"{title}/test.json", "r") as f: 112 | test_samples = json.load(f) 113 | 114 | cot_train_samples = [] 115 | for train_sample in train_samples: 116 | gt = train_sample['answers'][0] 117 | train_sample['answers'] = [gen_mod_add_cot_rationale(train_sample["a"], train_sample["b"], P)] 118 | train_sample["gt"] = gt 119 | cot_train_samples.append(train_sample) 120 | 121 | cot_test_samples = [] 122 | for test_sample in test_samples: 123 | gt = test_sample['answers'][0] 124 | test_sample['answers'] = [gen_mod_add_cot_rationale(test_sample["a"], test_sample["b"], P)] 125 | test_sample["gt"] = gt 126 | cot_test_samples.append(test_sample) 127 | 128 | os.mkdir(f"{title}_cot") 129 | with open(f"{title}_cot/train.json", "w") as f: 130 | json.dump(cot_train_samples, f) 131 | 132 | with open(f"{title}_cot/test.json", "w") as f: 133 | json.dump(cot_test_samples, f) 134 | -------------------------------------------------------------------------------- /datasets/mod_addition/create.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The script is to create the mod addition dataset (mod 113). 3 | 4 | Here's an example: 5 | 100+20=7 6 | ''' 7 | 8 | import json 9 | import math 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import random_split 13 | import os 14 | import random 15 | import interval 16 | from tqdm import tqdm 17 | import matplotlib.pyplot as plt 18 | 19 | P = 113 20 | def generate_dataset(a_range, b_range): 21 | idx = 0 22 | samples = [] 23 | for a in a_range: 24 | for b in b_range: 25 | a = int(a) 26 | b = int(b) 27 | samples.append({ 28 | "id": idx, 29 | "question": f"{a}+{b}=", 30 | "a": a, 31 | "b": b, 32 | "answers": [int(np.mod(a + b, P))] 33 | }) 34 | idx += 1 35 | print(len(samples)) 36 | return samples 37 | 38 | def generate_test(a_range, b_range, length_range, square_num): 39 | generated_square = 0 40 | centers = [] 41 | lengths = [] 42 | a_start = a_range[0] 43 | a_end = a_range[-1] 44 | b_start = b_range[0] 45 | b_end = b_range[-1] 46 | a_intervals = [] 47 | b_intervals = [] 48 | test_num = 0 49 | while generated_square < square_num: 50 | # generate centers 51 | center = (random.randint(a_start, a_end), 52 | random.randint(b_start, b_end)) 53 | # for each center, generate length, make sure that center +- length is still in the range 54 | length = random.randint(length_range[0], length_range[-1]) 55 | a_interval = interval.Interval(center[0] - length, center[0] + length) 56 | b_interval = interval.Interval(center[1] - length, center[1] + length) 57 | if a_interval in interval.Interval(a_start, a_end) and b_interval in interval.Interval(b_start, b_end): 58 | overlap = False 59 | # make sure the squares do not overlap 60 | for idx in range(generated_square): 61 | if a_interval.overlaps(a_intervals[idx]) and b_interval.overlaps(b_intervals[idx]): 62 | overlap = True 63 | break 64 | if not overlap: 65 | centers.append(center) 66 | lengths.append(length) 67 | a_intervals.append(a_interval) 68 | b_intervals.append(b_interval) 69 | generated_square += 1 70 | test_num += (2 * length + 1) ** 2 71 | return centers, lengths, test_num 72 | 73 | 74 | if __name__ == "__main__": 75 | start = 0 76 | end = P 77 | dataset = generate_dataset(np.arange(start,end), np.arange(start,end)) 78 | status = "random_split" # choose from "random_split" or "one_hole" or "multi_holes" 79 | status = "one_hole" 80 | status = "multi_holes" 81 | train_ratio = 0.7 82 | hole_num = 3 83 | center_a = round((start + end)/2) 84 | center_b = round((start + end)/2) 85 | length = 10 86 | 87 | if status == "random_split": 88 | train_num = int(len(dataset) * train_ratio) 89 | test_num = len(dataset) - train_num 90 | train_set, test_set = random_split(dataset=dataset, lengths=[train_num, test_num], generator=torch.Generator().manual_seed(42)) 91 | title = f"random_split_{train_ratio}_{train_num}_{test_num}_{start}-{end}" 92 | os.mkdir(f"{title}") 93 | with open(f"{title}/train.json", "w") as f: 94 | json.dump(list(train_set), f) 95 | with open(f"{title}/test.json", "w") as f: 96 | json.dump(list(test_set), f) 97 | 98 | elif status == "one_hole": 99 | center = (center_a,center_b) 100 | assert length * 2 + 1 <= np.sqrt(len(dataset) * (1-train_ratio)) 101 | 102 | title = f"1hole_{center}_{length}_{(length * 2 + 1)**2}_{start}-{end}" 103 | test_set = [] 104 | train_set = [] 105 | img = np.zeros((end-start, end-start)) 106 | for sample in tqdm(dataset): 107 | test = False 108 | if center[0] - length <= sample["a"] <= center[0] + length and center[1] - length <= sample["b"] <= center[1] + length: 109 | test_set.append(sample) 110 | test = True 111 | img[sample["a"], sample["b"]] = 2 112 | else: 113 | train_set.append(sample) 114 | img[sample["a"], sample["b"]] = 1 115 | os.mkdir(f"{title}") 116 | with open(f"{title}/train.json", "w") as f: 117 | json.dump(list(train_set), f) 118 | with open(f"{title}/test.json", "w") as f: 119 | json.dump(list(test_set), f) 120 | plt.imshow(img) 121 | plt.savefig(f"{title}/data_split") 122 | 123 | elif status == "multi_holes": 124 | centers, lengths, test_num = generate_test(np.arange(start, end), np.arange(start, end), np.arange(10,20), hole_num) 125 | title = f"{hole_num}hole_{test_num}_{start}-{end}" 126 | test_set = [] 127 | train_set = [] 128 | img = np.zeros((end-start, end-start)) 129 | for sample in tqdm(dataset): 130 | test = False 131 | for idx, center in enumerate(centers): 132 | length = lengths[idx] 133 | if center[0] - length <= sample["a"] <= center[0] + length and center[1] - length <= sample["b"] <= center[1] + length: 134 | test_set.append(sample) 135 | test = True 136 | img[sample["a"], sample["b"]] = 1 137 | break 138 | if test:pass 139 | else: 140 | train_set.append(sample) 141 | img[sample["a"], sample["b"]] = 0 142 | os.mkdir(f"{title}") 143 | with open(f"{title}/train.json", "w") as f: 144 | json.dump(list(train_set), f) 145 | with open(f"{title}/test.json", "w") as f: 146 | json.dump(list(test_set), f) 147 | with open(f"{title}/test_squares.txt", "w") as f: 148 | f.write("centers:{}\nlengths:{}".format(centers, lengths)) 149 | plt.imshow(img) 150 | plt.savefig(f"{title}/data_split") 151 | -------------------------------------------------------------------------------- /datasets/mod_addition/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 39, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "def extract_answer(rationale):\n", 10 | " answer = str(int(rationale.split(\"\\n\")[-1].split(\",\")[1]))\n", 11 | " return answer\n", 12 | "\n", 13 | "def count_c(a_digit, b_digit, c, mode):\n", 14 | " a_digit = int(a_digit)\n", 15 | " b_digit = int(b_digit)\n", 16 | " c = int(c)\n", 17 | " if mode == \"add\":\n", 18 | " if a_digit + b_digit + c>= 10:\n", 19 | " return 1\n", 20 | " else:\n", 21 | " return 0\n", 22 | " elif mode == \"subtract\":\n", 23 | " if a_digit - b_digit + c < 0:\n", 24 | " return -1\n", 25 | " else:\n", 26 | " return 0\n", 27 | "\n", 28 | "def count_sub(a_digit, b_digit, c):\n", 29 | " r = a_digit - b_digit + c\n", 30 | " if r < 0:\n", 31 | " return str(10 + r)\n", 32 | " else:\n", 33 | " return str(r)\n", 34 | "\n", 35 | "def gen_cot_rationale(a, b, mode=\"add\"):\n", 36 | " '''\n", 37 | " return the cot rationale for the question\n", 38 | " '''\n", 39 | " a_digits = [digit for digit in str(a)]\n", 40 | " b_digits = [digit for digit in str(b)]\n", 41 | " rationale = \"\"\n", 42 | " answer = \"\"\n", 43 | " c = 0\n", 44 | " if mode == \"add\":\n", 45 | " gt = int(a) + int(b)\n", 46 | " for _ in range(len(str(int(a)+int(b)))+1):\n", 47 | " line = f\"{''.join(a_digits)}+{''.join(b_digits)},{answer},C:{c}\\n\"\n", 48 | " rationale += line\n", 49 | " if a_digits and b_digits:\n", 50 | " answer = str(int(a_digits[-1]) + int(b_digits[-1]) + c)[-1] + answer\n", 51 | " c = count_c(a_digits[-1], b_digits[-1], c, mode)\n", 52 | " a_digits.pop()\n", 53 | " b_digits.pop()\n", 54 | " elif a_digits:\n", 55 | " answer = str(int(a_digits[-1]) + c)[-1] + answer\n", 56 | " c = count_c(a_digits[-1], 0, c, mode)\n", 57 | " a_digits.pop()\n", 58 | " elif b_digits:\n", 59 | " answer = str(int(b_digits[-1]) + c)[-1] + answer\n", 60 | " c = count_c(0, b_digits[-1], c, mode)\n", 61 | " b_digits.pop()\n", 62 | " else:\n", 63 | " if c:\n", 64 | " answer = str(c) + answer\n", 65 | " c = 0\n", 66 | " elif mode == \"subtract\":\n", 67 | " gt = int(a) - int(b)\n", 68 | " for _ in range(max(len(str(a)), len(str(b)))+1):\n", 69 | " line = f\"{''.join(a_digits)}-{''.join(b_digits)},{answer},C:{c}\\n\"\n", 70 | " rationale += line\n", 71 | " if a_digits and b_digits:\n", 72 | " answer = count_sub(int(a_digits[-1]),int(b_digits[-1]),c)[-1] + answer\n", 73 | " c = count_c(a_digits[-1], b_digits[-1], c, mode)\n", 74 | " a_digits.pop()\n", 75 | " b_digits.pop()\n", 76 | " elif a_digits:\n", 77 | " answer = count_sub(int(a_digits[-1]),0,c)[-1] + answer\n", 78 | " c = count_c(a_digits[-1], 0, c, mode)\n", 79 | " a_digits.pop()\n", 80 | " elif b_digits:\n", 81 | " answer = count_sub(0,int(b_digits[-1]),c)[-1] + answer\n", 82 | " c = count_c(0, b_digits[-1], c, mode)\n", 83 | " b_digits.pop()\n", 84 | " else:\n", 85 | " if c:\n", 86 | " answer = count_sub(0,0,c) + answer\n", 87 | " c = 0\n", 88 | " rationale = rationale.strip()\n", 89 | " assert int(extract_answer(rationale)) == int(gt)\n", 90 | " return f\"{rationale}\\n{gt}\"\n", 91 | "\n", 92 | "def gen_mod_add_cot_rationale(a, b, P):\n", 93 | " rationale = gen_cot_rationale(a, b, mode=\"add\")\n", 94 | " if a + b >= P:\n", 95 | " rationale += f\"\\n{a+b}>={P}\\n\"\n", 96 | " rationale += gen_cot_rationale(a+b, P, mode=\"subtract\")\n", 97 | " else:\n", 98 | " rationale += f\"\\n{a+b}<{P}\\n{a+b}\"\n", 99 | " return rationale" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 40, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "102+98,,C:0\n", 112 | "10+9,0,C:1\n", 113 | "1+,00,C:1\n", 114 | "+,200,C:0\n", 115 | "200\n", 116 | "200>=113\n", 117 | "200-113,,C:0\n", 118 | "20-11,7,C:-1\n", 119 | "2-1,87,C:-1\n", 120 | "-,087,C:0\n", 121 | "87\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "print(gen_mod_add_cot_rationale(102, 98, 113))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 41, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "for a in range(113,226):\n", 136 | " gen_cot_rationale(a,113)" 137 | ] 138 | } 139 | ], 140 | "metadata": { 141 | "kernelspec": { 142 | "display_name": "Python 3.9.12 ('base': conda)", 143 | "language": "python", 144 | "name": "python3" 145 | }, 146 | "language_info": { 147 | "codemirror_mode": { 148 | "name": "ipython", 149 | "version": 3 150 | }, 151 | "file_extension": ".py", 152 | "mimetype": "text/x-python", 153 | "name": "python", 154 | "nbconvert_exporter": "python", 155 | "pygments_lexer": "ipython3", 156 | "version": "3.9.12" 157 | }, 158 | "orig_nbformat": 4, 159 | "vscode": { 160 | "interpreter": { 161 | "hash": "2344b6d4cf75e2fe63d7adea2acd8b07cf02ecdef8a7e7834a9c3ab9d9f0906f" 162 | } 163 | } 164 | }, 165 | "nbformat": 4, 166 | "nbformat_minor": 2 167 | } 168 | -------------------------------------------------------------------------------- /datasets/rabbits_and_chickens/1hole_(70, 50)_10_441_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/rabbits_and_chickens/1hole_(70, 50)_10_441_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/rabbits_and_chickens/1hole_(75, 50)_10_441_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/rabbits_and_chickens/1hole_(75, 50)_10_441_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/rabbits_and_chickens/3hole_363_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/rabbits_and_chickens/3hole_363_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/rabbits_and_chickens/3hole_363_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(130, 82), (88, 56), (107, 89)] 2 | lengths:[5, 5, 5] -------------------------------------------------------------------------------- /datasets/rabbits_and_chickens/3hole_459_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/rabbits_and_chickens/3hole_459_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/rabbits_and_chickens/3hole_459_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(124, 71), (68, 49), (124, 84)] 2 | lengths:[6, 6, 5] -------------------------------------------------------------------------------- /datasets/rabbits_and_chickens/3hole_507_0-100/data_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/datasets/rabbits_and_chickens/3hole_507_0-100/data_split.png -------------------------------------------------------------------------------- /datasets/rabbits_and_chickens/3hole_507_0-100/test_squares.txt: -------------------------------------------------------------------------------- 1 | centers:[(132, 80), (104, 74), (66, 47)] 2 | lengths:[6, 6, 6] -------------------------------------------------------------------------------- /datasets/rabbits_and_chickens/create.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The script is to create the rabbits_and_chickens dataset. 3 | 4 | Here's an example: 5 | Q: Rabbits have 4 legs and 1 head. Chickens have 2 legs and 1 head. There are <> legs and <> heads on the farm. How many rabbits and chickens are there? 6 | A: There are <<(a-2b)/2>> rabbits and <<(4b-a)/2>> chickens. 7 | ''' 8 | 9 | import json 10 | import math 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import random_split 14 | import os 15 | import random 16 | import interval 17 | from tqdm import tqdm 18 | import matplotlib.pyplot as plt 19 | 20 | def generate_dataset(min_num_head, max_num_head): 21 | samples = [] 22 | idx = 0 23 | question = "Rabbits have 4 legs and 1 head. Chickens have 2 legs and 1 head. There are {} legs and {} heads on the farm. How many rabbits and chickens are there?" 24 | answer = "There are {} rabbits and {} chickens." 25 | for heads in range(min_num_head, max_num_head): 26 | for legs in range(2 * heads, 4 * heads + 1, 2): 27 | samples.append({ 28 | "id": idx, 29 | "question": question.format(legs, heads), 30 | "answers": [answer.format(int((legs - 2 * heads)/2), int((4 * heads - legs)/2))], 31 | "a": legs, 32 | "b": heads 33 | }) 34 | idx += 1 35 | print(len(samples)) 36 | return samples 37 | 38 | def generate_test(a_range, b_range, length_range, square_num): 39 | generated_square = 0 40 | centers = [] 41 | lengths = [] 42 | a_start = a_range[0] 43 | a_end = a_range[-1] 44 | b_start = b_range[0] 45 | b_end = b_range[-1] 46 | a_intervals = [] 47 | b_intervals = [] 48 | test_num = 0 49 | while generated_square < square_num: 50 | # generate centers 51 | b_center = random.randint(b_start, b_end) # head 52 | a_center = random.randint(b_center, b_center * 2) 53 | center = (a_center, b_center) 54 | # for each center, generate length, make sure that center +- length is still in the range 55 | lb = length_range[0] 56 | ub = min(length_range[1], int((a_center-b_center)/2), int((2*b_center-a_center)/3)) 57 | if lb <= ub: 58 | length = random.randint(lb,ub) 59 | a_interval = interval.Interval(center[0] - length, center[0] + length) 60 | b_interval = interval.Interval(center[1] - length, center[1] + length) 61 | if a_interval in interval.Interval(a_start, a_end) and b_interval in interval.Interval(b_start, b_end): 62 | overlap = False 63 | # make sure the squares do not overlap 64 | for idx in range(generated_square): 65 | if a_interval.overlaps(a_intervals[idx]) and b_interval.overlaps(b_intervals[idx]): 66 | overlap = True 67 | break 68 | if not overlap: 69 | centers.append(center) 70 | lengths.append(length) 71 | a_intervals.append(a_interval) 72 | b_intervals.append(b_interval) 73 | generated_square += 1 74 | test_num += (length * 2 + 1) ** 2 75 | return centers, lengths, test_num 76 | 77 | 78 | if __name__ == "__main__": 79 | start = 0 80 | end = 100 81 | dataset = generate_dataset(start, end) 82 | # status = "random_split" # choose from "random_split" or "one_hole" or "multi_holes" 83 | status = "multi_holes" 84 | train_ratio = 0.7 85 | hole_num = 3 86 | center_head = int((start + end)/2) 87 | center_leg = round(center_head * 3 / 2) 88 | center_leg = 70 89 | length = 10 90 | 91 | if status == "random_split": 92 | train_num = int(len(dataset) * train_ratio) 93 | test_num = len(dataset) - train_num 94 | train_set, test_set = random_split(dataset=dataset, lengths=[train_num, test_num], generator=torch.Generator().manual_seed(42)) 95 | title = f"random_split_{train_ratio}_{train_num}_{test_num}_{start}-{end}" 96 | os.mkdir(f"{title}") 97 | with open(f"{title}/train.json", "w") as f: 98 | json.dump(list(train_set), f) 99 | with open(f"{title}/test.json", "w") as f: 100 | json.dump(list(test_set), f) 101 | 102 | elif status == "one_hole": 103 | center = (center_leg,center_head) 104 | assert length * 2 + 1 <= np.sqrt(len(dataset) * (1-train_ratio)) 105 | 106 | title = f"1hole_{center}_{length}_{(length * 2 + 1)**2}_{start}-{end}" 107 | test_set = [] 108 | train_set = [] 109 | img = np.zeros((2*(end-start), end-start)) # range of (leg_num/2,head_num) 110 | for sample in tqdm(dataset): 111 | test = False 112 | if center[0] - length <= sample["a"]/2 <= center[0] + length and center[1] - length <= sample["b"] <= center[1] + length: 113 | test_set.append(sample) 114 | test = True 115 | img[round(sample["a"]/2), sample["b"]] = 2 116 | else: 117 | train_set.append(sample) 118 | img[round(sample["a"]/2), sample["b"]] = 1 119 | os.mkdir(f"{title}") 120 | with open(f"{title}/train.json", "w") as f: 121 | json.dump(list(train_set), f) 122 | with open(f"{title}/test.json", "w") as f: 123 | json.dump(list(test_set), f) 124 | plt.imshow(img) 125 | plt.savefig(f"{title}/data_split") 126 | 127 | elif status == "multi_holes": 128 | centers, lengths, test_num = generate_test(np.arange(2 * start, 2 * end), np.arange(start, end), np.arange(5,15), hole_num) 129 | title = f"{hole_num}hole_{test_num}_{start}-{end}" 130 | test_set = [] 131 | train_set = [] 132 | img = np.zeros((2*(end-start), end-start)) # range of (leg_num/2,head_num) 133 | for sample in tqdm(dataset): 134 | test = False 135 | for idx, center in enumerate(centers): 136 | length = lengths[idx] 137 | if center[0] - length <= sample["a"]/2 <= center[0] + length and center[1] - length <= sample["b"] <= center[1] + length: 138 | test_set.append(sample) 139 | test = True 140 | img[round(sample["a"]/2), sample["b"]] = 2 141 | break 142 | if test:pass 143 | else: 144 | train_set.append(sample) 145 | img[round(sample["a"]/2), sample["b"]] = 1 146 | os.mkdir(f"{title}") 147 | with open(f"{title}/train.json", "w") as f: 148 | json.dump(list(train_set), f) 149 | with open(f"{title}/test.json", "w") as f: 150 | json.dump(list(test_set), f) 151 | with open(f"{title}/test_squares.txt", "w") as f: 152 | f.write("centers:{}\nlengths:{}".format(centers, lengths)) 153 | plt.imshow(img) 154 | plt.savefig(f"{title}/data_split") 155 | -------------------------------------------------------------------------------- /datasets/rabbits_and_chickens/split.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import random_split 3 | import json 4 | 5 | if __name__ == "__main__": 6 | task = "rabbits_and_chickens" 7 | 8 | with open(f"{task}.json", "r") as f: 9 | dataset = json.load(f) 10 | 11 | train_set, test_set = random_split(dataset=dataset, lengths=[3030, 2020], generator=torch.Generator().manual_seed(0)) 12 | 13 | with open(f"{task}_train.json", "w") as f: 14 | json.dump(list(train_set), f) 15 | 16 | with open(f"{task}_test.json", "w") as f: 17 | json.dump(list(test_set), f) 18 | -------------------------------------------------------------------------------- /icl/icl_learning.py: -------------------------------------------------------------------------------- 1 | from addition_icl import BaseInt, test_a_b, is_easy_add 2 | from tqdm import tqdm 3 | import pickle 4 | import random 5 | 6 | NUM_TEST = 100 7 | max_length = 10 8 | task = 'rf' 9 | 10 | few_shot_digit: list[tuple[int, int]] = [(2,2), (3,3), (5,5), (2,3), (3,5)] 11 | accu_list = [] 12 | 13 | few_shot = [] 14 | for length in few_shot_digit: 15 | while True: 16 | a = BaseInt.random_generate_length(length[0]) 17 | b = BaseInt.random_generate_length(length[1]) 18 | if random.random() < 0.5: 19 | a, b = b, a 20 | if not is_easy_add(a, b): 21 | few_shot.append((a, b)) 22 | break 23 | random.shuffle(few_shot) 24 | tqdm.write(f'Few shot examples: {few_shot}\n\n') 25 | 26 | for longer_length in tqdm(range(1, max_length+1)): 27 | accu = 0 28 | for i in tqdm(range(NUM_TEST)): 29 | a = BaseInt.random_generate_length(longer_length) 30 | shorter_length = random.randint(1, longer_length) 31 | b = BaseInt.random_generate_length(shorter_length) 32 | if random.random() < 0.5: 33 | a, b = b, a 34 | accu += test_a_b(a, b, few_shot, test_num=10, detail=False, type=task) 35 | accu_list.append(accu / NUM_TEST) 36 | tqdm.write(f"Length {longer_length} Accuracy: {accu / NUM_TEST}") 37 | 38 | with open(f'icl_learning_{task}_addition.pkl', 'wb') as f: 39 | pickle.dump(accu_list, f) 40 | -------------------------------------------------------------------------------- /icl/icl_learning_base.py: -------------------------------------------------------------------------------- 1 | from base_addition_icl import BaseInt, test_a_b, is_easy_add 2 | from tqdm import tqdm 3 | import pickle 4 | import random 5 | 6 | NUM_TEST = 100 7 | max_length = 10 8 | task = 'rf' 9 | 10 | few_shot_digit: list[tuple[int, int]] = [(2,2), (3,3), (5,5), (2,3), (3,5)] 11 | accu_list = [] 12 | 13 | few_shot = [] 14 | for length in few_shot_digit: 15 | while True: 16 | a = BaseInt.random_generate_length(length[0]) 17 | b = BaseInt.random_generate_length(length[1]) 18 | if random.random() < 0.5: 19 | a, b = b, a 20 | if not is_easy_add(a, b): 21 | few_shot.append((a, b)) 22 | break 23 | random.shuffle(few_shot) 24 | tqdm.write(f'Few shot examples: {few_shot}\n\n') 25 | 26 | for longer_length in tqdm(range(1, max_length+1)): 27 | accu = 0 28 | for i in tqdm(range(NUM_TEST)): 29 | a = BaseInt.random_generate_length(longer_length) 30 | shorter_length = random.randint(1, longer_length) 31 | b = BaseInt.random_generate_length(shorter_length) 32 | if random.random() < 0.5: 33 | a, b = b, a 34 | accu += test_a_b(a, b, few_shot, test_num=10, detail=False, type=task) 35 | accu_list.append(accu / NUM_TEST) 36 | tqdm.write(f"Length {longer_length} Accuracy: {accu / NUM_TEST}") 37 | 38 | with open(f'icl_learning_{task}_base.pkl', 'wb') as f: 39 | pickle.dump(accu_list, f) 40 | -------------------------------------------------------------------------------- /icl/prompt_base10.py: -------------------------------------------------------------------------------- 1 | QUESTION='''Follow the code step by step to answer the question: 2 | {}+{}=''' 3 | 4 | CODE=''' 5 | def sum_digit_by_digit(num1, num2): 6 | # Initialize the result list and carry 7 | result=[] 8 | carry=0 9 | # Loop through each digit 10 | while num1 or num2: 11 | # Get the current digits, defaulting to 0 if one number is shorter 12 | digit1=num1.pop() if num1 else 0 13 | digit2=num2.pop() if num2 else 0 14 | # Calculate the sum of the current digits and the carry 15 | total=digit1+digit2+carry 16 | # Insert the last digit of total to the beginning of the result and update carry 17 | result.insert(0,total%10) 18 | carry=total//10 19 | # If there's a remaining carry, insert it to the beginning of the result 20 | if carry: 21 | result.insert(0, carry) 22 | # Return the result 23 | return result''' 24 | 25 | NUM=''' 26 | num1={} 27 | num2={}''' 28 | 29 | INITIALIZE=''' 30 | 1. Initialize Result and Carry 31 | result=[] 32 | carry=0 33 | 34 | 2. Loop Through Each Digit''' 35 | 36 | CHECK_THE_STOP_CRITERION_2_1_ENTER=''' 37 | ``` 38 | while num1 or num2: 39 | ``` 40 | 2.1 check the stop criterion 41 | num1={} 42 | num2={} 43 | bool(num1)={} 44 | bool(num2)={} 45 | num1 or num2={} 46 | enter the loop''' 47 | 48 | CHECK_THE_STOP_CRITERION_2_1_END=''' 49 | ``` 50 | while num1 or num2: 51 | ``` 52 | 2.1 check the stop criterion 53 | num1={} 54 | num2={} 55 | bool(num1)={} 56 | bool(num2)={} 57 | num1 or num2={} 58 | end the loop''' 59 | 60 | 61 | ONE_ITERATION_2_2=''' 62 | 2.2 one iteration''' 63 | 64 | POP_DIGIT=''' 65 | ``` 66 | digit{0}=num{0}.pop() if num{0} else 0 67 | ``` 68 | num{0}={1} 69 | bool(num{0})={2} 70 | num{0}.pop() 71 | num{0}={3} 72 | digit{0}={4}''' 73 | 74 | NO_POP_DIGIT=''' 75 | ``` 76 | digit{0}=num{0}.pop() if num{0} else 0 77 | ``` 78 | num{0}=[] 79 | bool(num{0})=False 80 | num{0}=[] 81 | digit{0}=0''' 82 | 83 | TOTAL_RESULT_CARRY= ''' 84 | ``` 85 | total=digit1+digit2+carry 86 | ``` 87 | total=digit1+digit2+carry={}+{}+{}={} 88 | ``` 89 | result.insert(0,total%10) 90 | ``` 91 | result={} 92 | total%10={}%10={} 93 | result={} 94 | ``` 95 | carry=total//10 96 | ``` 97 | carry={}//10={} 98 | 2.3 back to the start of the loop''' 99 | 100 | CHECK_REMAINING_CARRY_FALSE=''' 101 | 3. Check Remaining Carry 102 | ``` 103 | if carry: 104 | result.insert(0, carry) 105 | ``` 106 | result={0} 107 | carry=0 108 | bool(carry)=False 109 | pass 110 | result={0}''' 111 | 112 | CHECK_REMAINING_CARRY_TRUE=''' 113 | 3. Check Remaining Carry 114 | ``` 115 | if carry: 116 | result.insert(0, carry) 117 | ``` 118 | result={} 119 | carry=1 120 | bool(carry)=True 121 | result={}''' 122 | 123 | RETURN_THE_RESULT=''' 124 | 4. Return Result 125 | ``` 126 | return result 127 | ``` 128 | result={} 129 | ''' -------------------------------------------------------------------------------- /icl/prompt_base9.py: -------------------------------------------------------------------------------- 1 | QUESTION='''Follow the code step by step to answer the question: 2 | {}+{}=''' 3 | 4 | CODE=''' 5 | def sum_digit_by_digit(num1, num2): 6 | # Initialize the result list and carry 7 | result=[] 8 | carry=0 9 | # Loop through each digit 10 | while num1 or num2: 11 | # Get the current digits, defaulting to 0 if one number is shorter 12 | digit1=num1.pop() if num1 else 0 13 | digit2=num2.pop() if num2 else 0 14 | # Calculate the sum of the current digits and the carry 15 | total=digit1+digit2+carry 16 | # Insert the last digit of total to the beginning of the result and update carry 17 | result.insert(0,total%9) 18 | carry=total//9 19 | # If there's a remaining carry, insert it to the beginning of the result 20 | if carry: 21 | result.insert(0, carry) 22 | # Return the result 23 | return result''' 24 | 25 | NUM=''' 26 | num1={} 27 | num2={}''' 28 | 29 | INITIALIZE=''' 30 | 1. Initialize Result and Carry 31 | result=[] 32 | carry=0 33 | 34 | 2. Loop Through Each Digit''' 35 | 36 | CHECK_THE_STOP_CRITERION_2_1_ENTER=''' 37 | ``` 38 | while num1 or num2: 39 | ``` 40 | 2.1 check the stop criterion 41 | num1={} 42 | num2={} 43 | bool(num1)={} 44 | bool(num2)={} 45 | num1 or num2={} 46 | enter the loop''' 47 | 48 | CHECK_THE_STOP_CRITERION_2_1_END=''' 49 | ``` 50 | while num1 or num2: 51 | ``` 52 | 2.1 check the stop criterion 53 | num1={} 54 | num2={} 55 | bool(num1)={} 56 | bool(num2)={} 57 | num1 or num2={} 58 | end the loop''' 59 | 60 | 61 | ONE_ITERATION_2_2=''' 62 | 2.2 one iteration''' 63 | 64 | POP_DIGIT=''' 65 | ``` 66 | digit{0}=num{0}.pop() if num{0} else 0 67 | ``` 68 | num{0}={1} 69 | bool(num{0})={2} 70 | num{0}.pop() 71 | num{0}={3} 72 | digit{0}={4}''' 73 | 74 | NO_POP_DIGIT=''' 75 | ``` 76 | digit{0}=num{0}.pop() if num{0} else 0 77 | ``` 78 | num{0}=[] 79 | bool(num{0})=False 80 | num{0}=[] 81 | digit{0}=0''' 82 | 83 | TOTAL_RESULT_CARRY= ''' 84 | ``` 85 | total=digit1+digit2+carry 86 | ``` 87 | total=digit1+digit2+carry={}+{}+{}={} 88 | ``` 89 | result.insert(0,total%9) 90 | ``` 91 | result={} 92 | total%9={}%9={} 93 | result={} 94 | ``` 95 | carry=total//9 96 | ``` 97 | carry={}//9={} 98 | 2.3 back to the start of the loop''' 99 | 100 | CHECK_REMAINING_CARRY_FALSE=''' 101 | 3. Check Remaining Carry 102 | ``` 103 | if carry: 104 | result.insert(0, carry) 105 | ``` 106 | result={0} 107 | carry=0 108 | bool(carry)=False 109 | pass 110 | result={0}''' 111 | 112 | CHECK_REMAINING_CARRY_TRUE=''' 113 | 3. Check Remaining Carry 114 | ``` 115 | if carry: 116 | result.insert(0, carry) 117 | ``` 118 | result={} 119 | carry=1 120 | bool(carry)=True 121 | result={}''' 122 | 123 | RETURN_THE_RESULT=''' 124 | 4. Return Result 125 | ``` 126 | return result 127 | ``` 128 | result={} 129 | ''' -------------------------------------------------------------------------------- /icl/readme.md: -------------------------------------------------------------------------------- 1 | # In-context learning 2 | 3 | To use the API of openai, you should copy your api_key into the `base_addition_icl.py` and `addition_icl.py` first. 4 | 5 | `addition_icl.py` and `base_addition_icl.py` provide codes where we filter out some in-context examples and return the contribution of these examples in `.plt` files. To run this code, you should first call the function `generate_test_sample` to generate enough candidates test samples. 6 | 7 | To compare the in-context learning performance of `direct`, `scratchpad` and `rule-following`, you can run the code in `icl_learning.py` and set the task as `direct`, `scratchpad` and `rf`. -------------------------------------------------------------------------------- /llama/README.md: -------------------------------------------------------------------------------- 1 | # Fine-tune LlaMA2 2 | We use FastChat to fine tune Llama models. 3 | 4 | 5 | ## Install 6 | 7 | ### Method 1: With pip 8 | 9 | ```bash 10 | pip3 install "fschat[model_worker,webui]" 11 | ``` 12 | 13 | ### Method 2: From source 14 | 15 | 1. Clone this repository and navigate to the FastChat folder 16 | ```bash 17 | git clone https://github.com/lm-sys/FastChat.git 18 | cd FastChat 19 | ``` 20 | 21 | If you are running on Mac: 22 | ```bash 23 | brew install rust cmake 24 | ``` 25 | 26 | 2. Install Package 27 | ```bash 28 | pip3 install --upgrade pip # enable PEP 660 support 29 | pip3 install -e ".[model_worker,webui]" 30 | ``` 31 | 32 | ## Usage 33 | 34 | You can modify the file `train.sh` and excute the following command: 35 | ```bash 36 | bash train.sh 37 | ``` 38 | 39 | -------------------------------------------------------------------------------- /llama/fashchat/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/llama/fashchat/.DS_Store -------------------------------------------------------------------------------- /llama/fashchat/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.34" 2 | -------------------------------------------------------------------------------- /llama/fashchat/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global constants. 3 | """ 4 | 5 | from enum import IntEnum 6 | import os 7 | 8 | REPO_PATH = os.path.dirname(os.path.dirname(__file__)) 9 | 10 | ##### For the gradio web server 11 | SERVER_ERROR_MSG = ( 12 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 13 | ) 14 | MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES." 15 | CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." 16 | INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE." 17 | SLOW_MODEL_MSG = "⚠️ Both models will show the responses all at once. Please stay patient as it may take over 30 seconds." 18 | RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR TRY OTHER MODELS.**" 19 | # Maximum input length 20 | INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000)) 21 | # Maximum conversation turns 22 | CONVERSATION_TURN_LIMIT = 50 23 | # Session expiration time 24 | SESSION_EXPIRATION_TIME = 3600 25 | # The output dir of log files 26 | LOGDIR = os.getenv("LOGDIR", ".") 27 | # CPU Instruction Set Architecture 28 | CPU_ISA = os.getenv("CPU_ISA") 29 | 30 | 31 | ##### For the controller and workers (could be overwritten through ENV variables.) 32 | CONTROLLER_HEART_BEAT_EXPIRATION = int( 33 | os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) 34 | ) 35 | WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45)) 36 | WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) 37 | WORKER_API_EMBEDDING_BATCH_SIZE = int( 38 | os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4) 39 | ) 40 | 41 | 42 | class ErrorCode(IntEnum): 43 | """ 44 | https://platform.openai.com/docs/guides/error-codes/api-errors 45 | """ 46 | 47 | VALIDATION_TYPE_ERROR = 40001 48 | 49 | INVALID_AUTH_KEY = 40101 50 | INCORRECT_AUTH_KEY = 40102 51 | NO_PERMISSION = 40103 52 | 53 | INVALID_MODEL = 40301 54 | PARAM_OUT_OF_RANGE = 40302 55 | CONTEXT_OVERFLOW = 40303 56 | 57 | RATE_LIMIT = 42901 58 | QUOTA_EXCEEDED = 42902 59 | ENGINE_OVERLOADED = 42903 60 | 61 | INTERNAL_ERROR = 50001 62 | CUDA_OUT_OF_MEMORY = 50002 63 | GRADIO_REQUEST_ERROR = 50003 64 | GRADIO_STREAM_UNKNOWN_ERROR = 50004 65 | CONTROLLER_NO_WORKER = 50005 66 | CONTROLLER_WORKER_TIMEOUT = 50006 67 | -------------------------------------------------------------------------------- /llama/fashchat/llm_judge/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/llama/fashchat/llm_judge/.DS_Store -------------------------------------------------------------------------------- /llama/fashchat/model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/llama/fashchat/model/.DS_Store -------------------------------------------------------------------------------- /llama/fashchat/model/__init__.py: -------------------------------------------------------------------------------- 1 | from fastchat.model.model_adapter import ( 2 | load_model, 3 | get_conversation_template, 4 | add_model_args, 5 | ) 6 | -------------------------------------------------------------------------------- /llama/fashchat/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Apply the delta weights on top of a base model. 3 | 4 | Usage: 5 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta-v1.1 6 | """ 7 | import argparse 8 | import gc 9 | import glob 10 | import json 11 | import os 12 | import shutil 13 | import tempfile 14 | 15 | from huggingface_hub import snapshot_download 16 | import torch 17 | from torch import nn 18 | from tqdm import tqdm 19 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 20 | 21 | 22 | GB = 1 << 30 23 | 24 | 25 | def split_files(model_path, tmp_path, split_size): 26 | if not os.path.exists(model_path): 27 | model_path = snapshot_download(repo_id=model_path) 28 | if not os.path.exists(tmp_path): 29 | os.makedirs(tmp_path) 30 | 31 | file_pattern = os.path.join(model_path, "pytorch_model-*.bin") 32 | files = glob.glob(file_pattern) 33 | 34 | part = 0 35 | try: 36 | for file_path in tqdm(files): 37 | state_dict = torch.load(file_path) 38 | new_state_dict = {} 39 | 40 | current_size = 0 41 | for name, param in state_dict.items(): 42 | param_size = param.numel() * param.element_size() 43 | 44 | if current_size + param_size > split_size: 45 | new_file_name = f"pytorch_model-{part}.bin" 46 | new_file_path = os.path.join(tmp_path, new_file_name) 47 | torch.save(new_state_dict, new_file_path) 48 | current_size = 0 49 | new_state_dict = None 50 | gc.collect() 51 | new_state_dict = {} 52 | part += 1 53 | 54 | new_state_dict[name] = param 55 | current_size += param_size 56 | 57 | new_file_name = f"pytorch_model-{part}.bin" 58 | new_file_path = os.path.join(tmp_path, new_file_name) 59 | torch.save(new_state_dict, new_file_path) 60 | new_state_dict = None 61 | gc.collect() 62 | new_state_dict = {} 63 | part += 1 64 | except Exception as e: 65 | print(f"An error occurred during split_files: {e}") 66 | shutil.rmtree(tmp_path) 67 | raise 68 | 69 | 70 | def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): 71 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) 72 | delta_config = AutoConfig.from_pretrained(delta_path) 73 | 74 | if os.path.exists(target_model_path): 75 | shutil.rmtree(target_model_path) 76 | os.makedirs(target_model_path) 77 | 78 | split_size = 4 * GB 79 | 80 | with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path: 81 | print(f"Split files for the base model to {tmp_base_path}") 82 | split_files(base_model_path, tmp_base_path, split_size) 83 | print(f"Split files for the delta weights to {tmp_delta_path}") 84 | split_files(delta_path, tmp_delta_path, split_size) 85 | 86 | base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin") 87 | base_files = glob.glob(base_pattern) 88 | delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") 89 | delta_files = glob.glob(delta_pattern) 90 | delta_state_dict = torch.load(delta_files[0]) 91 | 92 | print("Applying the delta") 93 | weight_map = {} 94 | total_size = 0 95 | 96 | for i, base_file in tqdm(enumerate(base_files)): 97 | state_dict = torch.load(base_file) 98 | file_name = f"pytorch_model-{i}.bin" 99 | for name, param in state_dict.items(): 100 | if name not in delta_state_dict: 101 | for delta_file in delta_files: 102 | delta_state_dict = torch.load(delta_file) 103 | gc.collect() 104 | if name in delta_state_dict: 105 | break 106 | 107 | state_dict[name] += delta_state_dict[name] 108 | weight_map[name] = file_name 109 | total_size += param.numel() * param.element_size() 110 | gc.collect() 111 | torch.save(state_dict, os.path.join(target_model_path, file_name)) 112 | 113 | with open( 114 | os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w" 115 | ) as f: 116 | json.dump( 117 | {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f 118 | ) 119 | 120 | print(f"Saving the target model to {target_model_path}") 121 | delta_tokenizer.save_pretrained(target_model_path) 122 | delta_config.save_pretrained(target_model_path) 123 | 124 | 125 | def apply_delta(base_model_path, target_model_path, delta_path): 126 | print(f"Loading the delta weights from {delta_path}") 127 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) 128 | delta = AutoModelForCausalLM.from_pretrained( 129 | delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 130 | ) 131 | 132 | print(f"Loading the base model from {base_model_path}") 133 | base = AutoModelForCausalLM.from_pretrained( 134 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 135 | ) 136 | 137 | print("Applying the delta") 138 | for name, param in tqdm(base.state_dict().items(), desc="Applying delta"): 139 | assert name in delta.state_dict() 140 | param.data += delta.state_dict()[name] 141 | 142 | print(f"Saving the target model to {target_model_path}") 143 | base.save_pretrained(target_model_path) 144 | delta_tokenizer.save_pretrained(target_model_path) 145 | 146 | 147 | if __name__ == "__main__": 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument("--base-model-path", type=str, required=True) 150 | parser.add_argument("--target-model-path", type=str, required=True) 151 | parser.add_argument("--delta-path", type=str, required=True) 152 | parser.add_argument( 153 | "--low-cpu-mem", 154 | action="store_true", 155 | help="Lower the cpu memory usage. This will split large files and use " 156 | "disk as swap to reduce the memory usage below 10GB.", 157 | ) 158 | args = parser.parse_args() 159 | 160 | if args.low_cpu_mem: 161 | apply_delta_low_cpu_mem( 162 | args.base_model_path, args.target_model_path, args.delta_path 163 | ) 164 | else: 165 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 166 | -------------------------------------------------------------------------------- /llama/fashchat/model/apply_lora.py: -------------------------------------------------------------------------------- 1 | """ 2 | Apply the LoRA weights on top of a base model. 3 | 4 | Usage: 5 | python3 -m fastchat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B 6 | 7 | Dependency: 8 | pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b 9 | """ 10 | import argparse 11 | 12 | import torch 13 | from peft import PeftModel 14 | from transformers import AutoTokenizer, AutoModelForCausalLM 15 | 16 | 17 | def apply_lora(base_model_path, target_model_path, lora_path): 18 | print(f"Loading the base model from {base_model_path}") 19 | base = AutoModelForCausalLM.from_pretrained( 20 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 21 | ) 22 | base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False) 23 | 24 | print(f"Loading the LoRA adapter from {lora_path}") 25 | 26 | lora_model = PeftModel.from_pretrained( 27 | base, 28 | lora_path, 29 | # torch_dtype=torch.float16 30 | ) 31 | 32 | print("Applying the LoRA") 33 | model = lora_model.merge_and_unload() 34 | 35 | print(f"Saving the target model to {target_model_path}") 36 | model.save_pretrained(target_model_path) 37 | base_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--lora-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_lora(args.base_model_path, args.target_model_path, args.lora_path) 49 | -------------------------------------------------------------------------------- /llama/fashchat/model/compression.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import gc 3 | import glob 4 | import os 5 | 6 | from accelerate import init_empty_weights 7 | from accelerate.utils import set_module_tensor_to_device 8 | from huggingface_hub import snapshot_download 9 | import torch 10 | from torch import Tensor 11 | from torch.nn import functional as F 12 | import torch.nn as nn 13 | from tqdm import tqdm 14 | from transformers import ( 15 | AutoConfig, 16 | AutoModelForCausalLM, 17 | AutoTokenizer, 18 | AutoModel, 19 | AutoModelForSeq2SeqLM, 20 | ) 21 | 22 | 23 | @dataclasses.dataclass 24 | class CompressionConfig: 25 | """Group-wise quantization.""" 26 | 27 | num_bits: int 28 | group_size: int 29 | group_dim: int 30 | symmetric: bool 31 | enabled: bool = True 32 | 33 | 34 | default_compression_config = CompressionConfig( 35 | num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True 36 | ) 37 | 38 | 39 | class CLinear(nn.Module): 40 | """Compressed Linear Layer.""" 41 | 42 | def __init__(self, weight=None, bias=None, device=None): 43 | super().__init__() 44 | if weight is None: 45 | self.weight = None 46 | elif isinstance(weight, Tensor): 47 | self.weight = compress(weight.data.to(device), default_compression_config) 48 | else: 49 | self.weight = weight 50 | self.bias = bias 51 | 52 | def forward(self, input: Tensor) -> Tensor: 53 | weight = decompress(self.weight, default_compression_config) 54 | if self.bias is None: 55 | return F.linear(input.to(weight.dtype), weight) 56 | return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype)) 57 | 58 | 59 | def compress_module(module, target_device): 60 | for attr_str in dir(module): 61 | target_attr = getattr(module, attr_str) 62 | if type(target_attr) == torch.nn.Linear: 63 | setattr( 64 | module, 65 | attr_str, 66 | CLinear(target_attr.weight, target_attr.bias, target_device), 67 | ) 68 | for name, child in module.named_children(): 69 | compress_module(child, target_device) 70 | 71 | 72 | def get_compressed_list(module, prefix=""): 73 | compressed_list = [] 74 | for attr_str in dir(module): 75 | target_attr = getattr(module, attr_str) 76 | if type(target_attr) == torch.nn.Linear: 77 | full_name = ( 78 | f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" 79 | ) 80 | compressed_list.append(full_name) 81 | for name, child in module.named_children(): 82 | child_prefix = f"{prefix}.{name}" if prefix else name 83 | for each in get_compressed_list(child, child_prefix): 84 | compressed_list.append(each) 85 | return compressed_list 86 | 87 | 88 | def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""): 89 | for attr_str in dir(module): 90 | target_attr = getattr(module, attr_str) 91 | if type(target_attr) == torch.nn.Linear: 92 | full_name = ( 93 | f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" 94 | ) 95 | setattr( 96 | module, 97 | attr_str, 98 | CLinear( 99 | compressed_state_dict[full_name], target_attr.bias, target_device 100 | ), 101 | ) 102 | for name, child in module.named_children(): 103 | child_prefix = f"{prefix}.{name}" if prefix else name 104 | apply_compressed_weight( 105 | child, compressed_state_dict, target_device, child_prefix 106 | ) 107 | 108 | 109 | def load_compress_model(model_path, device, torch_dtype, use_fast, revision="main"): 110 | # partially load model 111 | # `use_fast=True`` is not supported for some models. 112 | try: 113 | tokenizer = AutoTokenizer.from_pretrained( 114 | model_path, use_fast=use_fast, revision=revision, trust_remote_code=True 115 | ) 116 | except TypeError: 117 | tokenizer = AutoTokenizer.from_pretrained( 118 | model_path, use_fast=~use_fast, revision=revision, trust_remote_code=True 119 | ) 120 | with init_empty_weights(): 121 | # `trust_remote_code` should be set as `True` for both AutoConfig and AutoModel 122 | config = AutoConfig.from_pretrained( 123 | model_path, 124 | low_cpu_mem_usage=True, 125 | torch_dtype=torch_dtype, 126 | trust_remote_code=True, 127 | revision=revision, 128 | ) 129 | # some models are loaded by AutoModel but not AutoModelForCausalLM, 130 | # such as chatglm, chatglm2 131 | try: 132 | # google/flan-* models are based on an AutoModelForSeq2SeqLM. 133 | if "T5Config" in str(type(config)): 134 | model = AutoModelForSeq2SeqLM.from_config( 135 | config, trust_remote_code=True 136 | ) 137 | else: 138 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) 139 | except NameError: 140 | model = AutoModel.from_config(config, trust_remote_code=True) 141 | linear_weights = get_compressed_list(model) 142 | if os.path.exists(model_path): 143 | # `model_path` is a local folder 144 | base_pattern = os.path.join(model_path, "pytorch_model*.bin") 145 | else: 146 | # `model_path` is a cached Hugging Face repo 147 | # We don't necessarily need to download the model' repo again if there is a cache. 148 | # So check the default huggingface cache first. 149 | model_path_temp = os.path.join( 150 | os.path.expanduser("~"), 151 | ".cache/huggingface/hub", 152 | "models--" + model_path.replace("/", "--"), 153 | "snapshots/", 154 | ) 155 | downloaded = False 156 | if os.path.exists(model_path_temp): 157 | temp_last_dir = os.listdir(model_path_temp)[-1] 158 | model_path_temp = os.path.join(model_path_temp, temp_last_dir) 159 | base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin") 160 | files = glob.glob(base_pattern) 161 | if len(files) > 0: 162 | downloaded = True 163 | 164 | if downloaded: 165 | model_path = model_path_temp 166 | else: 167 | model_path = snapshot_download(model_path, revision=revision) 168 | base_pattern = os.path.join(model_path, "pytorch_model*.bin") 169 | 170 | files = glob.glob(base_pattern) 171 | use_safetensors = False 172 | if len(files) == 0: 173 | base_pattern = os.path.join(model_path, "*.safetensors") 174 | files = glob.glob(base_pattern) 175 | use_safetensors = True 176 | if len(files) == 0: 177 | raise ValueError( 178 | f"Cannot find any model weight files. " 179 | f"Please check your (cached) weight path: {model_path}" 180 | ) 181 | 182 | compressed_state_dict = {} 183 | if use_safetensors: 184 | from safetensors.torch import load_file 185 | for filename in tqdm(files): 186 | if use_safetensors: 187 | tmp_state_dict = load_file(filename) 188 | else: 189 | tmp_state_dict = torch.load( 190 | filename, map_location=lambda storage, loc: storage 191 | ) 192 | for name in tmp_state_dict: 193 | if name in linear_weights: 194 | tensor = tmp_state_dict[name].to(device, dtype=torch_dtype) 195 | compressed_state_dict[name] = compress( 196 | tensor, default_compression_config 197 | ) 198 | else: 199 | compressed_state_dict[name] = tmp_state_dict[name].to( 200 | device, dtype=torch_dtype 201 | ) 202 | tmp_state_dict[name] = None 203 | tensor = None 204 | gc.collect() 205 | torch.cuda.empty_cache() 206 | if device == "xpu": 207 | torch.xpu.empty_cache() 208 | if device == "npu": 209 | torch.npu.empty_cache() 210 | 211 | for name in model.state_dict(): 212 | if name not in linear_weights: 213 | set_module_tensor_to_device( 214 | model, name, device, value=compressed_state_dict[name] 215 | ) 216 | apply_compressed_weight(model, compressed_state_dict, device) 217 | 218 | if torch_dtype == torch.float16: 219 | model.half() 220 | model.to(device) 221 | model.eval() 222 | 223 | return model, tokenizer 224 | 225 | 226 | def compress(tensor, config): 227 | """Simulate group-wise quantization.""" 228 | if not config.enabled: 229 | return tensor 230 | 231 | group_size, num_bits, group_dim, symmetric = ( 232 | config.group_size, 233 | config.num_bits, 234 | config.group_dim, 235 | config.symmetric, 236 | ) 237 | assert num_bits <= 8 238 | 239 | original_shape = tensor.shape 240 | num_groups = (original_shape[group_dim] + group_size - 1) // group_size 241 | new_shape = ( 242 | original_shape[:group_dim] 243 | + (num_groups, group_size) 244 | + original_shape[group_dim + 1 :] 245 | ) 246 | 247 | # Pad 248 | pad_len = (group_size - original_shape[group_dim] % group_size) % group_size 249 | if pad_len != 0: 250 | pad_shape = ( 251 | original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :] 252 | ) 253 | tensor = torch.cat( 254 | [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], 255 | dim=group_dim, 256 | ) 257 | data = tensor.view(new_shape) 258 | 259 | # Quantize 260 | if symmetric: 261 | B = 2 ** (num_bits - 1) - 1 262 | scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] 263 | data = data * scale 264 | data = data.clamp_(-B, B).round_().to(torch.int8) 265 | return data, scale, original_shape 266 | else: 267 | B = 2**num_bits - 1 268 | mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] 269 | mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] 270 | 271 | scale = B / (mx - mn) 272 | data = data - mn 273 | data.mul_(scale) 274 | 275 | data = data.clamp_(0, B).round_().to(torch.uint8) 276 | return data, mn, scale, original_shape 277 | 278 | 279 | def decompress(packed_data, config): 280 | """Simulate group-wise dequantization.""" 281 | if not config.enabled: 282 | return packed_data 283 | 284 | group_size, num_bits, group_dim, symmetric = ( 285 | config.group_size, 286 | config.num_bits, 287 | config.group_dim, 288 | config.symmetric, 289 | ) 290 | 291 | # Dequantize 292 | if symmetric: 293 | data, scale, original_shape = packed_data 294 | data = data / scale 295 | else: 296 | data, mn, scale, original_shape = packed_data 297 | data = data / scale 298 | data.add_(mn) 299 | 300 | # Unpad 301 | pad_len = (group_size - original_shape[group_dim] % group_size) % group_size 302 | if pad_len: 303 | padded_original_shape = ( 304 | original_shape[:group_dim] 305 | + (original_shape[group_dim] + pad_len,) 306 | + original_shape[group_dim + 1 :] 307 | ) 308 | data = data.reshape(padded_original_shape) 309 | indices = [slice(0, x) for x in original_shape] 310 | return data[indices].contiguous() 311 | else: 312 | return data.view(original_shape) 313 | -------------------------------------------------------------------------------- /llama/fashchat/model/convert_fp16.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.convert_fp16 --in in-folder --out out-folder 4 | """ 5 | import argparse 6 | 7 | from transformers import AutoTokenizer, AutoModelForCausalLM 8 | import torch 9 | 10 | 11 | def convert_fp16(in_checkpoint, out_checkpoint): 12 | tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False) 13 | model = AutoModelForCausalLM.from_pretrained( 14 | in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True 15 | ) 16 | model.save_pretrained(out_checkpoint) 17 | tokenizer.save_pretrained(out_checkpoint) 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--in-checkpoint", type=str, help="Path to the model") 23 | parser.add_argument("--out-checkpoint", type=str, help="Path to the output model") 24 | args = parser.parse_args() 25 | 26 | convert_fp16(args.in_checkpoint, args.out_checkpoint) 27 | -------------------------------------------------------------------------------- /llama/fashchat/model/llama_condense_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_patch.py 2 | 3 | from functools import partial 4 | 5 | import torch 6 | import transformers 7 | import transformers.models.llama.modeling_llama 8 | 9 | 10 | class CondenseRotaryEmbedding(torch.nn.Module): 11 | def __init__( 12 | self, dim, ratio, max_position_embeddings=2048, base=10000, device=None 13 | ): 14 | super().__init__() 15 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 16 | self.register_buffer("inv_freq", inv_freq) 17 | 18 | # Build here to make `torch.jit.trace` work. 19 | self.ratio = ratio 20 | max_position_embeddings *= ratio 21 | self.max_seq_len_cached = max_position_embeddings 22 | # print(f"Monkey Patching condense ratio {ratio}") 23 | t = ( 24 | torch.arange( 25 | self.max_seq_len_cached, 26 | device=self.inv_freq.device, 27 | dtype=self.inv_freq.dtype, 28 | ) 29 | / ratio 30 | ) 31 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 32 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 33 | emb = torch.cat((freqs, freqs), dim=-1) 34 | dtype = torch.get_default_dtype() 35 | self.register_buffer( 36 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False 37 | ) 38 | self.register_buffer( 39 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False 40 | ) 41 | 42 | def forward(self, x, seq_len=None): 43 | # x: [bs, num_attention_heads, seq_len, head_size] 44 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 45 | if seq_len > self.max_seq_len_cached: 46 | self.max_seq_len_cached = seq_len 47 | t = ( 48 | torch.arange( 49 | self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype 50 | ) 51 | / self.ratio 52 | ) 53 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 54 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 55 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 56 | self.register_buffer( 57 | "cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False 58 | ) 59 | self.register_buffer( 60 | "sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False 61 | ) 62 | return ( 63 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 64 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 65 | ) 66 | 67 | 68 | def replace_llama_with_condense(ratio): 69 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial( 70 | CondenseRotaryEmbedding, ratio=ratio 71 | ) 72 | -------------------------------------------------------------------------------- /llama/fashchat/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make the delta weights by subtracting base weights. 3 | 4 | Usage: 5 | python3 -m fastchat.model.make_delta --base ~/model_weights/llama-13b --target ~/model_weights/vicuna-13b --delta ~/model_weights/vicuna-13b-delta --hub-repo-id lmsys/vicuna-13b-delta-v1.1 6 | """ 7 | import argparse 8 | 9 | import torch 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path): 15 | print(f"Loading the base model from {base_model_path}") 16 | base = AutoModelForCausalLM.from_pretrained( 17 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 18 | ) 19 | 20 | print(f"Loading the target model from {target_model_path}") 21 | target = AutoModelForCausalLM.from_pretrained( 22 | target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 23 | ) 24 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False) 25 | 26 | print("Calculating the delta") 27 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 28 | assert name in base.state_dict() 29 | param.data -= base.state_dict()[name] 30 | 31 | print(f"Saving the delta to {delta_path}") 32 | if args.hub_repo_id: 33 | kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id} 34 | else: 35 | kwargs = {} 36 | target.save_pretrained(delta_path, **kwargs) 37 | target_tokenizer.save_pretrained(delta_path, **kwargs) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | parser.add_argument("--hub-repo-id", type=str) 46 | args = parser.parse_args() 47 | 48 | make_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llama/fashchat/model/model_chatglm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference code for ChatGLM. 3 | Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py. 4 | """ 5 | import re 6 | 7 | import torch 8 | from transformers.generation.logits_process import LogitsProcessor 9 | 10 | 11 | class InvalidScoreLogitsProcessor(LogitsProcessor): 12 | def __call__( 13 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 14 | ) -> torch.FloatTensor: 15 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 16 | scores.zero_() 17 | scores[..., 5] = 5e4 18 | return scores 19 | 20 | 21 | invalid_score_processor = InvalidScoreLogitsProcessor() 22 | 23 | 24 | def process_response(response): 25 | response = response.strip() 26 | response = response.replace("[[训练时间]]", "2023年") 27 | punkts = [ 28 | [",", ","], 29 | ["!", "!"], 30 | [":", ":"], 31 | [";", ";"], 32 | ["\?", "?"], 33 | ] 34 | for item in punkts: 35 | response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) 36 | response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) 37 | return response 38 | 39 | 40 | @torch.inference_mode() 41 | def generate_stream_chatglm( 42 | model, 43 | tokenizer, 44 | params, 45 | device, 46 | context_len=2048, 47 | stream_interval=2, 48 | judge_sent_end=False, 49 | ): 50 | prompt = params["prompt"] 51 | temperature = float(params.get("temperature", 1.0)) 52 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 53 | top_p = float(params.get("top_p", 1.0)) 54 | max_new_tokens = int(params.get("max_new_tokens", 256)) 55 | echo = params.get("echo", True) 56 | 57 | inputs = tokenizer([prompt], return_tensors="pt").to(model.device) 58 | input_echo_len = len(inputs["input_ids"][0]) 59 | 60 | gen_kwargs = { 61 | "max_length": max_new_tokens + input_echo_len, 62 | "do_sample": True if temperature > 1e-5 else False, 63 | "top_p": top_p, 64 | "repetition_penalty": repetition_penalty, 65 | "logits_processor": [invalid_score_processor], 66 | } 67 | if temperature > 1e-5: 68 | gen_kwargs["temperature"] = temperature 69 | 70 | total_len = 0 71 | for total_ids in model.stream_generate(**inputs, **gen_kwargs): 72 | total_ids = total_ids.tolist()[0] 73 | total_len = len(total_ids) 74 | if echo: 75 | output_ids = total_ids 76 | else: 77 | output_ids = total_ids[input_echo_len:] 78 | response = tokenizer.decode(output_ids) 79 | response = process_response(response) 80 | 81 | yield { 82 | "text": response, 83 | "usage": { 84 | "prompt_tokens": input_echo_len, 85 | "completion_tokens": total_len - input_echo_len, 86 | "total_tokens": total_len, 87 | }, 88 | "finish_reason": None, 89 | } 90 | 91 | # TODO: ChatGLM stop when it reach max length 92 | # Only last stream result contains finish_reason, we set finish_reason as stop 93 | ret = { 94 | "text": response, 95 | "usage": { 96 | "prompt_tokens": input_echo_len, 97 | "completion_tokens": total_len - input_echo_len, 98 | "total_tokens": total_len, 99 | }, 100 | "finish_reason": "stop", 101 | } 102 | yield ret 103 | -------------------------------------------------------------------------------- /llama/fashchat/model/model_codet5p.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | import torch 4 | import transformers 5 | from transformers import ( 6 | GenerationConfig, 7 | StoppingCriteria, 8 | StoppingCriteriaList, 9 | TextIteratorStreamer, 10 | ) 11 | 12 | 13 | @torch.inference_mode() 14 | def generate_stream_codet5p( 15 | model, 16 | tokenizer, 17 | params, 18 | device, 19 | context_len=2048, 20 | stream_interval=2, 21 | judge_sent_end=False, 22 | ): 23 | prompt = params["prompt"] 24 | temperature = float(params.get("temperature", 1.0)) 25 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 26 | top_p = float(params.get("top_p", 1.0)) 27 | top_k = int(params.get("top_k", 50)) # -1 means disable 28 | max_new_tokens = int(params.get("max_new_tokens", 1024)) 29 | stop_token_ids = params.get("stop_token_ids", None) or [] 30 | stop_token_ids.append(tokenizer.eos_token_id) 31 | 32 | decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) 33 | streamer = TextIteratorStreamer(tokenizer, **decode_config) 34 | encoding = tokenizer(prompt, return_tensors="pt").to(device) 35 | input_ids = encoding.input_ids 36 | encoding["decoder_input_ids"] = encoding["input_ids"].clone() 37 | input_echo_len = len(input_ids) 38 | 39 | generation_config = GenerationConfig( 40 | max_new_tokens=max_new_tokens, 41 | do_sample=temperature >= 1e-5, 42 | temperature=temperature, 43 | repetition_penalty=repetition_penalty, 44 | no_repeat_ngram_size=10, 45 | top_p=top_p, 46 | top_k=top_k, 47 | eos_token_id=stop_token_ids, 48 | ) 49 | 50 | class CodeBlockStopper(StoppingCriteria): 51 | def __call__( 52 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 53 | ) -> bool: 54 | # Code-completion is open-end generation. 55 | # We check \n\n to stop at end of a code block. 56 | if list(input_ids[0][-2:]) == [628, 198]: 57 | return True 58 | return False 59 | 60 | gen_kwargs = dict( 61 | **encoding, 62 | streamer=streamer, 63 | generation_config=generation_config, 64 | stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), 65 | ) 66 | thread = Thread(target=model.generate, kwargs=gen_kwargs) 67 | thread.start() 68 | i = 0 69 | output = "" 70 | for new_text in streamer: 71 | i += 1 72 | output += new_text 73 | if i % stream_interval == 0 or i == max_new_tokens - 1: 74 | yield { 75 | "text": output, 76 | "usage": { 77 | "prompt_tokens": input_echo_len, 78 | "completion_tokens": i, 79 | "total_tokens": input_echo_len + i, 80 | }, 81 | "finish_reason": None, 82 | } 83 | if i >= max_new_tokens: 84 | break 85 | 86 | if i >= max_new_tokens: 87 | finish_reason = "length" 88 | else: 89 | finish_reason = "stop" 90 | 91 | yield { 92 | "text": output, 93 | "usage": { 94 | "prompt_tokens": input_echo_len, 95 | "completion_tokens": i, 96 | "total_tokens": input_echo_len + i, 97 | }, 98 | "finish_reason": finish_reason, 99 | } 100 | thread.join() 101 | 102 | # clean 103 | gc.collect() 104 | torch.cuda.empty_cache() 105 | if device == "xpu": 106 | torch.xpu.empty_cache() 107 | if device == "npu": 108 | torch.npu.empty_cache() 109 | -------------------------------------------------------------------------------- /llama/fashchat/model/model_exllama.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import sys 3 | from typing import Dict 4 | 5 | import torch 6 | 7 | 8 | def generate_stream_exllama( 9 | model, 10 | tokenizer, 11 | params: Dict, 12 | device: str, 13 | context_len: int, 14 | stream_interval: int = 2, 15 | judge_sent_end: bool = False, 16 | ): 17 | try: 18 | from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler 19 | except ImportError as e: 20 | print(f"Error: Failed to load Exllamav2. {e}") 21 | sys.exit(-1) 22 | 23 | prompt = params["prompt"] 24 | 25 | generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) 26 | settings = ExLlamaV2Sampler.Settings() 27 | 28 | settings.temperature = float(params.get("temperature", 0.85)) 29 | settings.top_k = int(params.get("top_k", 50)) 30 | settings.top_p = float(params.get("top_p", 0.8)) 31 | settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) 32 | settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) 33 | 34 | max_new_tokens = int(params.get("max_new_tokens", 256)) 35 | 36 | generator.set_stop_conditions(params.get("stop_token_ids", None) or []) 37 | echo = bool(params.get("echo", True)) 38 | 39 | input_ids = generator.tokenizer.encode(prompt) 40 | prompt_tokens = input_ids.shape[-1] 41 | generator.begin_stream(input_ids, settings) 42 | 43 | generated_tokens = 0 44 | if echo: 45 | output = prompt 46 | else: 47 | output = "" 48 | while True: 49 | chunk, eos, _ = generator.stream() 50 | output += chunk 51 | generated_tokens += 1 52 | if generated_tokens == max_new_tokens: 53 | finish_reason = "length" 54 | break 55 | elif eos: 56 | finish_reason = "length" 57 | break 58 | yield { 59 | "text": output, 60 | "usage": { 61 | "prompt_tokens": prompt_tokens, 62 | "completion_tokens": generated_tokens, 63 | "total_tokens": prompt_tokens + generated_tokens, 64 | }, 65 | "finish_reason": None, 66 | } 67 | 68 | yield { 69 | "text": output, 70 | "usage": { 71 | "prompt_tokens": prompt_tokens, 72 | "completion_tokens": generated_tokens, 73 | "total_tokens": prompt_tokens + generated_tokens, 74 | }, 75 | "finish_reason": finish_reason, 76 | } 77 | gc.collect() 78 | -------------------------------------------------------------------------------- /llama/fashchat/model/model_falcon.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | from typing import Iterable 4 | 5 | import torch 6 | import transformers 7 | from transformers import TextIteratorStreamer, GenerationConfig 8 | 9 | from fastchat.utils import is_partial_stop 10 | 11 | 12 | @torch.inference_mode() 13 | def generate_stream_falcon( 14 | model, 15 | tokenizer, 16 | params, 17 | device, 18 | context_len=2048, 19 | stream_interval=2, 20 | judge_sent_end=False, 21 | ): 22 | prompt = params["prompt"] 23 | len_prompt = len(prompt) 24 | temperature = float(params.get("temperature", 1.0)) 25 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 26 | top_p = float(params.get("top_p", 1.0)) 27 | top_k = int(params.get("top_k", 50)) # -1 means disable 28 | max_new_tokens = int(params.get("max_new_tokens", 256)) 29 | stop_str = params.get("stop", None) 30 | echo = bool(params.get("echo", True)) 31 | stop_token_ids = params.get("stop_token_ids", None) or [] 32 | stop_token_ids.append(tokenizer.eos_token_id) 33 | 34 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 35 | input_ids = inputs["input_ids"] 36 | attention_mask = inputs["attention_mask"] 37 | 38 | max_src_len = context_len - max_new_tokens - 8 39 | 40 | input_ids = input_ids[-max_src_len:] # truncate from the left 41 | attention_mask = attention_mask[-max_src_len:] # truncate from the left 42 | input_echo_len = len(input_ids) 43 | 44 | decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) 45 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) 46 | 47 | generation_config = GenerationConfig( 48 | max_new_tokens=max_new_tokens, 49 | do_sample=temperature >= 1e-5, 50 | temperature=temperature, 51 | repetition_penalty=repetition_penalty, 52 | no_repeat_ngram_size=10, 53 | top_p=top_p, 54 | top_k=top_k, 55 | eos_token_id=stop_token_ids, 56 | ) 57 | 58 | generation_kwargs = dict( 59 | inputs=input_ids, 60 | attention_mask=attention_mask, 61 | streamer=streamer, 62 | generation_config=generation_config, 63 | ) 64 | 65 | thread = Thread(target=model.generate, kwargs=generation_kwargs) 66 | thread.start() 67 | 68 | if echo: 69 | # means keep the prompt 70 | output = prompt 71 | else: 72 | output = "" 73 | 74 | for i, new_text in enumerate(streamer): 75 | output += new_text 76 | if i % stream_interval == 0: 77 | if echo: 78 | rfind_start = len_prompt 79 | else: 80 | rfind_start = 0 81 | 82 | partially_stopped = False 83 | if stop_str: 84 | if isinstance(stop_str, str): 85 | pos = output.rfind(stop_str, rfind_start) 86 | if pos != -1: 87 | output = output[:pos] 88 | else: 89 | partially_stopped = is_partial_stop(output, stop_str) 90 | elif isinstance(stop_str, Iterable): 91 | for each_stop in stop_str: 92 | pos = output.rfind(each_stop, rfind_start) 93 | if pos != -1: 94 | output = output[:pos] 95 | break 96 | else: 97 | partially_stopped = is_partial_stop(output, each_stop) 98 | if partially_stopped: 99 | break 100 | else: 101 | raise ValueError("Invalid stop field type.") 102 | 103 | # prevent yielding partial stop sequence 104 | if not partially_stopped: 105 | yield { 106 | "text": output, 107 | "usage": { 108 | "prompt_tokens": input_echo_len, 109 | "completion_tokens": i, 110 | "total_tokens": input_echo_len + i, 111 | }, 112 | "finish_reason": None, 113 | } 114 | output = output.strip() 115 | 116 | # finish stream event, which contains finish reason 117 | if i == max_new_tokens - 1: 118 | finish_reason = "length" 119 | elif partially_stopped: 120 | finish_reason = None 121 | else: 122 | finish_reason = "stop" 123 | 124 | yield { 125 | "text": output, 126 | "usage": { 127 | "prompt_tokens": input_echo_len, 128 | "completion_tokens": i, 129 | "total_tokens": input_echo_len + i, 130 | }, 131 | "finish_reason": finish_reason, 132 | } 133 | 134 | # clean 135 | gc.collect() 136 | torch.cuda.empty_cache() 137 | if device == "xpu": 138 | torch.xpu.empty_cache() 139 | if device == "npu": 140 | torch.npu.empty_cache() 141 | -------------------------------------------------------------------------------- /llama/fashchat/model/model_xfastertransformer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from threading import Thread 3 | 4 | import torch 5 | from transformers import TextIteratorStreamer 6 | 7 | 8 | @torch.inference_mode() 9 | def generate_stream_xft( 10 | model, 11 | tokenizer, 12 | params, 13 | device, 14 | context_len=8192, 15 | stream_interval=2, 16 | judge_sent_end=False, 17 | ): 18 | prompt = params["prompt"] 19 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 20 | 21 | # unused now, and placehold for future. 22 | # temperature = float(params.get("temperature", 1.0)) 23 | # top_p = float(params.get("top_p", 1.0)) 24 | 25 | max_new_tokens = int(params.get("max_new_tokens", 4096)) 26 | echo = params.get("echo", True) 27 | 28 | inputs = tokenizer( 29 | prompt, return_tensors="pt", padding=model.config.padding 30 | ).input_ids 31 | input_echo_len = len(inputs[0]) 32 | max_len = max_new_tokens + input_echo_len 33 | 34 | decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) 35 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) 36 | generation_kwargs = { 37 | "input_ids": inputs, 38 | "streamer": streamer, 39 | "max_length": max_len, 40 | "num_beams": model.config.beam_width, 41 | "length_penalty": repetition_penalty, 42 | "num_return_sequences": model.config.num_return_sequences, 43 | "early_stopping": model.config.early_stopping, 44 | "eos_token_id": model.config.eos_token_id, 45 | "pad_token_id": model.config.pad_token_id, 46 | } 47 | 48 | thread = Thread(target=model.model.generate, kwargs=generation_kwargs) 49 | thread.start() 50 | if echo: 51 | # means keep the prompt 52 | output = prompt 53 | else: 54 | output = "" 55 | i = 0 56 | for i, new_text in enumerate(streamer): 57 | output += new_text 58 | yield { 59 | "text": output, 60 | "usage": { 61 | "prompt_tokens": input_echo_len, 62 | "completion_tokens": i, 63 | "total_tokens": input_echo_len + i, 64 | }, 65 | "finish_reason": None, 66 | } 67 | output = output.strip() 68 | if i == max_new_tokens - 1: 69 | finish_reason = "length" 70 | else: 71 | finish_reason = "stop" 72 | yield { 73 | "text": output, 74 | "usage": { 75 | "prompt_tokens": input_echo_len, 76 | "completion_tokens": i, 77 | "total_tokens": input_echo_len + i, 78 | }, 79 | "finish_reason": finish_reason, 80 | } 81 | gc.collect() 82 | -------------------------------------------------------------------------------- /llama/fashchat/model/monkey_patch_non_inplace.py: -------------------------------------------------------------------------------- 1 | """ 2 | Monkey patch the llama implementation in the huggingface/transformers library. 3 | Avoid bugs in mps backend by not using in-place operations. 4 | """ 5 | import math 6 | from typing import List, Optional, Tuple 7 | 8 | import torch 9 | from torch import nn 10 | import transformers 11 | 12 | 13 | def rotate_half(x): 14 | """Rotates half the hidden dims of the input.""" 15 | x1 = x[..., : x.shape[-1] // 2].clone() 16 | x2 = x[..., x.shape[-1] // 2 :].clone() 17 | return torch.cat((-x2, x1), dim=-1) 18 | 19 | 20 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 21 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] 22 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) 23 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 24 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 25 | q_embed = (q * cos) + (rotate_half(q) * sin) 26 | k_embed = (k * cos) + (rotate_half(k) * sin) 27 | return q_embed, k_embed 28 | 29 | 30 | def forward( 31 | self, 32 | hidden_states: torch.Tensor, 33 | attention_mask: Optional[torch.Tensor] = None, 34 | position_ids: Optional[torch.LongTensor] = None, 35 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 36 | output_attentions: bool = False, 37 | use_cache: bool = False, 38 | padding_mask: Optional[torch.LongTensor] = None, 39 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 40 | bsz, q_len, _ = hidden_states.size() 41 | 42 | query_states = ( 43 | self.q_proj(hidden_states) 44 | .view(bsz, q_len, self.num_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) 47 | key_states = ( 48 | self.k_proj(hidden_states) 49 | .view(bsz, q_len, self.num_heads, self.head_dim) 50 | .transpose(1, 2) 51 | ) 52 | value_states = ( 53 | self.v_proj(hidden_states) 54 | .view(bsz, q_len, self.num_heads, self.head_dim) 55 | .transpose(1, 2) 56 | ) 57 | 58 | kv_seq_len = key_states.shape[-2] 59 | if past_key_value is not None: 60 | kv_seq_len += past_key_value[0].shape[-2] 61 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 62 | query_states, key_states = apply_rotary_pos_emb( 63 | query_states, key_states, cos, sin, position_ids 64 | ) 65 | # [bsz, nh, t, hd] 66 | 67 | if past_key_value is not None: 68 | # reuse k, v, self_attention 69 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 70 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 71 | 72 | past_key_value = (key_states, value_states) if use_cache else None 73 | 74 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( 75 | self.head_dim 76 | ) 77 | 78 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 79 | raise ValueError( 80 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 81 | f" {attn_weights.size()}" 82 | ) 83 | 84 | if attention_mask is not None: 85 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 86 | raise ValueError( 87 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 88 | ) 89 | attn_weights = attn_weights + attention_mask 90 | attn_weights = torch.max( 91 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 92 | ) 93 | 94 | # upcast attention to fp32 95 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( 96 | query_states.dtype 97 | ) 98 | attn_output = torch.matmul(attn_weights, value_states) 99 | 100 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 101 | raise ValueError( 102 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 103 | f" {attn_output.size()}" 104 | ) 105 | 106 | attn_output = attn_output.transpose(1, 2) 107 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 108 | 109 | attn_output = self.o_proj(attn_output) 110 | 111 | if not output_attentions: 112 | attn_weights = None 113 | 114 | return attn_output, attn_weights, past_key_value 115 | 116 | 117 | def replace_llama_attn_with_non_inplace_operations(): 118 | """Avoid bugs in mps backend by not using in-place operations.""" 119 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 120 | -------------------------------------------------------------------------------- /llama/fashchat/model/rwkv_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from types import SimpleNamespace 3 | import warnings 4 | 5 | import torch 6 | 7 | os.environ["RWKV_JIT_ON"] = "1" 8 | os.environ["RWKV_CUDA_ON"] = "1" 9 | 10 | from rwkv.model import RWKV 11 | from rwkv.utils import PIPELINE, PIPELINE_ARGS 12 | 13 | 14 | class RwkvModel: 15 | def __init__(self, model_path): 16 | warnings.warn( 17 | "Experimental support. Please use ChatRWKV if you want to chat with RWKV" 18 | ) 19 | self.config = SimpleNamespace(is_encoder_decoder=False) 20 | self.model = RWKV(model=model_path, strategy="cuda fp16") 21 | # two GPUs 22 | # self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16") 23 | 24 | self.tokenizer = None 25 | self.model_path = model_path 26 | 27 | def to(self, target): 28 | assert target == "cuda" 29 | 30 | def __call__(self, input_ids, use_cache, past_key_values=None): 31 | assert use_cache == True 32 | input_ids = input_ids[0].detach().cpu().numpy() 33 | # print(input_ids) 34 | logits, state = self.model.forward(input_ids, past_key_values) 35 | # print(logits) 36 | logits = logits.unsqueeze(0).unsqueeze(0) 37 | out = SimpleNamespace(logits=logits, past_key_values=state) 38 | return out 39 | 40 | def generate( 41 | self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0 42 | ): 43 | # This function is used by fastchat.llm_judge. 44 | # Because RWKV does not support huggingface generation API, 45 | # we reuse fastchat.serve.inference.generate_stream as a workaround. 46 | from transformers import AutoTokenizer 47 | 48 | from fastchat.serve.inference import generate_stream 49 | from fastchat.conversation import get_conv_template 50 | 51 | if self.tokenizer is None: 52 | self.tokenizer = AutoTokenizer.from_pretrained( 53 | "EleutherAI/pythia-160m", use_fast=True 54 | ) 55 | prompt = self.tokenizer.decode(input_ids[0].tolist()) 56 | conv = get_conv_template("rwkv") 57 | 58 | gen_params = { 59 | "model": self.model_path, 60 | "prompt": prompt, 61 | "temperature": temperature, 62 | "repetition_penalty": repetition_penalty, 63 | "max_new_tokens": max_new_tokens, 64 | "stop": conv.stop_str, 65 | "stop_token_ids": conv.stop_token_ids, 66 | "echo": False, 67 | } 68 | res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda") 69 | 70 | for res in res_iter: 71 | pass 72 | 73 | output = res["text"] 74 | output_ids = self.tokenizer.encode(output) 75 | 76 | return [input_ids[0].tolist() + output_ids] 77 | -------------------------------------------------------------------------------- /llama/fashchat/model/upload_hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Upload weights to huggingface. 3 | 4 | Usage: 5 | python3 -m fastchat.model.upload_hub --model-path ~/model_weights/vicuna-13b --hub-repo-id lmsys/vicuna-13b-v1.3 6 | """ 7 | import argparse 8 | import tempfile 9 | 10 | import torch 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | 13 | 14 | def upload_hub(model_path, hub_repo_id, component, private): 15 | if component == "all": 16 | components = ["model", "tokenizer"] 17 | else: 18 | components = [component] 19 | 20 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private} 21 | 22 | if "model" in components: 23 | model = AutoModelForCausalLM.from_pretrained( 24 | model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 25 | ) 26 | with tempfile.TemporaryDirectory() as tmp_path: 27 | model.save_pretrained(tmp_path, **kwargs) 28 | 29 | if "tokenizer" in components: 30 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 31 | with tempfile.TemporaryDirectory() as tmp_path: 32 | tokenizer.save_pretrained(tmp_path, **kwargs) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--model-path", type=str, required=True) 38 | parser.add_argument("--hub-repo-id", type=str, required=True) 39 | parser.add_argument( 40 | "--component", type=str, choices=["all", "model", "tokenizer"], default="all" 41 | ) 42 | parser.add_argument("--private", action="store_true") 43 | args = parser.parse_args() 44 | 45 | upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) 46 | -------------------------------------------------------------------------------- /llama/fashchat/modules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/llama/fashchat/modules/.DS_Store -------------------------------------------------------------------------------- /llama/fashchat/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/llama/fashchat/modules/__init__.py -------------------------------------------------------------------------------- /llama/fashchat/modules/awq.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | import sys 4 | 5 | import torch 6 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils 7 | 8 | 9 | @dataclass 10 | class AWQConfig: 11 | ckpt: str = field( 12 | default=None, 13 | metadata={ 14 | "help": "Load quantized model. The path to the local AWQ checkpoint." 15 | }, 16 | ) 17 | wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) 18 | groupsize: int = field( 19 | default=-1, 20 | metadata={"help": "Groupsize to use for quantization; default uses full row."}, 21 | ) 22 | 23 | 24 | def load_awq_quantized(model_name, awq_config: AWQConfig, device): 25 | print("Loading AWQ quantized model...") 26 | 27 | try: 28 | from tinychat.utils import load_quant 29 | from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp 30 | except ImportError as e: 31 | print(f"Error: Failed to import tinychat. {e}") 32 | print("Please double check if you have successfully installed AWQ") 33 | print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md") 34 | sys.exit(-1) 35 | 36 | config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) 37 | tokenizer = AutoTokenizer.from_pretrained( 38 | model_name, use_fast=False, trust_remote_code=True 39 | ) 40 | 41 | def skip(*args, **kwargs): 42 | pass 43 | 44 | torch.nn.init.kaiming_uniform_ = skip 45 | torch.nn.init.kaiming_normal_ = skip 46 | torch.nn.init.uniform_ = skip 47 | torch.nn.init.normal_ = skip 48 | modeling_utils._init_weights = False 49 | 50 | torch.set_default_dtype(torch.half) 51 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) 52 | 53 | if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]): 54 | model = load_quant.load_awq_llama_fast( 55 | model, 56 | find_awq_ckpt(awq_config), 57 | awq_config.wbits, 58 | awq_config.groupsize, 59 | device, 60 | ) 61 | make_quant_attn(model, device) 62 | make_quant_norm(model) 63 | make_fused_mlp(model) 64 | else: 65 | model = load_quant.load_awq_model( 66 | model, 67 | find_awq_ckpt(awq_config), 68 | awq_config.wbits, 69 | awq_config.groupsize, 70 | device, 71 | ) 72 | return model, tokenizer 73 | 74 | 75 | def find_awq_ckpt(awq_config: AWQConfig): 76 | if Path(awq_config.ckpt).is_file(): 77 | return awq_config.ckpt 78 | 79 | for ext in ["*.pt", "*.safetensors"]: 80 | matched_result = sorted(Path(awq_config.ckpt).glob(ext)) 81 | if len(matched_result) > 0: 82 | return str(matched_result[-1]) 83 | 84 | print("Error: AWQ checkpoint not found") 85 | sys.exit(1) 86 | -------------------------------------------------------------------------------- /llama/fashchat/modules/exllama.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import sys 3 | 4 | 5 | @dataclass 6 | class ExllamaConfig: 7 | max_seq_len: int 8 | gpu_split: str = None 9 | cache_8bit: bool = False 10 | 11 | 12 | class ExllamaModel: 13 | def __init__(self, exllama_model, exllama_cache): 14 | self.model = exllama_model 15 | self.cache = exllama_cache 16 | self.config = self.model.config 17 | 18 | 19 | def load_exllama_model(model_path, exllama_config: ExllamaConfig): 20 | try: 21 | from exllamav2 import ( 22 | ExLlamaV2Config, 23 | ExLlamaV2Tokenizer, 24 | ExLlamaV2, 25 | ExLlamaV2Cache, 26 | ExLlamaV2Cache_8bit, 27 | ) 28 | except ImportError as e: 29 | print(f"Error: Failed to load Exllamav2. {e}") 30 | sys.exit(-1) 31 | 32 | exllamav2_config = ExLlamaV2Config() 33 | exllamav2_config.model_dir = model_path 34 | exllamav2_config.prepare() 35 | exllamav2_config.max_seq_len = exllama_config.max_seq_len 36 | exllamav2_config.cache_8bit = exllama_config.cache_8bit 37 | 38 | exllama_model = ExLlamaV2(exllamav2_config) 39 | tokenizer = ExLlamaV2Tokenizer(exllamav2_config) 40 | 41 | split = None 42 | if exllama_config.gpu_split: 43 | split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] 44 | exllama_model.load(split) 45 | 46 | cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache 47 | exllama_cache = cache_class(exllama_model) 48 | model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) 49 | 50 | return model, tokenizer 51 | -------------------------------------------------------------------------------- /llama/fashchat/modules/gptq.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import os 3 | from os.path import isdir, isfile 4 | from pathlib import Path 5 | import sys 6 | 7 | from transformers import AutoTokenizer 8 | 9 | 10 | @dataclass 11 | class GptqConfig: 12 | ckpt: str = field( 13 | default=None, 14 | metadata={ 15 | "help": "Load quantized model. The path to the local GPTQ checkpoint." 16 | }, 17 | ) 18 | wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) 19 | groupsize: int = field( 20 | default=-1, 21 | metadata={"help": "Groupsize to use for quantization; default uses full row."}, 22 | ) 23 | act_order: bool = field( 24 | default=True, 25 | metadata={"help": "Whether to apply the activation order GPTQ heuristic"}, 26 | ) 27 | 28 | 29 | def load_gptq_quantized(model_name, gptq_config: GptqConfig): 30 | print("Loading GPTQ quantized model...") 31 | 32 | try: 33 | script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 34 | module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa") 35 | 36 | sys.path.insert(0, module_path) 37 | from llama import load_quant 38 | except ImportError as e: 39 | print(f"Error: Failed to load GPTQ-for-LLaMa. {e}") 40 | print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md") 41 | sys.exit(-1) 42 | 43 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 44 | # only `fastest-inference-4bit` branch cares about `act_order` 45 | if gptq_config.act_order: 46 | model = load_quant( 47 | model_name, 48 | find_gptq_ckpt(gptq_config), 49 | gptq_config.wbits, 50 | gptq_config.groupsize, 51 | act_order=gptq_config.act_order, 52 | ) 53 | else: 54 | # other branches 55 | model = load_quant( 56 | model_name, 57 | find_gptq_ckpt(gptq_config), 58 | gptq_config.wbits, 59 | gptq_config.groupsize, 60 | ) 61 | 62 | return model, tokenizer 63 | 64 | 65 | def find_gptq_ckpt(gptq_config: GptqConfig): 66 | if Path(gptq_config.ckpt).is_file(): 67 | return gptq_config.ckpt 68 | 69 | for ext in ["*.pt", "*.safetensors"]: 70 | matched_result = sorted(Path(gptq_config.ckpt).glob(ext)) 71 | if len(matched_result) > 0: 72 | return str(matched_result[-1]) 73 | 74 | print("Error: gptq checkpoint not found") 75 | sys.exit(1) 76 | -------------------------------------------------------------------------------- /llama/fashchat/modules/xfastertransformer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import sys 3 | 4 | 5 | @dataclass 6 | class XftConfig: 7 | max_seq_len: int = 4096 8 | beam_width: int = 1 9 | eos_token_id: int = -1 10 | pad_token_id: int = -1 11 | num_return_sequences: int = 1 12 | is_encoder_decoder: bool = False 13 | padding: bool = True 14 | early_stopping: bool = False 15 | data_type: str = "bf16_fp16" 16 | 17 | 18 | class XftModel: 19 | def __init__(self, xft_model, xft_config): 20 | self.model = xft_model 21 | self.config = xft_config 22 | 23 | 24 | def load_xft_model(model_path, xft_config: XftConfig): 25 | try: 26 | import xfastertransformer 27 | from transformers import AutoTokenizer 28 | except ImportError as e: 29 | print(f"Error: Failed to load xFasterTransformer. {e}") 30 | sys.exit(-1) 31 | 32 | if xft_config.data_type is None or xft_config.data_type == "": 33 | data_type = "bf16_fp16" 34 | else: 35 | data_type = xft_config.data_type 36 | tokenizer = AutoTokenizer.from_pretrained( 37 | model_path, use_fast=False, padding_side="left", trust_remote_code=True 38 | ) 39 | xft_model = xfastertransformer.AutoModel.from_pretrained( 40 | model_path, dtype=data_type 41 | ) 42 | model = XftModel(xft_model=xft_model, xft_config=xft_config) 43 | if model.model.rank > 0: 44 | while True: 45 | model.model.generate() 46 | return model, tokenizer 47 | -------------------------------------------------------------------------------- /llama/fashchat/train/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/llama/fashchat/train/.DS_Store -------------------------------------------------------------------------------- /llama/fashchat/train/llama2_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from flash_attn import __version__ as flash_attn_version 6 | from flash_attn.bert_padding import pad_input, unpad_input 7 | from flash_attn.flash_attn_interface import ( 8 | flash_attn_func, 9 | flash_attn_varlen_kvpacked_func, 10 | ) 11 | from transformers.models.llama.modeling_llama import ( 12 | LlamaAttention, 13 | LlamaModel, 14 | rotate_half, 15 | ) 16 | 17 | 18 | def apply_rotary_pos_emb(q, k, cos_sin, position_ids): 19 | gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1] 20 | gather_indices = gather_indices.repeat( 21 | 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3] 22 | ) 23 | bsz = gather_indices.shape[0] 24 | cos, sin = ( 25 | torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices) 26 | for x in cos_sin 27 | ) 28 | q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k)) 29 | return q, k 30 | 31 | 32 | def forward( 33 | self, 34 | hidden_states: torch.Tensor, 35 | attention_mask: Optional[torch.Tensor] = None, 36 | position_ids: Optional[torch.Tensor] = None, 37 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 38 | output_attentions: bool = False, 39 | use_cache: bool = False, 40 | padding_mask: Optional[torch.Tensor] = None, 41 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 42 | if output_attentions: 43 | warnings.warn( 44 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 45 | ) 46 | 47 | bsz, q_len, _ = hidden_states.size() 48 | kv_heads = getattr(self, "num_key_value_heads", self.num_heads) 49 | 50 | q, k, v = ( 51 | op(hidden_states).view(bsz, q_len, nh, self.head_dim) 52 | for op, nh in ( 53 | (self.q_proj, self.num_heads), 54 | (self.k_proj, kv_heads), 55 | (self.v_proj, kv_heads), 56 | ) 57 | ) 58 | # shape: (b, s, num_heads, head_dim) 59 | 60 | kv_seq_len = k.shape[1] 61 | past_kv_len = 0 62 | if past_key_value is not None: 63 | past_kv_len = past_key_value[0].shape[2] 64 | kv_seq_len += past_kv_len 65 | 66 | cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) 67 | q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids) 68 | 69 | if past_key_value is not None: 70 | assert ( 71 | flash_attn_version >= "2.1.0" 72 | ), "past_key_value support requires flash-attn >= 2.1.0" 73 | # reuse k, v 74 | k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1) 75 | v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1) 76 | 77 | past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None 78 | 79 | if attention_mask is None: 80 | output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( 81 | bsz, q_len, -1 82 | ) 83 | else: 84 | q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) 85 | # We can skip concat and call unpad twice but seems better to call unpad only once. 86 | kv, _, cu_k_lens, max_k = unpad_input( 87 | torch.stack((k, v), dim=2), attention_mask 88 | ) 89 | output_unpad = flash_attn_varlen_kvpacked_func( 90 | q, 91 | kv, 92 | cu_q_lens, 93 | cu_k_lens, 94 | max_s, 95 | max_k, 96 | 0.0, 97 | softmax_scale=None, 98 | causal=True, 99 | ) 100 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 101 | output = pad_input(output_unpad, indices, bsz, q_len) 102 | 103 | return self.o_proj(output), None, past_key_value 104 | 105 | 106 | # Disable the transformation of the attention mask in LlamaModel as flash attention 107 | # takes a boolean key_padding_mask. Fills in the past kv length for use in forward. 108 | def _prepare_decoder_attention_mask( 109 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 110 | ): 111 | # [bsz, seq_len] 112 | if past_key_values_length > 0 and attention_mask is not None: 113 | attention_mask = torch.cat( 114 | ( 115 | torch.full( 116 | (input_shape[0], past_key_values_length), 117 | True, 118 | dtype=attention_mask.dtype, 119 | device=attention_mask.device, 120 | ), 121 | attention_mask, 122 | ), 123 | dim=-1, 124 | ) 125 | 126 | if attention_mask is not None and torch.all(attention_mask): 127 | return None # This uses the faster call when training with full samples 128 | 129 | return attention_mask 130 | 131 | 132 | def replace_llama_attn_with_flash_attn(): 133 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 134 | if cuda_major < 8: 135 | warnings.warn( 136 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 137 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 138 | ) 139 | 140 | LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 141 | LlamaAttention.forward = forward 142 | 143 | 144 | def test(): 145 | from fastchat.train.llama_flash_attn_monkey_patch import forward as fastchat_forward 146 | from transformers.models.llama.configuration_llama import LlamaConfig 147 | 148 | config = LlamaConfig( 149 | hidden_size=1024, 150 | intermediate_size=128, 151 | num_hidden_layers=1, 152 | num_attention_heads=8, 153 | max_position_embeddings=16, 154 | ) 155 | device = torch.device("cuda") 156 | model = LlamaModel(config) 157 | attn = LlamaAttention(config).to(device).half() 158 | bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings 159 | position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view( 160 | -1, seqlen 161 | ) 162 | 163 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) 164 | for i in range(4): 165 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) 166 | if i: 167 | mask[0, -i:] = False 168 | mask[1, :i] = False 169 | 170 | lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0) 171 | ref, _, _ = attn.forward( 172 | hidden, attention_mask=lmask, position_ids=position_ids 173 | ) 174 | 175 | fast, _, _ = fastchat_forward( 176 | attn, hidden, attention_mask=mask, position_ids=position_ids 177 | ) 178 | 179 | lmask = _prepare_decoder_attention_mask( 180 | model, mask, hidden.shape[:2], hidden, 0 181 | ) 182 | test, _, _ = forward( 183 | attn, hidden, attention_mask=lmask, position_ids=position_ids 184 | ) 185 | 186 | print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}") 187 | print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}") 188 | print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}") 189 | print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}") 190 | print(f"allclose(fast, test) = {torch.allclose(fast, test)}") 191 | 192 | with torch.no_grad(): 193 | # Also check that past_kv is handled properly 194 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) 195 | part_len = seqlen // 4 196 | assert part_len * 4 == seqlen 197 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) 198 | mask[0, -2:] = False 199 | lmask = _prepare_decoder_attention_mask( 200 | model, mask, hidden.shape[:2], hidden, 0 201 | ) 202 | oneshot, _, _ = forward( 203 | attn, hidden, attention_mask=lmask, position_ids=position_ids 204 | ) 205 | parts = [] 206 | past_kv, past_kv_len = None, 0 207 | for i in range(4): 208 | start = part_len * i 209 | end = start + part_len 210 | hidden_part = hidden[:, start:end, ...] 211 | lmask = _prepare_decoder_attention_mask( 212 | model, 213 | mask[:, start:end], 214 | hidden_part.shape[:2], 215 | hidden_part, 216 | past_kv_len, 217 | ) 218 | part, _, past_kv = forward( 219 | attn, 220 | hidden_part.clone(), 221 | attention_mask=lmask, 222 | position_ids=position_ids[:, start:end], 223 | past_key_value=past_kv, 224 | use_cache=True, 225 | ) 226 | parts.append(part) 227 | past_kv_len = past_kv[0].shape[2] 228 | 229 | print( 230 | f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}" 231 | ) 232 | print( 233 | f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}" 234 | ) 235 | 236 | 237 | if __name__ == "__main__": 238 | test() 239 | -------------------------------------------------------------------------------- /llama/fashchat/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | from torch import nn 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 8 | 9 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 10 | from flash_attn.bert_padding import unpad_input, pad_input 11 | 12 | 13 | def forward( 14 | self, 15 | hidden_states: torch.Tensor, 16 | attention_mask: Optional[torch.Tensor] = None, 17 | position_ids: Optional[torch.Tensor] = None, 18 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 19 | output_attentions: bool = False, 20 | use_cache: bool = False, 21 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 22 | if output_attentions: 23 | warnings.warn( 24 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 25 | ) 26 | 27 | bsz, q_len, _ = hidden_states.size() 28 | 29 | query_states = ( 30 | self.q_proj(hidden_states) 31 | .view(bsz, q_len, self.num_heads, self.head_dim) 32 | .transpose(1, 2) 33 | ) 34 | key_states = ( 35 | self.k_proj(hidden_states) 36 | .view(bsz, q_len, self.num_heads, self.head_dim) 37 | .transpose(1, 2) 38 | ) 39 | value_states = ( 40 | self.v_proj(hidden_states) 41 | .view(bsz, q_len, self.num_heads, self.head_dim) 42 | .transpose(1, 2) 43 | ) # shape: (b, num_heads, s, head_dim) 44 | 45 | kv_seq_len = key_states.shape[-2] 46 | if past_key_value is not None: 47 | kv_seq_len += past_key_value[0].shape[-2] 48 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 49 | query_states, key_states = apply_rotary_pos_emb( 50 | query_states, key_states, cos, sin, position_ids 51 | ) 52 | 53 | if past_key_value is not None: 54 | # reuse k, v 55 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 56 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 57 | 58 | past_key_value = (key_states, value_states) if use_cache else None 59 | 60 | # Transform the data into the format required by flash attention 61 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 62 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 63 | key_padding_mask = attention_mask 64 | 65 | if key_padding_mask is None: 66 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 67 | cu_q_lens = torch.arange( 68 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 69 | ) 70 | max_s = q_len 71 | output = flash_attn_varlen_qkvpacked_func( 72 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 73 | ) 74 | output = output.view(bsz, q_len, -1) 75 | else: 76 | qkv = qkv.reshape(bsz, q_len, -1) 77 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 78 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 79 | output_unpad = flash_attn_varlen_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 83 | output = pad_input(output_unpad, indices, bsz, q_len) 84 | 85 | return self.o_proj(output), None, past_key_value 86 | 87 | 88 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 89 | # requires the attention mask to be the same as the key_padding_mask 90 | def _prepare_decoder_attention_mask( 91 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 92 | ): 93 | # [bsz, seq_len] 94 | return attention_mask 95 | 96 | 97 | def replace_llama_attn_with_flash_attn(): 98 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 99 | if cuda_major < 8: 100 | warnings.warn( 101 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 102 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 103 | ) 104 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 105 | _prepare_decoder_attention_mask 106 | ) 107 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 108 | -------------------------------------------------------------------------------- /llama/fashchat/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /llama/fashchat/train/train_lora.py: -------------------------------------------------------------------------------- 1 | # Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> 2 | 3 | # Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: 4 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from dataclasses import dataclass, field 19 | import logging 20 | import pathlib 21 | import typing 22 | import os 23 | 24 | from deepspeed import zero 25 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 26 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 27 | import transformers 28 | from transformers import Trainer, BitsAndBytesConfig, deepspeed 29 | import torch 30 | 31 | from fastchat.train.train import ( 32 | DataArguments, 33 | ModelArguments, 34 | make_supervised_data_module, 35 | ) 36 | 37 | from fastchat.train.llama_flash_attn_monkey_patch import ( 38 | replace_llama_attn_with_flash_attn, 39 | ) 40 | 41 | 42 | @dataclass 43 | class TrainingArguments(transformers.TrainingArguments): 44 | cache_dir: typing.Optional[str] = field(default=None) 45 | optim: str = field(default="adamw_torch") 46 | model_max_length: int = field( 47 | default=512, 48 | metadata={ 49 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 50 | }, 51 | ) 52 | flash_attn: bool = False 53 | 54 | 55 | @dataclass 56 | class LoraArguments: 57 | lora_r: int = 8 58 | lora_alpha: int = 16 59 | lora_dropout: float = 0.05 60 | lora_target_modules: typing.List[str] = field( 61 | default_factory=lambda: ["q_proj", "v_proj"] 62 | ) 63 | lora_weight_path: str = "" 64 | lora_bias: str = "none" 65 | q_lora: bool = False 66 | 67 | 68 | def maybe_zero_3(param): 69 | if hasattr(param, "ds_id"): 70 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE 71 | with zero.GatheredParameters([param]): 72 | param = param.data.detach().cpu().clone() 73 | else: 74 | param = param.detach().cpu().clone() 75 | return param 76 | 77 | 78 | # Borrowed from peft.utils.get_peft_model_state_dict 79 | def get_peft_state_maybe_zero_3(named_params, bias): 80 | if bias == "none": 81 | to_return = {k: t for k, t in named_params if "lora_" in k} 82 | elif bias == "all": 83 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 84 | elif bias == "lora_only": 85 | to_return = {} 86 | maybe_lora_bias = {} 87 | lora_bias_names = set() 88 | for k, t in named_params: 89 | if "lora_" in k: 90 | to_return[k] = t 91 | bias_name = k.split("lora_")[0] + "bias" 92 | lora_bias_names.add(bias_name) 93 | elif "bias" in k: 94 | maybe_lora_bias[k] = t 95 | for k, t in maybe_lora_bias: 96 | if bias_name in lora_bias_names: 97 | to_return[bias_name] = t 98 | else: 99 | raise NotImplementedError 100 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} 101 | return to_return 102 | 103 | 104 | def train(): 105 | parser = transformers.HfArgumentParser( 106 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments) 107 | ) 108 | parser.add_argument( 109 | "--pre_process_func", type=str, required=True, help="how to process data" 110 | ) 111 | ( 112 | model_args, 113 | data_args, 114 | training_args, 115 | lora_args, 116 | process_args 117 | ) = parser.parse_args_into_dataclasses() 118 | 119 | if training_args.flash_attn: 120 | replace_llama_attn_with_flash_attn() 121 | 122 | device_map = None 123 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 124 | ddp = world_size != 1 125 | if lora_args.q_lora: 126 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None 127 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): 128 | logging.warning( 129 | "FSDP and ZeRO3 are both currently incompatible with QLoRA." 130 | ) 131 | 132 | compute_dtype = ( 133 | torch.float16 134 | if training_args.fp16 135 | else (torch.bfloat16 if training_args.bf16 else torch.float32) 136 | ) 137 | 138 | model = transformers.AutoModelForCausalLM.from_pretrained( 139 | model_args.model_name_or_path, 140 | cache_dir=training_args.cache_dir, 141 | device_map=device_map, 142 | quantization_config=BitsAndBytesConfig( 143 | load_in_4bit=True, 144 | bnb_4bit_use_double_quant=True, 145 | bnb_4bit_quant_type="nf4", 146 | bnb_4bit_compute_dtype=compute_dtype, 147 | ) 148 | if lora_args.q_lora 149 | else None, 150 | ) 151 | lora_config = LoraConfig( 152 | r=lora_args.lora_r, 153 | lora_alpha=lora_args.lora_alpha, 154 | target_modules=lora_args.lora_target_modules, 155 | lora_dropout=lora_args.lora_dropout, 156 | bias=lora_args.lora_bias, 157 | task_type="CAUSAL_LM", 158 | ) 159 | 160 | if lora_args.q_lora: 161 | model = prepare_model_for_kbit_training( 162 | model, use_gradient_checkpointing=training_args.gradient_checkpointing 163 | ) 164 | if not ddp and torch.cuda.device_count() > 1: 165 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available 166 | model.is_parallelizable = True 167 | model.model_parallel = True 168 | 169 | model = get_peft_model(model, lora_config) 170 | if training_args.flash_attn: 171 | for name, module in model.named_modules(): 172 | if "norm" in name: 173 | module = module.to(compute_dtype) 174 | if "lm_head" in name or "embed_tokens" in name: 175 | if hasattr(module, "weight"): 176 | module = module.to(compute_dtype) 177 | if training_args.deepspeed is not None and training_args.local_rank == 0: 178 | model.print_trainable_parameters() 179 | 180 | if training_args.gradient_checkpointing: 181 | model.enable_input_require_grads() 182 | 183 | tokenizer = transformers.AutoTokenizer.from_pretrained( 184 | model_args.model_name_or_path, 185 | cache_dir=training_args.cache_dir, 186 | model_max_length=training_args.model_max_length, 187 | padding_side="right", 188 | use_fast=False, 189 | ) 190 | tokenizer.pad_token = tokenizer.unk_token 191 | 192 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, process_args=process_args) 193 | trainer = Trainer( 194 | model=model, tokenizer=tokenizer, args=training_args, **data_module 195 | ) 196 | 197 | model.config.use_cache = False 198 | 199 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 200 | trainer.train(resume_from_checkpoint=True) 201 | else: 202 | trainer.train() 203 | trainer.save_state() 204 | 205 | # check if zero3 mode enabled 206 | if deepspeed.is_deepspeed_zero3_enabled(): 207 | # use deepspeed engine internal function to gather state dict 208 | # state_dict_zero3 contains whole parameters of base and lora adapters 209 | # we will not extract lora parameters since peft save_pretrained will do that 210 | # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125 211 | # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19 212 | state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() 213 | if training_args.local_rank == 0: 214 | state_dict = state_dict_zero3 215 | else: 216 | # in other mode we use original code from fastchat team, to make sure our change is minimum 217 | state_dict = get_peft_state_maybe_zero_3( 218 | model.named_parameters(), lora_args.lora_bias 219 | ) 220 | 221 | if training_args.local_rank == 0: 222 | model.save_pretrained(training_args.output_dir, state_dict=state_dict) 223 | 224 | 225 | if __name__ == "__main__": 226 | train() 227 | -------------------------------------------------------------------------------- /llama/fashchat/train/train_lora_concat.py: -------------------------------------------------------------------------------- 1 | # Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> 2 | 3 | # Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: 4 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from dataclasses import dataclass, field 19 | import logging 20 | import pathlib 21 | import typing 22 | import os 23 | 24 | from deepspeed import zero 25 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 26 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 27 | import transformers 28 | from transformers import Trainer, BitsAndBytesConfig, deepspeed 29 | import torch 30 | 31 | from fastchat.train.train_concat import ( 32 | DataArguments, 33 | ModelArguments, 34 | make_supervised_data_module, 35 | ) 36 | 37 | from fastchat.train.llama_flash_attn_monkey_patch import ( 38 | replace_llama_attn_with_flash_attn, 39 | ) 40 | 41 | 42 | @dataclass 43 | class TrainingArguments(transformers.TrainingArguments): 44 | cache_dir: typing.Optional[str] = field(default=None) 45 | optim: str = field(default="adamw_torch") 46 | model_max_length: int = field( 47 | default=4096, 48 | metadata={ 49 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 50 | }, 51 | ) 52 | flash_attn: bool = False 53 | 54 | 55 | @dataclass 56 | class LoraArguments: 57 | lora_r: int = 8 58 | lora_alpha: int = 16 59 | lora_dropout: float = 0.05 60 | lora_target_modules: typing.List[str] = field( 61 | default_factory=lambda: ["q_proj", "v_proj"] 62 | ) 63 | lora_weight_path: str = "" 64 | lora_bias: str = "none" 65 | q_lora: bool = False 66 | 67 | 68 | def maybe_zero_3(param): 69 | if hasattr(param, "ds_id"): 70 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE 71 | with zero.GatheredParameters([param]): 72 | param = param.data.detach().cpu().clone() 73 | else: 74 | param = param.detach().cpu().clone() 75 | return param 76 | 77 | 78 | # Borrowed from peft.utils.get_peft_model_state_dict 79 | def get_peft_state_maybe_zero_3(named_params, bias): 80 | if bias == "none": 81 | to_return = {k: t for k, t in named_params if "lora_" in k} 82 | elif bias == "all": 83 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 84 | elif bias == "lora_only": 85 | to_return = {} 86 | maybe_lora_bias = {} 87 | lora_bias_names = set() 88 | for k, t in named_params: 89 | if "lora_" in k: 90 | to_return[k] = t 91 | bias_name = k.split("lora_")[0] + "bias" 92 | lora_bias_names.add(bias_name) 93 | elif "bias" in k: 94 | maybe_lora_bias[k] = t 95 | for k, t in maybe_lora_bias: 96 | if bias_name in lora_bias_names: 97 | to_return[bias_name] = t 98 | else: 99 | raise NotImplementedError 100 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} 101 | return to_return 102 | 103 | 104 | def train(): 105 | parser = transformers.HfArgumentParser( 106 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments) 107 | ) 108 | parser.add_argument( 109 | "--pre_process_func", type=str, required=True, help="how to process data" 110 | ) 111 | ( 112 | model_args, 113 | data_args, 114 | training_args, 115 | lora_args, 116 | process_args 117 | ) = parser.parse_args_into_dataclasses() 118 | 119 | if training_args.flash_attn: 120 | replace_llama_attn_with_flash_attn() 121 | 122 | device_map = None 123 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 124 | ddp = world_size != 1 125 | if lora_args.q_lora: 126 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None 127 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): 128 | logging.warning( 129 | "FSDP and ZeRO3 are both currently incompatible with QLoRA." 130 | ) 131 | 132 | compute_dtype = ( 133 | torch.float16 134 | if training_args.fp16 135 | else (torch.bfloat16 if training_args.bf16 else torch.float32) 136 | ) 137 | 138 | model = transformers.AutoModelForCausalLM.from_pretrained( 139 | model_args.model_name_or_path, 140 | cache_dir=training_args.cache_dir, 141 | device_map=device_map, 142 | quantization_config=BitsAndBytesConfig( 143 | load_in_4bit=True, 144 | bnb_4bit_use_double_quant=True, 145 | bnb_4bit_quant_type="nf4", 146 | bnb_4bit_compute_dtype=compute_dtype, 147 | ) 148 | if lora_args.q_lora 149 | else None, 150 | ) 151 | lora_config = LoraConfig( 152 | r=lora_args.lora_r, 153 | lora_alpha=lora_args.lora_alpha, 154 | target_modules=lora_args.lora_target_modules, 155 | lora_dropout=lora_args.lora_dropout, 156 | bias=lora_args.lora_bias, 157 | task_type="CAUSAL_LM", 158 | ) 159 | 160 | if lora_args.q_lora: 161 | model = prepare_model_for_kbit_training( 162 | model, use_gradient_checkpointing=training_args.gradient_checkpointing 163 | ) 164 | if not ddp and torch.cuda.device_count() > 1: 165 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available 166 | model.is_parallelizable = True 167 | model.model_parallel = True 168 | 169 | model = get_peft_model(model, lora_config) 170 | if training_args.flash_attn: 171 | for name, module in model.named_modules(): 172 | if "norm" in name: 173 | module = module.to(compute_dtype) 174 | if "lm_head" in name or "embed_tokens" in name: 175 | if hasattr(module, "weight"): 176 | module = module.to(compute_dtype) 177 | if training_args.deepspeed is not None and training_args.local_rank == 0: 178 | model.print_trainable_parameters() 179 | 180 | if training_args.gradient_checkpointing: 181 | model.enable_input_require_grads() 182 | 183 | tokenizer = transformers.AutoTokenizer.from_pretrained( 184 | model_args.model_name_or_path, 185 | cache_dir=training_args.cache_dir, 186 | model_max_length=training_args.model_max_length, 187 | padding_side="right", 188 | use_fast=False, 189 | ) 190 | tokenizer.pad_token = tokenizer.unk_token 191 | 192 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, process_args=process_args) 193 | trainer = Trainer( 194 | model=model, tokenizer=tokenizer, args=training_args, **data_module 195 | ) 196 | 197 | model.config.use_cache = False 198 | 199 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 200 | trainer.train(resume_from_checkpoint=True) 201 | else: 202 | trainer.train() 203 | trainer.save_state() 204 | 205 | # check if zero3 mode enabled 206 | if deepspeed.is_deepspeed_zero3_enabled(): 207 | # use deepspeed engine internal function to gather state dict 208 | # state_dict_zero3 contains whole parameters of base and lora adapters 209 | # we will not extract lora parameters since peft save_pretrained will do that 210 | # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125 211 | # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19 212 | state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() 213 | if training_args.local_rank == 0: 214 | state_dict = state_dict_zero3 215 | else: 216 | # in other mode we use original code from fastchat team, to make sure our change is minimum 217 | state_dict = get_peft_state_maybe_zero_3( 218 | model.named_parameters(), lora_args.lora_bias 219 | ) 220 | 221 | if training_args.local_rank == 0: 222 | model.save_pretrained(training_args.output_dir, state_dict=state_dict) 223 | 224 | 225 | if __name__ == "__main__": 226 | train() 227 | -------------------------------------------------------------------------------- /llama/fashchat/train/train_lora_scratch.py: -------------------------------------------------------------------------------- 1 | # Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> 2 | 3 | # Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: 4 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from dataclasses import dataclass, field 19 | import logging 20 | import pathlib 21 | import typing 22 | import os 23 | 24 | from deepspeed import zero 25 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 26 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 27 | import transformers 28 | from transformers import Trainer, BitsAndBytesConfig, deepspeed 29 | import torch 30 | 31 | from fastchat.train.train_scratch import ( 32 | DataArguments, 33 | ModelArguments, 34 | make_supervised_data_module, 35 | ) 36 | 37 | from fastchat.train.llama_flash_attn_monkey_patch import ( 38 | replace_llama_attn_with_flash_attn, 39 | ) 40 | 41 | 42 | @dataclass 43 | class TrainingArguments(transformers.TrainingArguments): 44 | cache_dir: typing.Optional[str] = field(default=None) 45 | optim: str = field(default="adamw_torch") 46 | model_max_length: int = field( 47 | default=512, 48 | metadata={ 49 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 50 | }, 51 | ) 52 | flash_attn: bool = False 53 | 54 | 55 | @dataclass 56 | class LoraArguments: 57 | lora_r: int = 8 58 | lora_alpha: int = 16 59 | lora_dropout: float = 0.05 60 | lora_target_modules: typing.List[str] = field( 61 | default_factory=lambda: ["q_proj", "v_proj"] 62 | ) 63 | lora_weight_path: str = "" 64 | lora_bias: str = "none" 65 | q_lora: bool = False 66 | 67 | 68 | def maybe_zero_3(param): 69 | if hasattr(param, "ds_id"): 70 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE 71 | with zero.GatheredParameters([param]): 72 | param = param.data.detach().cpu().clone() 73 | else: 74 | param = param.detach().cpu().clone() 75 | return param 76 | 77 | 78 | # Borrowed from peft.utils.get_peft_model_state_dict 79 | def get_peft_state_maybe_zero_3(named_params, bias): 80 | if bias == "none": 81 | to_return = {k: t for k, t in named_params if "lora_" in k} 82 | elif bias == "all": 83 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 84 | elif bias == "lora_only": 85 | to_return = {} 86 | maybe_lora_bias = {} 87 | lora_bias_names = set() 88 | for k, t in named_params: 89 | if "lora_" in k: 90 | to_return[k] = t 91 | bias_name = k.split("lora_")[0] + "bias" 92 | lora_bias_names.add(bias_name) 93 | elif "bias" in k: 94 | maybe_lora_bias[k] = t 95 | for k, t in maybe_lora_bias: 96 | if bias_name in lora_bias_names: 97 | to_return[bias_name] = t 98 | else: 99 | raise NotImplementedError 100 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} 101 | return to_return 102 | 103 | 104 | def train(): 105 | parser = transformers.HfArgumentParser( 106 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments) 107 | ) 108 | parser.add_argument( 109 | "--pre_process_func", type=str, required=True, help="how to process data" 110 | ) 111 | ( 112 | model_args, 113 | data_args, 114 | training_args, 115 | lora_args, 116 | process_args 117 | ) = parser.parse_args_into_dataclasses() 118 | 119 | if training_args.flash_attn: 120 | replace_llama_attn_with_flash_attn() 121 | 122 | device_map = None 123 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 124 | ddp = world_size != 1 125 | if lora_args.q_lora: 126 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None 127 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): 128 | logging.warning( 129 | "FSDP and ZeRO3 are both currently incompatible with QLoRA." 130 | ) 131 | 132 | compute_dtype = ( 133 | torch.float16 134 | if training_args.fp16 135 | else (torch.bfloat16 if training_args.bf16 else torch.float32) 136 | ) 137 | 138 | model = transformers.AutoModelForCausalLM.from_pretrained( 139 | model_args.model_name_or_path, 140 | cache_dir=training_args.cache_dir, 141 | device_map=device_map, 142 | quantization_config=BitsAndBytesConfig( 143 | load_in_4bit=True, 144 | bnb_4bit_use_double_quant=True, 145 | bnb_4bit_quant_type="nf4", 146 | bnb_4bit_compute_dtype=compute_dtype, 147 | ) 148 | if lora_args.q_lora 149 | else None, 150 | ) 151 | lora_config = LoraConfig( 152 | r=lora_args.lora_r, 153 | lora_alpha=lora_args.lora_alpha, 154 | target_modules=lora_args.lora_target_modules, 155 | lora_dropout=lora_args.lora_dropout, 156 | bias=lora_args.lora_bias, 157 | task_type="CAUSAL_LM", 158 | ) 159 | 160 | if lora_args.q_lora: 161 | model = prepare_model_for_kbit_training( 162 | model, use_gradient_checkpointing=training_args.gradient_checkpointing 163 | ) 164 | if not ddp and torch.cuda.device_count() > 1: 165 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available 166 | model.is_parallelizable = True 167 | model.model_parallel = True 168 | 169 | model = get_peft_model(model, lora_config) 170 | if training_args.flash_attn: 171 | for name, module in model.named_modules(): 172 | if "norm" in name: 173 | module = module.to(compute_dtype) 174 | if "lm_head" in name or "embed_tokens" in name: 175 | if hasattr(module, "weight"): 176 | module = module.to(compute_dtype) 177 | if training_args.deepspeed is not None and training_args.local_rank == 0: 178 | model.print_trainable_parameters() 179 | 180 | if training_args.gradient_checkpointing: 181 | model.enable_input_require_grads() 182 | 183 | tokenizer = transformers.AutoTokenizer.from_pretrained( 184 | model_args.model_name_or_path, 185 | cache_dir=training_args.cache_dir, 186 | model_max_length=training_args.model_max_length, 187 | padding_side="right", 188 | use_fast=False, 189 | ) 190 | tokenizer.pad_token = tokenizer.unk_token 191 | 192 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, process_args=process_args) 193 | trainer = Trainer( 194 | model=model, tokenizer=tokenizer, args=training_args, **data_module 195 | ) 196 | 197 | model.config.use_cache = False 198 | 199 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 200 | trainer.train(resume_from_checkpoint=True) 201 | else: 202 | trainer.train() 203 | trainer.save_state() 204 | 205 | # check if zero3 mode enabled 206 | if deepspeed.is_deepspeed_zero3_enabled(): 207 | # use deepspeed engine internal function to gather state dict 208 | # state_dict_zero3 contains whole parameters of base and lora adapters 209 | # we will not extract lora parameters since peft save_pretrained will do that 210 | # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125 211 | # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19 212 | state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() 213 | if training_args.local_rank == 0: 214 | state_dict = state_dict_zero3 215 | else: 216 | # in other mode we use original code from fastchat team, to make sure our change is minimum 217 | state_dict = get_peft_state_maybe_zero_3( 218 | model.named_parameters(), lora_args.lora_bias 219 | ) 220 | 221 | if training_args.local_rank == 0: 222 | model.save_pretrained(training_args.output_dir, state_dict=state_dict) 223 | 224 | 225 | if __name__ == "__main__": 226 | train() 227 | -------------------------------------------------------------------------------- /llama/fashchat/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 2 | 3 | # Need to call this before importing transformers. 4 | from fastchat.train.llama2_flash_attn_monkey_patch import ( 5 | replace_llama_attn_with_flash_attn, 6 | ) 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from fastchat.train.train_from_scratch import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llama/fashchat/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common utilities. 3 | """ 4 | from asyncio import AbstractEventLoop 5 | import json 6 | import logging 7 | import logging.handlers 8 | import os 9 | import platform 10 | import sys 11 | from typing import AsyncGenerator, Generator 12 | import warnings 13 | 14 | import requests 15 | 16 | from fastchat.constants import LOGDIR 17 | 18 | 19 | handler = None 20 | visited_loggers = set() 21 | 22 | 23 | def build_logger(logger_name, logger_filename): 24 | global handler 25 | 26 | formatter = logging.Formatter( 27 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 28 | datefmt="%Y-%m-%d %H:%M:%S", 29 | ) 30 | 31 | # Set the format of root handlers 32 | if not logging.getLogger().handlers: 33 | if sys.version_info[1] >= 9: 34 | # This is for windows 35 | logging.basicConfig(level=logging.INFO, encoding="utf-8") 36 | else: 37 | if platform.system() == "Windows": 38 | warnings.warn( 39 | "If you are running on Windows, " 40 | "we recommend you use Python >= 3.9 for UTF-8 encoding." 41 | ) 42 | logging.basicConfig(level=logging.INFO) 43 | logging.getLogger().handlers[0].setFormatter(formatter) 44 | 45 | # Redirect stdout and stderr to loggers 46 | stdout_logger = logging.getLogger("stdout") 47 | stdout_logger.setLevel(logging.INFO) 48 | sl = StreamToLogger(stdout_logger, logging.INFO) 49 | sys.stdout = sl 50 | 51 | stderr_logger = logging.getLogger("stderr") 52 | stderr_logger.setLevel(logging.ERROR) 53 | sl = StreamToLogger(stderr_logger, logging.ERROR) 54 | sys.stderr = sl 55 | 56 | # Get logger 57 | logger = logging.getLogger(logger_name) 58 | logger.setLevel(logging.INFO) 59 | 60 | # if LOGDIR is empty, then don't try output log to local file 61 | if LOGDIR != "": 62 | os.makedirs(LOGDIR, exist_ok=True) 63 | filename = os.path.join(LOGDIR, logger_filename) 64 | handler = logging.handlers.TimedRotatingFileHandler( 65 | filename, when="D", utc=True, encoding="utf-8" 66 | ) 67 | handler.setFormatter(formatter) 68 | 69 | for l in [stdout_logger, stderr_logger, logger]: 70 | if l in visited_loggers: 71 | continue 72 | visited_loggers.add(l) 73 | l.addHandler(handler) 74 | 75 | return logger 76 | 77 | 78 | class StreamToLogger(object): 79 | """ 80 | Fake file-like stream object that redirects writes to a logger instance. 81 | """ 82 | 83 | def __init__(self, logger, log_level=logging.INFO): 84 | self.terminal = sys.stdout 85 | self.logger = logger 86 | self.log_level = log_level 87 | self.linebuf = "" 88 | 89 | def __getattr__(self, attr): 90 | return getattr(self.terminal, attr) 91 | 92 | def write(self, buf): 93 | temp_linebuf = self.linebuf + buf 94 | self.linebuf = "" 95 | for line in temp_linebuf.splitlines(True): 96 | # From the io.TextIOWrapper docs: 97 | # On output, if newline is None, any '\n' characters written 98 | # are translated to the system default line separator. 99 | # By default sys.stdout.write() expects '\n' newlines and then 100 | # translates them so this is still cross platform. 101 | if line[-1] == "\n": 102 | encoded_message = line.encode("utf-8", "ignore").decode("utf-8") 103 | self.logger.log(self.log_level, encoded_message.rstrip()) 104 | else: 105 | self.linebuf += line 106 | 107 | def flush(self): 108 | if self.linebuf != "": 109 | encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") 110 | self.logger.log(self.log_level, encoded_message.rstrip()) 111 | self.linebuf = "" 112 | 113 | 114 | def disable_torch_init(): 115 | """ 116 | Disable the redundant torch default initialization to accelerate model creation. 117 | """ 118 | import torch 119 | 120 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 121 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 122 | 123 | 124 | def get_gpu_memory(max_gpus=None): 125 | """Get available memory for each GPU.""" 126 | import torch 127 | 128 | gpu_memory = [] 129 | num_gpus = ( 130 | torch.cuda.device_count() 131 | if max_gpus is None 132 | else min(max_gpus, torch.cuda.device_count()) 133 | ) 134 | 135 | for gpu_id in range(num_gpus): 136 | with torch.cuda.device(gpu_id): 137 | device = torch.cuda.current_device() 138 | gpu_properties = torch.cuda.get_device_properties(device) 139 | total_memory = gpu_properties.total_memory / (1024**3) 140 | allocated_memory = torch.cuda.memory_allocated() / (1024**3) 141 | available_memory = total_memory - allocated_memory 142 | gpu_memory.append(available_memory) 143 | return gpu_memory 144 | 145 | 146 | def oai_moderation(text): 147 | """ 148 | Check whether the text violates OpenAI moderation API. 149 | """ 150 | import openai 151 | 152 | openai.api_base = "https://api.openai.com/v1" 153 | openai.api_key = os.environ["OPENAI_API_KEY"] 154 | openai.api_type = "open_ai" 155 | openai.api_version = None 156 | 157 | MAX_RETRY = 3 158 | for i in range(MAX_RETRY): 159 | try: 160 | res = openai.Moderation.create(input=text) 161 | flagged = res["results"][0]["flagged"] 162 | break 163 | except (openai.error.OpenAIError, KeyError, IndexError) as e: 164 | # flag true to be conservative 165 | flagged = True 166 | print(f"MODERATION ERROR: {e}\nInput: {text}") 167 | return flagged 168 | 169 | 170 | def moderation_filter(text, model_list): 171 | MODEL_KEYWORDS = ["claude"] 172 | 173 | for keyword in MODEL_KEYWORDS: 174 | for model in model_list: 175 | if keyword in model and oai_moderation(text): 176 | return True 177 | return False 178 | 179 | 180 | def clean_flant5_ckpt(ckpt_path): 181 | """ 182 | Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, 183 | Use this function to make sure it can be correctly loaded. 184 | """ 185 | import torch 186 | 187 | index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json") 188 | index_json = json.load(open(index_file, "r")) 189 | 190 | weightmap = index_json["weight_map"] 191 | 192 | share_weight_file = weightmap["shared.weight"] 193 | share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[ 194 | "shared.weight" 195 | ] 196 | 197 | for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]: 198 | weight_file = weightmap[weight_name] 199 | weight = torch.load(os.path.join(ckpt_path, weight_file)) 200 | weight[weight_name] = share_weight 201 | torch.save(weight, os.path.join(ckpt_path, weight_file)) 202 | 203 | 204 | def pretty_print_semaphore(semaphore): 205 | """Print a semaphore in better format.""" 206 | if semaphore is None: 207 | return "None" 208 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 209 | 210 | 211 | """A javascript function to get url parameters for the gradio web server.""" 212 | get_window_url_params_js = """ 213 | function() { 214 | const params = new URLSearchParams(window.location.search); 215 | url_params = Object.fromEntries(params); 216 | console.log("url_params", url_params); 217 | return url_params; 218 | } 219 | """ 220 | 221 | 222 | get_window_url_params_with_tos_js = """ 223 | function() { 224 | const params = new URLSearchParams(window.location.search); 225 | url_params = Object.fromEntries(params); 226 | console.log("url_params", url_params); 227 | 228 | msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nThe service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license." 229 | alert(msg); 230 | 231 | return url_params; 232 | } 233 | """ 234 | 235 | 236 | def iter_over_async( 237 | async_gen: AsyncGenerator, event_loop: AbstractEventLoop 238 | ) -> Generator: 239 | """ 240 | Convert async generator to sync generator 241 | 242 | :param async_gen: the AsyncGenerator to convert 243 | :param event_loop: the event loop to run on 244 | :returns: Sync generator 245 | """ 246 | ait = async_gen.__aiter__() 247 | 248 | async def get_next(): 249 | try: 250 | obj = await ait.__anext__() 251 | return False, obj 252 | except StopAsyncIteration: 253 | return True, None 254 | 255 | while True: 256 | done, obj = event_loop.run_until_complete(get_next()) 257 | if done: 258 | break 259 | yield obj 260 | 261 | 262 | def detect_language(text: str) -> str: 263 | """Detect the langauge of a string.""" 264 | import polyglot # pip3 install polyglot pyicu pycld2 265 | from polyglot.detect import Detector 266 | from polyglot.detect.base import logger as polyglot_logger 267 | import pycld2 268 | 269 | polyglot_logger.setLevel("ERROR") 270 | 271 | try: 272 | lang_code = Detector(text).language.name 273 | except (pycld2.error, polyglot.detect.base.UnknownLanguage): 274 | lang_code = "unknown" 275 | return lang_code 276 | 277 | 278 | def parse_gradio_auth_creds(filename: str): 279 | """Parse a username:password file for gradio authorization.""" 280 | gradio_auth_creds = [] 281 | with open(filename, "r", encoding="utf8") as file: 282 | for line in file.readlines(): 283 | gradio_auth_creds += [x.strip() for x in line.split(",") if x.strip()] 284 | if gradio_auth_creds: 285 | auth = [tuple(cred.split(":")) for cred in gradio_auth_creds] 286 | else: 287 | auth = None 288 | return auth 289 | 290 | 291 | def is_partial_stop(output: str, stop_str: str): 292 | """Check whether the output contains a partial stop str.""" 293 | for i in range(0, min(len(output), len(stop_str))): 294 | if stop_str.startswith(output[-i:]): 295 | return True 296 | return False 297 | 298 | 299 | def run_cmd(cmd: str): 300 | """Run a bash command.""" 301 | print(cmd) 302 | return os.system(cmd) 303 | 304 | 305 | def is_sentence_complete(output: str): 306 | """Check whether the output is a complete sentence.""" 307 | end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”") 308 | return output.endswith(end_symbols) 309 | 310 | 311 | # Models don't use the same configuration key for determining the maximum 312 | # sequence length. Store them here so we can sanely check them. 313 | # NOTE: The ordering here is important. Some models have two of these and we 314 | # have a preference for which value gets used. 315 | SEQUENCE_LENGTH_KEYS = [ 316 | "max_sequence_length", 317 | "seq_length", 318 | "max_position_embeddings", 319 | "max_seq_len", 320 | "model_max_length", 321 | ] 322 | 323 | 324 | def get_context_length(config): 325 | """Get the context length of a model from a huggingface model config.""" 326 | rope_scaling = getattr(config, "rope_scaling", None) 327 | if rope_scaling: 328 | rope_scaling_factor = config.rope_scaling["factor"] 329 | else: 330 | rope_scaling_factor = 1 331 | 332 | for key in SEQUENCE_LENGTH_KEYS: 333 | val = getattr(config, key, None) 334 | if val is not None: 335 | return int(rope_scaling_factor * val) 336 | return 2048 337 | 338 | 339 | def str_to_torch_dtype(dtype: str): 340 | import torch 341 | 342 | if dtype is None: 343 | return None 344 | elif dtype == "float32": 345 | return torch.float32 346 | elif dtype == "float16": 347 | return torch.float16 348 | elif dtype == "bfloat16": 349 | return torch.bfloat16 350 | else: 351 | raise ValueError(f"Unrecognized dtype: {dtype}") 352 | -------------------------------------------------------------------------------- /llama/playground/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/Case_or_Rule/1f96b052b6a83300220812eee8c4e43af6789489/llama/playground/.DS_Store -------------------------------------------------------------------------------- /llama/playground/deepspeed_config_s2.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "offload_optimizer": { 5 | "device": "cpu" 6 | }, 7 | "contiguous_gradients": true, 8 | "overlap_comm": true 9 | }, 10 | "fp16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "gradient_accumulation_steps": "auto" 15 | } -------------------------------------------------------------------------------- /llama/playground/deepspeed_config_s3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "zero_optimization": { 11 | "stage": 3, 12 | "offload_optimizer": { 13 | "device": "cpu", 14 | "pin_memory": true 15 | }, 16 | "offload_param": { 17 | "device": "cpu", 18 | "pin_memory": true 19 | }, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true, 22 | "stage3_max_live_parameters" : 1e9, 23 | "stage3_max_reuse_distance" : 1e9, 24 | "stage3_prefetch_bucket_size" : 5e8, 25 | "stage3_param_persistence_threshold" : 1e6, 26 | "sub_group_size" : 1e12, 27 | "stage3_gather_16bit_weights_on_model_save": true 28 | }, 29 | "train_batch_size": "auto", 30 | "train_micro_batch_size_per_gpu": "auto", 31 | "gradient_accumulation_steps": "auto" 32 | } -------------------------------------------------------------------------------- /llama/playground/test_embedding/README.md: -------------------------------------------------------------------------------- 1 | ## Machine Learning with Embeddings 2 | You can use embeddings to 3 | - Evaluate text similarity, see [test_sentence_similarity.py](test_sentence_similarity.py) 4 | - Build your own classifier, see [test_classification.py](test_classification.py) 5 | - Search relative texts, see [test_semantic_search.py](test_semantic_search.py) 6 | 7 | To these tests, you need to download the data [here](https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews). You also need an OpenAI API key for comparison. 8 | 9 | Run with: 10 | ```bash 11 | cd playground/test_embedding 12 | python3 test_classification.py 13 | ``` 14 | 15 | The script will train classifiers based on `vicuna-7b`, `text-similarity-ada-001` and `text-embedding-ada-002` and report the accuracy of each classifier. 16 | -------------------------------------------------------------------------------- /llama/playground/test_embedding/test_classification.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import openai 6 | import pandas as pd 7 | import requests 8 | from sklearn.ensemble import RandomForestClassifier 9 | from sklearn.model_selection import train_test_split 10 | from sklearn.metrics import classification_report, accuracy_score 11 | 12 | 13 | np.set_printoptions(threshold=10000) 14 | 15 | 16 | def get_embedding_from_api(word, model="vicuna-7b-v1.1"): 17 | if "ada" in model: 18 | resp = openai.Embedding.create( 19 | model=model, 20 | input=word, 21 | ) 22 | embedding = np.array(resp["data"][0]["embedding"]) 23 | return embedding 24 | 25 | url = "http://localhost:8000/v1/embeddings" 26 | headers = {"Content-Type": "application/json"} 27 | data = json.dumps({"model": model, "input": word}) 28 | 29 | response = requests.post(url, headers=headers, data=data) 30 | if response.status_code == 200: 31 | embedding = np.array(response.json()["data"][0]["embedding"]) 32 | return embedding 33 | else: 34 | print(f"Error: {response.status_code} - {response.text}") 35 | return None 36 | 37 | 38 | def create_embedding_data_frame(data_path, model, max_tokens=500): 39 | df = pd.read_csv(data_path, index_col=0) 40 | df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]] 41 | df = df.dropna() 42 | df["combined"] = ( 43 | "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip() 44 | ) 45 | top_n = 1000 46 | df = df.sort_values("Time").tail(top_n * 2) 47 | df.drop("Time", axis=1, inplace=True) 48 | 49 | df["n_tokens"] = df.combined.apply(lambda x: len(x)) 50 | df = df[df.n_tokens <= max_tokens].tail(top_n) 51 | df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model)) 52 | return df 53 | 54 | 55 | def train_random_forest(df): 56 | X_train, X_test, y_train, y_test = train_test_split( 57 | list(df.embedding.values), df.Score, test_size=0.2, random_state=42 58 | ) 59 | 60 | clf = RandomForestClassifier(n_estimators=100) 61 | clf.fit(X_train, y_train) 62 | preds = clf.predict(X_test) 63 | 64 | report = classification_report(y_test, preds) 65 | accuracy = accuracy_score(y_test, preds) 66 | return clf, accuracy, report 67 | 68 | 69 | input_datapath = "amazon_fine_food_review.csv" 70 | if not os.path.exists(input_datapath): 71 | raise Exception( 72 | f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews" 73 | ) 74 | 75 | df = create_embedding_data_frame(input_datapath, "vicuna-7b-v1.1") 76 | clf, accuracy, report = train_random_forest(df) 77 | print(f"Vicuna-7b-v1.1 accuracy:{accuracy}") 78 | df = create_embedding_data_frame(input_datapath, "text-similarity-ada-001") 79 | clf, accuracy, report = train_random_forest(df) 80 | print(f"text-similarity-ada-001 accuracy:{accuracy}") 81 | df = create_embedding_data_frame(input_datapath, "text-embedding-ada-002") 82 | clf, accuracy, report = train_random_forest(df) 83 | print(f"text-embedding-ada-002 accuracy:{accuracy}") 84 | -------------------------------------------------------------------------------- /llama/playground/test_embedding/test_semantic_search.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import openai 6 | import pandas as pd 7 | import requests 8 | from scipy.spatial.distance import cosine 9 | 10 | 11 | def cosine_similarity(vec1, vec2): 12 | try: 13 | return 1 - cosine(vec1, vec2) 14 | except: 15 | print(vec1.shape, vec2.shape) 16 | 17 | 18 | def get_embedding_from_api(word, model="vicuna-7b-v1.1"): 19 | if "ada" in model: 20 | resp = openai.Embedding.create( 21 | model=model, 22 | input=word, 23 | ) 24 | embedding = np.array(resp["data"][0]["embedding"]) 25 | return embedding 26 | 27 | url = "http://localhost:8000/v1/embeddings" 28 | headers = {"Content-Type": "application/json"} 29 | data = json.dumps({"model": model, "input": word}) 30 | 31 | response = requests.post(url, headers=headers, data=data) 32 | if response.status_code == 200: 33 | embedding = np.array(response.json()["data"][0]["embedding"]) 34 | return embedding 35 | else: 36 | print(f"Error: {response.status_code} - {response.text}") 37 | return None 38 | 39 | 40 | def create_embedding_data_frame(data_path, model, max_tokens=500): 41 | df = pd.read_csv(data_path, index_col=0) 42 | df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]] 43 | df = df.dropna() 44 | df["combined"] = ( 45 | "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip() 46 | ) 47 | top_n = 1000 48 | df = df.sort_values("Time").tail(top_n * 2) 49 | df.drop("Time", axis=1, inplace=True) 50 | 51 | df["n_tokens"] = df.combined.apply(lambda x: len(x)) 52 | df = df[df.n_tokens <= max_tokens].tail(top_n) 53 | df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model)) 54 | return df 55 | 56 | 57 | def search_reviews(df, product_description, n=3, pprint=False, model="vicuna-7b-v1.1"): 58 | product_embedding = get_embedding_from_api(product_description, model=model) 59 | df["similarity"] = df.embedding.apply( 60 | lambda x: cosine_similarity(x, product_embedding) 61 | ) 62 | 63 | results = ( 64 | df.sort_values("similarity", ascending=False) 65 | .head(n) 66 | .combined.str.replace("Title: ", "") 67 | .str.replace("; Content:", ": ") 68 | ) 69 | if pprint: 70 | for r in results: 71 | print(r[:200]) 72 | print() 73 | return results 74 | 75 | 76 | def print_model_search(input_path, model): 77 | print(f"Model: {model}") 78 | df = create_embedding_data_frame(input_path, model) 79 | print("search: delicious beans") 80 | results = search_reviews(df, "delicious beans", n=5, model=model) 81 | print(results) 82 | print("search: whole wheat pasta") 83 | results = search_reviews(df, "whole wheat pasta", n=5, model=model) 84 | print(results) 85 | print("search: bad delivery") 86 | results = search_reviews(df, "bad delivery", n=5, model=model) 87 | print(results) 88 | 89 | 90 | input_datapath = "amazon_fine_food_review.csv" 91 | if not os.path.exists(input_datapath): 92 | raise Exception( 93 | f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews" 94 | ) 95 | 96 | 97 | print_model_search(input_datapath, "vicuna-7b-v1.1") 98 | print_model_search(input_datapath, "text-similarity-ada-001") 99 | print_model_search(input_datapath, "text-embedding-ada-002") 100 | -------------------------------------------------------------------------------- /llama/playground/test_embedding/test_sentence_similarity.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import openai 6 | import requests 7 | from scipy.spatial.distance import cosine 8 | 9 | 10 | def get_embedding_from_api(word, model="vicuna-7b-v1.5"): 11 | if "ada" in model: 12 | resp = openai.Embedding.create( 13 | model=model, 14 | input=word, 15 | ) 16 | embedding = np.array(resp["data"][0]["embedding"]) 17 | return embedding 18 | 19 | url = "http://localhost:8000/v1/embeddings" 20 | headers = {"Content-Type": "application/json"} 21 | data = json.dumps({"model": model, "input": word}) 22 | 23 | response = requests.post(url, headers=headers, data=data) 24 | if response.status_code == 200: 25 | embedding = np.array(response.json()["data"][0]["embedding"]) 26 | return embedding 27 | else: 28 | print(f"Error: {response.status_code} - {response.text}") 29 | return None 30 | 31 | 32 | def cosine_similarity(vec1, vec2): 33 | return 1 - cosine(vec1, vec2) 34 | 35 | 36 | def print_cosine_similarity(embeddings, texts): 37 | for i in range(len(texts)): 38 | for j in range(i + 1, len(texts)): 39 | sim = cosine_similarity(embeddings[texts[i]], embeddings[texts[j]]) 40 | print(f"Cosine similarity between '{texts[i]}' and '{texts[j]}': {sim:.2f}") 41 | 42 | 43 | texts = [ 44 | "The quick brown fox", 45 | "The quick brown dog", 46 | "The fast brown fox", 47 | "A completely different sentence", 48 | ] 49 | 50 | embeddings = {} 51 | for text in texts: 52 | embeddings[text] = get_embedding_from_api(text) 53 | 54 | print("Vicuna-7B:") 55 | print_cosine_similarity(embeddings, texts) 56 | 57 | for text in texts: 58 | embeddings[text] = get_embedding_from_api(text, model="text-similarity-ada-001") 59 | 60 | print("text-similarity-ada-001:") 61 | print_cosine_similarity(embeddings, texts) 62 | 63 | for text in texts: 64 | embeddings[text] = get_embedding_from_api(text, model="text-embedding-ada-002") 65 | 66 | print("text-embedding-ada-002:") 67 | print_cosine_similarity(embeddings, texts) 68 | -------------------------------------------------------------------------------- /llama/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "fschat" 7 | version = "0.2.34" 8 | description = "An open platform for training, serving, and evaluating large language model based chatbots." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "aiohttp", "fastapi", "httpx", "markdown2[all]", "nh3", "numpy", 17 | "prompt_toolkit>=3.0.0", "pydantic<2,>=1", "requests", "rich>=10.0.0", 18 | "shortuuid", "tiktoken", "uvicorn", 19 | ] 20 | 21 | [project.optional-dependencies] 22 | model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"] 23 | webui = ["gradio"] 24 | train = ["einops", "flash-attn>=2.0", "wandb"] 25 | llm_judge = ["openai<1", "anthropic>=0.3", "ray"] 26 | dev = ["black==23.3.0", "pylint==2.8.2"] 27 | 28 | [project.urls] 29 | "Homepage" = "https://github.com/lm-sys/fastchat" 30 | "Bug Tracker" = "https://github.com/lm-sys/fastchat/issues" 31 | 32 | [tool.setuptools.packages.find] 33 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 34 | 35 | [tool.wheel] 36 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 37 | -------------------------------------------------------------------------------- /llama/train.sh: -------------------------------------------------------------------------------- 1 | PATH_TO_DEEPSPEED_CONFIG=playground/deepspeed_config_s2.json 2 | deepspeed fastchat/train/train_lora.py \ 3 | --pre_process None \ 4 | --model_name_or_path MODEL_PATH \ 5 | --lora_r 8 \ 6 | --lora_alpha 16 \ 7 | --lora_dropout 0.05 \ 8 | --data_path $DATA_PATH \ 9 | --output_dir OUTPUT_DIR \ 10 | --num_train_epochs 1 \ 11 | --fp16 True \ 12 | --per_device_train_batch_size 1 \ 13 | --per_device_eval_batch_size 2 \ 14 | --gradient_accumulation_steps 1 \ 15 | --evaluation_strategy "steps" \ 16 | --eval_steps 1000 \ 17 | --eval_data_path $TEST_DATA_PATH \ 18 | --save_strategy "steps" \ 19 | --save_steps 625 \ 20 | --save_total_limit 5 \ 21 | --learning_rate 2e-5 \ 22 | --weight_decay 0. \ 23 | --warmup_ratio 0.03 \ 24 | --lr_scheduler_type "cosine" \ 25 | --logging_strategy "steps" \ 26 | --logging_steps 1 \ 27 | --tf32 True \ 28 | --model_max_length 4096 \ 29 | --q_lora False \ 30 | --deepspeed $PATH_TO_DEEPSPEED_CONFIG \ 31 | --gradient_checkpointing True \ 32 | --flash_attn True 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | interval==1.0.0 2 | matplotlib==3.8.2 3 | numpy==1.26.4 4 | pandas==2.2.0 5 | seaborn==0.13.2 6 | torch==2.1.2+cu118 7 | tqdm==4.63.0 8 | transformers==4.37.0 9 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer 4 | import transformers 5 | from transformers import AutoTokenizer 6 | from dataset import GPT2Dataset, GPT2DatasetReplace, TestDataset 7 | from tqdm import tqdm 8 | from torch.utils.tensorboard import SummaryWriter 9 | import time 10 | import re 11 | 12 | 13 | task = "addition" 14 | title = "1hole_(50, 50)_10_441_0-100" 15 | model_name = "gpt2" 16 | device = torch.device("cuda:0") 17 | print(f"running {task} - {title}...") 18 | 19 | save_model_path = f"save_model_{model_name}" 20 | 21 | # hyperparameters here 22 | log_step = 200 23 | num_epoch = 100 24 | batchsize = 30 25 | lr = 1e-4 26 | weight_decay= 0 27 | random_seed = 42 28 | torch.manual_seed(random_seed) 29 | import random 30 | random.seed(random_seed) 31 | 32 | load_checkpoint = None 33 | # tensorboard writer 34 | timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 35 | writer = SummaryWriter(log_dir='log/{}_{}_{}_{}'.format(model_name, task, title, timestamp)) 36 | 37 | # load pretrain model 38 | print(f"loading pretrained model: {model_name}...") 39 | model = GPT2LMHeadModel.from_pretrained(f"pretrained_models/{model_name}") 40 | tokenizer = GPT2Tokenizer.from_pretrained(f"pretrained_models/{model_name}") 41 | print("done") 42 | 43 | if task == "mod_addition" and "cot" in title: 44 | l = 100 45 | else: 46 | l = 50 47 | print(f"setting max length to {l}") 48 | train_dataset = GPT2Dataset(file_path='datasets/{}/{}/train.json'.format(task,title), max_length=l) 49 | valid_dataset = TestDataset(file_path='datasets/{}/{}/test.json'.format(task,title)) 50 | test_dataset = TestDataset(file_path='datasets/{}/{}/test.json'.format(task,title)) 51 | 52 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True) 53 | 54 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) 55 | num_training_steps = num_epoch * len(train_dataloader) 56 | num_warmup_steps = 0.01 * num_training_steps 57 | lr_scheduler = transformers.get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) 58 | 59 | progress_bar = tqdm(range(num_training_steps)) 60 | 61 | def extract_answer(s: str, mode=task): 62 | if mode in ["addition", "mod_addition", "base_addition", "linear_regression"]: 63 | return re.findall(r'[0-9]+\.?[0-9]*', s)[-1] 64 | 65 | elif mode in ["chickens_and_rabbits"]: 66 | return re.findall(r'[0-9]+\.?[0-9]*', s)[-2:] 67 | 68 | elif mode in ["addition_code"]: 69 | try: 70 | l = eval(re.findall(r'\[.+\]', s)[-1]) 71 | return l 72 | except: 73 | print(s) 74 | return None 75 | 76 | def valid_and_test(model, valid_dataset, test_dataset, device, step): 77 | with torch.no_grad(): 78 | model.eval() 79 | valid_correct = 0 80 | ctr = 0 81 | for valid_question, valid_answer in tqdm(valid_dataset): 82 | valid_answer = extract_answer(str(valid_answer)) 83 | outputs = model.generate(valid_question.to(device), max_length=l, num_beams=1, do_sample=False, pad_token_id=50257) # no padding, greedy decoding 84 | generated_answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] 85 | if ctr % 500 == 0: 86 | tqdm.write('-'*40) 87 | tqdm.write(generated_answer) 88 | tqdm.write('### The groundtruth is {}'.format(valid_answer)) 89 | generated_answer = generated_answer[len(tokenizer.decode(valid_question.squeeze())):] 90 | generated_answer = extract_answer(generated_answer) 91 | if generated_answer is None: 92 | continue 93 | if generated_answer == valid_answer: 94 | valid_correct += 1 95 | ctr += 1 96 | tqdm.write('valid accuracy: {}, num of valid samples: {}'.format(valid_correct/(len(valid_answer) * len(valid_dataset)), len(valid_dataset))) 97 | writer.add_scalar('valid_accuracy', valid_correct/len(valid_dataset), step) 98 | 99 | # main loop 100 | step = 0 101 | for epoch in range(num_epoch): 102 | optimizer.zero_grad() 103 | model.to(device) 104 | model.train() 105 | for batch in train_dataloader: 106 | batch = batch[0] # the labels is not used, because the labels is the same as the input_ids 107 | labels = batch 108 | outputs = model(batch.to(device), labels=labels.to(device)) 109 | loss = outputs.loss 110 | loss.backward() 111 | 112 | optimizer.step() 113 | lr_scheduler.step() 114 | optimizer.zero_grad() 115 | progress_bar.update(1) 116 | 117 | writer.add_scalar('loss', loss.item(), progress_bar.n) 118 | writer.add_scalar('lr', lr_scheduler.get_last_lr()[0], progress_bar.n) 119 | if progress_bar.n % 600 == 0: 120 | valid_and_test(model, valid_dataset, test_dataset, device, step=progress_bar.n) 121 | if step % log_step == 0: 122 | tqdm.write('epoch {}, step {}, loss {}, lr: {}'.format(epoch, progress_bar.n, loss.item(), lr_scheduler.get_last_lr()[0])) 123 | step += 1 124 | 125 | GPT2LMHeadModel.save_pretrained(model, os.path.join(save_model_path, task, title, f"model_{epoch}")) 126 | --------------------------------------------------------------------------------