├── .gitattributes ├── .gitignore ├── .vsls.json ├── LICENSE ├── README.md ├── conftest.py ├── convert ├── bin_to_safetensors.py ├── convert_hf_to_lit.py ├── convert_lit_to_hf.py ├── convert_pretrained_checkpoint.py └── requirements.txt ├── dpo ├── README.md ├── __init__.py ├── adapters │ ├── __init__.py │ ├── base.py │ ├── distilabel_capybara_dpo.py │ ├── intel_orca_pair.py │ ├── nectar.py │ ├── open_hermes_preferences.py │ ├── ultrafeedback_binarized.py │ ├── ultrafeedback_binarized_test.jsonl │ └── ultrafeedback_binarized_test.py ├── ds_config_stage1.json ├── ds_config_stage2.json ├── ds_config_stage3.json ├── main.py ├── make_dataset.py └── requirements.txt ├── eval ├── README.md ├── eval.py ├── huggingface_llm_eval.sh └── huggingface_llm_eval_fast.sh ├── finetune ├── .DS_Store ├── __init__.py ├── custom_dataclasses.py ├── datacollators.py ├── full.py └── requirements.txt ├── generation_config ├── README.md ├── generation_config-16k.json └── generation_config.json ├── inference ├── README.md ├── generate.py ├── generate_hf.py └── generate_lit.py ├── lit_gpt ├── __init__.py ├── adapter.py ├── adapter_v2.py ├── args.py ├── config.py ├── config_base.py ├── datamodules │ ├── __init__.py │ ├── base.py │ ├── base_test.py │ ├── capybara.py │ ├── capybara_test.jsonl │ ├── capybara_test.py │ ├── create_test_jsonl.py │ ├── deita.py │ ├── deita_e2e_test.py │ ├── deita_test.jsonl │ ├── deita_test.py │ ├── llama_instruct.py │ ├── llama_instruct_test.jsonl │ ├── llama_instruct_test.py │ ├── meta_math_qa.py │ ├── meta_math_qa_e2e_test.py │ ├── meta_math_qa_test.jsonl │ ├── meta_math_qa_test.py │ ├── sft_dataset_base.py │ ├── sft_multiturn_dataset_base.py │ ├── sft_multiturn_dataset_base_test.py │ ├── slim-orca-idontknow_test.parquet │ ├── slim_orca_dedup.py │ ├── slim_orca_dedup_test.jsonl │ ├── slim_orca_dedup_test.py │ ├── slim_orca_idontknow.py │ ├── slim_orca_idontknow_e2e_test.py │ ├── slim_orca_idontknow_test.py │ ├── typings │ │ ├── base.py │ │ ├── formatted_dataset.py │ │ └── prepared_dataset.py │ ├── ultrachat_200k.py │ ├── ultrachat_200k_test.jsonl │ ├── ultrachat_200k_test.py │ ├── wizardlm_evol_instruct_v2.py │ ├── wizardlm_evol_instruct_v2_test.jsonl │ └── wizardlm_evol_instruct_v2_test.py ├── fused_cross_entropy.py ├── fused_rotary_embedding.py ├── generate │ ├── __init__.py │ └── base.py ├── lora.py ├── model.py ├── packed_dataset.py ├── prompts.py ├── prompts_test.py ├── rmsnorm.py ├── speed_monitor.py ├── tokenizer.py ├── utils.py └── utils_old.py ├── media ├── pints-1.5.png └── pints_ai-banner.png ├── prepare_dataset ├── __init__.py ├── prepare_dataset_test.py ├── preparer.py ├── shuffle.py └── standard_parquet.py ├── pretrain ├── __init__.py ├── custom_types │ ├── __init__.py │ ├── training_params.py │ └── training_state.py ├── main.py ├── main_save_checkpoint_test.py ├── main_train_test.py └── requirements.txt ├── pyproject.toml ├── requirements.test.txt ├── requirements.txt ├── tokenizer ├── README.md ├── convert │ └── convert.py ├── llama2 │ ├── tokenizer.json │ ├── tokenizer.model │ └── tokenizer_config.json ├── llama_vs_mistral │ ├── dataset_compression_cmp.py │ └── mistral_llama_tokenizer_cmp.py ├── mistral_v1 │ ├── config.json │ ├── generation_config.json │ ├── special_tokens_map.json │ ├── tokenizer.json │ ├── tokenizer.model │ └── tokenizer_config.json ├── modifier │ ├── check.py │ ├── compare.py │ ├── investigate_tokenizer_oddity.py │ └── modify.py └── pints │ ├── added_tokens.json │ ├── special_tokens_map.json │ ├── tokenizer.json │ ├── tokenizer.model │ └── tokenizer_config.json └── util ├── __init__.py └── splitter.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store/ 2 | __pycache__/ 3 | .conda/ 4 | .idea/workspace.xml 5 | .idea/ 6 | .ipynb_checkpoints/ 7 | 8 | consolidated_output/ 9 | data/ 10 | deepspeed_cache/ 11 | flash-attention/ 12 | lightning_logs/ 13 | logs/ 14 | models/ 15 | output/ 16 | temp/ 17 | wandb/ 18 | ~$* 19 | 20 | # Compressed dataset filetypes 21 | *.tar 22 | *.gz 23 | *.zip 24 | 25 | # EleutherAI evaluation harness 26 | lm-evaluation-harness/ 27 | lm_cache/ 28 | -------------------------------------------------------------------------------- /.vsls.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json.schemastore.org/vsls", 3 | "gitignore": "none", 4 | "excludeFiles": ["*.p12", "*.cer", "token", ".gitignore"], 5 | "hideFiles": ["bin", "obj"] 6 | } 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Pints.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # conftest.py 2 | import pytest 3 | from unittest.mock import MagicMock 4 | import sys 5 | import torch 6 | from typing import Optional 7 | 8 | # xformers.ops 9 | mock_swiGLU = MagicMock(name='SwiGLU') 10 | xformers_ops_mock = MagicMock(SwiGLU=mock_swiGLU) 11 | sys.modules['xformers.ops'] = xformers_ops_mock 12 | 13 | # FLASH ATTENTION 14 | # Without CUDA, all these needs to be mocked. 15 | 16 | # lit_gpt/fused_rotary_embedding 17 | sys.modules['rotary_emb'] = MagicMock() 18 | sys.modules['einops'] = MagicMock() 19 | 20 | # lit_gpt/fused_cross_entropy 21 | sys.modules['xentropy_cuda_lib'] = MagicMock() 22 | 23 | 24 | class MockTokenizer: 25 | """A dummy tokenizer that encodes each character as its ASCII code.""" 26 | 27 | bos_id = 1 28 | eos_id = 2 29 | 30 | def encode( 31 | self, 32 | string: str, 33 | device: Optional[torch.device] = None, 34 | bos: bool = False, 35 | eos: bool = False, 36 | max_length=-1, 37 | ) -> torch.Tensor: 38 | tokens = [ord(c) for c in string] 39 | if bos: 40 | bos_id = self.bos_id 41 | if bos_id is None: 42 | raise NotImplementedError( 43 | 'Attempted to add bos, but this tokenizer does not defined a bos token.' 44 | ) 45 | tokens = [bos_id] + tokens 46 | 47 | if max_length > 0: 48 | if eos: 49 | if self.eos_id is None: 50 | raise NotImplementedError( 51 | 'Attempted to add eos, but this tokenizer does not defined an eos token' 52 | ) 53 | if len(tokens) >= max_length: 54 | tokens = tokens[:max_length] 55 | tokens[-1] = self.eos_id 56 | else: 57 | tokens = tokens + [self.eos_id] 58 | else: 59 | if eos: 60 | if self.eos_id is None: 61 | raise NotImplementedError( 62 | 'Attempted to add eos, but this tokenizer does not defined an eos token' 63 | ) 64 | tokens = tokens + [self.eos_id] 65 | 66 | return torch.tensor(tokens) 67 | 68 | def decode(self, tokens: torch.Tensor) -> str: 69 | decoded = '' 70 | for token in tokens.tolist(): 71 | if token == -100: 72 | decoded += '' 73 | continue 74 | 75 | if token == 1: 76 | decoded += '' 77 | continue 78 | 79 | if token == 2: 80 | decoded += '' 81 | continue 82 | 83 | decoded += chr(int(token)) 84 | 85 | return decoded 86 | 87 | # return ''.join(chr(int(t)) for t in tokens.tolist()) 88 | 89 | 90 | @pytest.fixture() 91 | def MOCK_TOKENIZER_FIXTURE(): 92 | return MockTokenizer() 93 | -------------------------------------------------------------------------------- /convert/convert_pretrained_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | # Support running without installing as a package 4 | # ruff: noqa: E402 5 | from pathlib import Path 6 | import sys 7 | 8 | wd = Path(__file__).parent.parent.resolve() 9 | sys.path.append(str(wd)) 10 | 11 | import torch 12 | from lit_gpt.utils import CLI, copy_config_files, incremental_save 13 | 14 | 15 | @torch.inference_mode() 16 | def convert_pretrained_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: 17 | """Convert a checkpoint after pretraining. 18 | 19 | The pretrained checkpoint contains optimizer states and several other metadata that are not needed after training 20 | is finished. This script will export the state-dict of the model and place it in the chosen output folder, 21 | which then can be loaded by other scripts for inference, evaluation, etc. 22 | 23 | Args: 24 | checkpoint_dir: Path to a checkpoint directory produced by ``litgpt.pretrain``. 25 | output_dir: The output folder where the converted state-dict file and config files will be saved to. 26 | """ 27 | 28 | if output_dir.is_dir() and output_dir.glob('*'): 29 | raise FileExistsError( 30 | f'The output folder exists and is not empty: {str(output_dir)}.' 31 | ' Please delete it first or choose a different name.' 32 | ) 33 | 34 | output_dir.mkdir(parents=True) 35 | checkpoint_file = checkpoint_dir / 'lit_model.pth' 36 | output_checkpoint_file = output_dir / 'lit_model.pth' 37 | 38 | # TODO: Consolidate sharded checkpoint if applicable 39 | # Extract the model state dict and save to output folder 40 | with incremental_save(output_checkpoint_file) as saver: 41 | print('Processing', checkpoint_file) 42 | full_checkpoint = torch.load(str(checkpoint_file), mmap=True) 43 | loaded_state_dict = full_checkpoint['model'] 44 | converted_state_dict = {} 45 | for param_name, param in loaded_state_dict.items(): 46 | saver.store_early(param) 47 | # remove prefix for compiled model (if any) 48 | param_name = param_name.replace('_orig_mod.', '') 49 | converted_state_dict[param_name] = param 50 | print(f'Saving converted checkpoint to {str(output_checkpoint_file)}.') 51 | saver.save(converted_state_dict) 52 | 53 | copy_config_files(checkpoint_dir, output_dir) 54 | 55 | 56 | if __name__ == '__main__': 57 | CLI(convert_pretrained_checkpoint) 58 | -------------------------------------------------------------------------------- /convert/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | jsonargparse[signatures,url] 3 | safetensors -------------------------------------------------------------------------------- /dpo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/dpo/__init__.py -------------------------------------------------------------------------------- /dpo/adapters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/dpo/adapters/__init__.py -------------------------------------------------------------------------------- /dpo/adapters/base.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import cpu_count 2 | from typing import List, Literal, TypedDict 3 | 4 | # `cpu_count() - 2` is to keep 2 cpus for other uses (and 2 because usually 2 virtual cores is one phyiscal core) 5 | # But in non multi-core machines, we just use 1 6 | MAX_USABLE_CORES = min(cpu_count() - 2, 10) # 10 is the limit 7 | NUM_PROCESSES = max(MAX_USABLE_CORES, 1) 8 | 9 | Split = Literal['train', 'test'] 10 | 11 | 12 | class FormattedDatasetRow(TypedDict): 13 | prompt: str 14 | chosen: str 15 | rejected: str 16 | 17 | 18 | class DatasetAdapter: 19 | 20 | splits: List[Split] = [] 21 | 22 | def load_dataset(self, split: Split, **kwargs) -> List[FormattedDatasetRow]: 23 | """ 24 | Processes and returns the dataset. 25 | """ 26 | raise NotImplementedError 27 | -------------------------------------------------------------------------------- /dpo/adapters/distilabel_capybara_dpo.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, TypedDict 2 | 3 | from datasets import Dataset 4 | from datasets import load_dataset as load 5 | 6 | from dpo.adapters.base import NUM_PROCESSES, DatasetAdapter, FormattedDatasetRow, Split 7 | 8 | HUGGINGFACE_ID = 'argilla/distilabel-capybara-dpo-7k-binarized' 9 | 10 | 11 | class DistilabelCapybaraDPOItem(TypedDict): 12 | content: str 13 | role: Literal['user', 'assistant'] 14 | 15 | 16 | class DistilabelCapybaraDPORow(TypedDict): 17 | source: str 18 | category: str 19 | prompt: str 20 | candidates_completions: List[str] 21 | candidate_policies: List[str] 22 | ranks: List[int] 23 | rank_str: str 24 | chosen_policy: str 25 | chosen: List[DistilabelCapybaraDPOItem] 26 | rejected_policy: str 27 | rejected: List[DistilabelCapybaraDPOItem] 28 | 29 | 30 | class DistilabelCapybaraDPO(DatasetAdapter): 31 | 32 | splits: List[Split] = ['train'] 33 | 34 | # TODO: Refactor prompt_template, use tokenizer.apply_chat_template in main.py instead. 35 | def load_dataset(self, split: Split, prompt_template='chatml') -> List[FormattedDatasetRow]: 36 | """ 37 | Args: 38 | prompt_template: The prompt template you want to use. 39 | """ 40 | 41 | # Load dataset from Huggingface 42 | dataset = load(HUGGINGFACE_ID, split='train') 43 | 44 | # Format rows 45 | formatted_dataset = format_dataset(dataset, prompt_template) 46 | 47 | return formatted_dataset 48 | 49 | 50 | def format_dataset(dataset: Dataset, prompt_template: str): 51 | # Print first 10 rows just to check 52 | test_rows = dataset.select(range(0, 10)) 53 | test_rows = test_rows.map( 54 | format_rows(prompt_template), num_proc=NUM_PROCESSES 55 | ) 56 | print_test_rows(test_rows) 57 | 58 | print('Dataset length:', len(dataset)) 59 | # Format Dataset 60 | formatted_dataset = dataset.map( 61 | format_rows(prompt_template), num_proc=NUM_PROCESSES 62 | ) 63 | 64 | return formatted_dataset 65 | 66 | 67 | def format_rows(prompt_template='chatml') -> FormattedDatasetRow: 68 | if prompt_template != 'chatml': 69 | raise ValueError(f'prompt_template "{prompt_template}" not supported.') 70 | 71 | def _format_row(row: DistilabelCapybaraDPORow) -> FormattedDatasetRow: 72 | assert row['chosen'][0]['content'] == row['rejected'][0]['content'] 73 | 74 | formatted_prompt = f"""<|im_start|>user 75 | {row['chosen'][0]['content']}<|im_end|> 76 | <|im_start|>assistant 77 | """ 78 | 79 | formatted_chosen = f"{row['chosen'][1]['content']}<|im_end|>" 80 | formatted_rejected = f"{row['rejected'][1]['content']}<|im_end|>" 81 | 82 | return { 83 | 'prompt': formatted_prompt, 84 | 'chosen': formatted_chosen, 85 | 'rejected': formatted_rejected, 86 | } 87 | 88 | return _format_row 89 | 90 | 91 | def print_test_rows(row: DistilabelCapybaraDPORow) -> None: 92 | print('=' * 80) 93 | print(row['prompt']) 94 | print('=' * 80) 95 | -------------------------------------------------------------------------------- /dpo/adapters/intel_orca_pair.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from datasets import load_dataset as load 4 | 5 | from dpo.adapters.base import DatasetAdapter, FormattedDatasetRow, Split 6 | 7 | # TODO: Not ready for use, need to add in chat template formatting 8 | 9 | HUGGINGFACE_ID = 'Intel/orca_dpo_pairs' 10 | 11 | 12 | class OrcaDPOPairs(DatasetAdapter): 13 | 14 | splits: List[Split] = ['train'] 15 | 16 | def load_dataset(self, split: Split) -> List[FormattedDatasetRow]: 17 | # Load huggingface dataset 18 | dataset = load(HUGGINGFACE_ID, split='train') 19 | 20 | # Remove unused column 21 | dataset = dataset.remove_columns(['system']) 22 | 23 | # Rename columns to names required by dpo trainer 24 | dataset.rename_column('question', 'prompt') 25 | 26 | return dataset 27 | -------------------------------------------------------------------------------- /dpo/adapters/nectar.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, TypedDict 3 | 4 | import datasets 5 | from datasets import load_dataset as load 6 | 7 | from dpo.adapters.base import NUM_PROCESSES, DatasetAdapter, FormattedDatasetRow, Split 8 | 9 | HUGGINGFACE_ID = 'berkeley-nest/Nectar' 10 | SOURCES_TO_SKIP = ['anthropic-hh'] # Skip this as we want the model to be uncensored 11 | 12 | class NectarAnswer(TypedDict): 13 | answer: str 14 | model: str 15 | rank: int 16 | 17 | 18 | class NectarRow(TypedDict): 19 | prompt: str 20 | answers: List[NectarAnswer] 21 | source: List[str] 22 | 23 | 24 | class Nectar(DatasetAdapter): 25 | 26 | splits: List[Split] = ['train'] 27 | 28 | def __init__(self, seed: int = None): 29 | self.seed = seed 30 | 31 | def load_dataset(self, split: Split, prompt_template='chatml') -> datasets.Dataset: 32 | # Load huggingface dataset 33 | dataset = load(HUGGINGFACE_ID, split='train', num_proc=NUM_PROCESSES) 34 | 35 | # Drop Unused Columns 36 | dataset = dataset.remove_columns( 37 | ['turns', 'num_responses', 'good_natured'] 38 | ) 39 | 40 | # Print first 10 rows 41 | test_rows = dataset.select(range(0, 10)) 42 | test_rows.map(format_rows_print, num_proc=NUM_PROCESSES) 43 | 44 | filtered_dataset = dataset.filter(filter_rows) 45 | 46 | # Format Dataset 47 | random.seed(self.seed) 48 | formattedDataset = filtered_dataset.map( 49 | format_rows(prompt_template), 50 | num_proc=NUM_PROCESSES 51 | ) 52 | 53 | return formattedDataset 54 | 55 | 56 | 57 | def filter_rows(row: NectarRow): 58 | # Remove smaples from the blacklisted sources 59 | for source in SOURCES_TO_SKIP: 60 | if source in row['source']: 61 | return False 62 | 63 | return True 64 | 65 | 66 | def format_rows(prompt_template='chatml'): 67 | if prompt_template == 'chatml': 68 | def _format_rows(row: NectarRow) -> FormattedDatasetRow: 69 | chosen = row['answers'][0]['answer'] 70 | # chosen = choose_top_answer( 71 | # answers=row['answers'], 72 | # filter_out_gpt=False, 73 | # ) 74 | 75 | # https://arxiv.org/pdf/2310.16944 76 | # We construct binary preferences from UltraFeedback by selecting the 77 | # highest mean score as the “chosen” response and one of the remaining three at random as 78 | # “rejected”. We opted for random selection instead of selecting the lowest-scored response 79 | # to encourage diversity and make the DPO objective more challenging. As noted above, this 80 | # step is computed offline and does not involve any sampling from the reference model. 81 | rejected_pool = row['answers'][1:] 82 | rejected = random.choice(rejected_pool)['answer'] # Randomize the rejected answer 83 | 84 | # The DPOTrainer tokenization will not add a stop token. So we have to add it here ourselves. 85 | # and have the model learn to output the end token. 86 | return { 87 | 'chosen': f'{chosen}<|im_end|>', 88 | 'rejected': f'{rejected}<|im_end|>', 89 | 'prompt': format_prompt(row['prompt'], 'chatml'), 90 | } 91 | 92 | return _format_rows 93 | 94 | raise NotImplementedError(f'Prompt template [{prompt_template}] not implemented.') 95 | 96 | 97 | def choose_top_answer(answers: List[NectarAnswer], filter_out_gpt: bool) -> str: 98 | # TODO: Due to implementation of random rejected answer, this needs to be refactored and cannot be used. 99 | # if filter_out_gpt: 100 | # # TODO: Find a way to filter out chat-gpt answers as there are rows where there are no non-chatGPT answers 101 | # # If we want to filter out gpt, we will choose the top-ranked that is not from GPT model 102 | # for answer in answers: 103 | # if 'gpt' not in answer['model']: 104 | # return answer['answer'] 105 | # raise NotImplementedError( 106 | # 'Have yet to implement a way to filter out rows where there are no non-chatGPT answers' 107 | # ) 108 | 109 | # If we want are not filtering out gpt, then we will just choose the first answer (since that will be the top-ranked one) 110 | return answers[0]['answer'] 111 | 112 | 113 | def format_prompt(prompt: str, promptFormat: str = 'chatml') -> str: 114 | assert promptFormat == 'chatml', \ 115 | f'Error, {promptFormat} has not been implemented.' 116 | 117 | # Remove spaces and newlines at the start and end 118 | prompt = prompt.strip() 119 | 120 | # Human 121 | assert prompt.startswith('Human: ') 122 | prompt = prompt.removeprefix('Human: ') 123 | prompt = '<|im_start|>user\n' + prompt 124 | 125 | # Assistant 126 | assert prompt.endswith('Assistant:') 127 | prompt = prompt.removesuffix('Assistant:') 128 | prompt = prompt.strip() 129 | 130 | prompt += '<|im_end|>\n<|im_start|>assistant\n' 131 | 132 | return prompt 133 | 134 | 135 | def format_rows_print(row: NectarRow) -> None: 136 | return 137 | to_print = format_rows()(row) 138 | print('=' * 80) 139 | print(to_print['prompt']) 140 | print('=' * 80) 141 | -------------------------------------------------------------------------------- /dpo/adapters/open_hermes_preferences.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, TypedDict 2 | 3 | from datasets import Dataset 4 | from datasets import load_dataset as load 5 | 6 | from dpo.adapters.base import NUM_PROCESSES, DatasetAdapter, FormattedDatasetRow, Split 7 | 8 | HUGGINGFACE_ID = 'argilla/OpenHermesPreferences' 9 | 10 | 11 | class OpenHermesItem(TypedDict): 12 | content: str 13 | role: Literal['user', 'assistant'] 14 | 15 | 16 | class OpenHermesRow(TypedDict): 17 | source: str 18 | category: str 19 | prompt: str 20 | candidates_completions: List[str] 21 | candidate_policies: List[str] 22 | ranks: List[int] 23 | rank_str: str 24 | chosen_policy: str 25 | chosen: List[OpenHermesItem] 26 | rejected_policy: str 27 | rejected: List[OpenHermesItem] 28 | 29 | 30 | class OpenHermesPreferences(DatasetAdapter): 31 | 32 | splits: List[Split] = ['train'] 33 | 34 | # TODO: Refactor prompt_template, use tokenizer.apply_chat_template in main.py instead. 35 | def load_dataset(self, split: Split, prompt_template='chatml') -> List[FormattedDatasetRow]: 36 | """ 37 | Args: 38 | prompt_template: The prompt template you want to use. 39 | """ 40 | 41 | # Load dataset from Huggingface 42 | dataset = load(HUGGINGFACE_ID, split='train') 43 | 44 | # Format rows 45 | formatted_dataset = format_dataset(dataset, prompt_template) 46 | 47 | return formatted_dataset 48 | 49 | 50 | def format_dataset(dataset: Dataset, prompt_template: str): 51 | # Print first 10 rows just to check 52 | test_rows = dataset.select(range(0, 10)) 53 | test_rows = test_rows.map( 54 | format_rows(prompt_template), num_proc=NUM_PROCESSES 55 | ) 56 | print_test_rows(test_rows) 57 | 58 | print('Dataset length:', len(dataset)) 59 | # Format Dataset 60 | formatted_dataset = dataset.map( 61 | format_rows(prompt_template), num_proc=NUM_PROCESSES 62 | ) 63 | 64 | return formatted_dataset 65 | 66 | 67 | def format_rows(prompt_template='chatml') -> FormattedDatasetRow: 68 | if prompt_template != 'chatml': 69 | raise ValueError(f'prompt_template "{prompt_template}" not supported.') 70 | 71 | def _format_row(row: OpenHermesRow) -> FormattedDatasetRow: 72 | assert row['chosen'][0]['content'] == row['rejected'][0]['content'] 73 | 74 | formatted_prompt = f"""<|im_start|>user 75 | {row['chosen'][0]['content']}<|im_end|> 76 | <|im_start|>assistant 77 | """ 78 | 79 | formatted_chosen = f"{row['chosen'][1]['content']}<|im_end|>" 80 | formatted_rejected = f"{row['rejected'][1]['content']}<|im_end|>" 81 | 82 | return { 83 | 'prompt': formatted_prompt, 84 | 'chosen': formatted_chosen, 85 | 'rejected': formatted_rejected, 86 | } 87 | 88 | return _format_row 89 | 90 | 91 | def print_test_rows(row: OpenHermesRow) -> None: 92 | print('=' * 80) 93 | print(row['prompt']) 94 | print('=' * 80) 95 | -------------------------------------------------------------------------------- /dpo/adapters/ultrafeedback_binarized.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, TypedDict 2 | 3 | from datasets import Dataset 4 | from datasets import load_dataset as load 5 | 6 | from dpo.adapters.base import NUM_PROCESSES, DatasetAdapter, FormattedDatasetRow, Split 7 | 8 | HUGGINGFACE_ID = 'HuggingFaceH4/ultrafeedback_binarized' 9 | 10 | # This is the distance between the DPO pairs before we use it. If the scores does not have a wide distance, 11 | # it may not be useful for the model to learn, given both are just as good, or just as poor 12 | SCORE_DISTANCE = 2.0 13 | 14 | 15 | class UltraFeedbackItem(TypedDict): 16 | content: str 17 | role: Literal['user', 'assistant'] 18 | 19 | 20 | class UltraFeedbackRow(TypedDict): 21 | prompt_id: str 22 | prompt: str 23 | chosen: List[UltraFeedbackItem] 24 | rejected: List[UltraFeedbackItem] 25 | score_chosen: float 26 | score_rejected: float 27 | 28 | 29 | class UltraFeedbackBinarized(DatasetAdapter): 30 | 31 | splits: List[Split] = ['train', 'test'] 32 | 33 | def __init__(self, score_distance: float = SCORE_DISTANCE): 34 | self.score_distance = score_distance 35 | 36 | def load_dataset( 37 | self, 38 | split: Split, 39 | prompt_template='chatml' 40 | ) -> List[FormattedDatasetRow]: 41 | """ 42 | Args: 43 | split: The dataset split that you want to get. Ultrafeedback has `train` and `test`. 44 | prompt_template: The prompt template you want to use. 45 | score_distance: Minimum score distance for included DPO pairs. 46 | I.e, 1.0 means a pair with chosen score of 9.0, and rejected score of 8.0, will be accepted (9 - 8 = 1, which is not less than 1.0) 47 | """ 48 | 49 | if split.lower() == 'train': 50 | split = 'train_prefs' 51 | elif split.lower() == 'test': 52 | split = 'test_prefs' 53 | else: 54 | raise Exception(f"Unknown split '{split}") 55 | 56 | # Load dataset from Huggingface 57 | dataset = load(HUGGINGFACE_ID, split=split) 58 | 59 | # Drop unused columns 60 | dataset = dataset.remove_columns(['messages']) 61 | 62 | # Format rows 63 | formatted_dataset = format_dataset(dataset, prompt_template, self.score_distance) 64 | 65 | # Drop other unused columns post-processing 66 | formatted_dataset = formatted_dataset.remove_columns( 67 | ['prompt_id', 'score_chosen', 'score_rejected'] 68 | ) 69 | 70 | return formatted_dataset 71 | 72 | 73 | def format_dataset(dataset: Dataset, prompt_template: str, score_distance: float): 74 | # Print first 10 rows just to check 75 | test_rows = dataset.select(range(0, 10)) 76 | test_rows = test_rows.map( 77 | format_rows(prompt_template), num_proc=NUM_PROCESSES 78 | ) 79 | print_test_rows(test_rows) 80 | 81 | print('Dataset length before filtering:', len(dataset)) 82 | filtered_dataset = filter_dataset(dataset, score_distance) 83 | print('Dataset length after filter:', len(filtered_dataset)) 84 | # Format Dataset 85 | formatted_dataset = filtered_dataset.map( 86 | format_rows(prompt_template), num_proc=NUM_PROCESSES 87 | ) 88 | 89 | return formatted_dataset 90 | 91 | 92 | def format_rows(prompt_template='chatml') -> FormattedDatasetRow: 93 | if prompt_template == 'chatml': 94 | 95 | def _format_row(row: UltraFeedbackRow) -> FormattedDatasetRow: 96 | dpo_triplets = get_dpo_triplets(row) 97 | 98 | formatted_prompt = f"""\ 99 | <|im_start|>user 100 | {dpo_triplets['prompt']}<|im_end|> 101 | <|im_start|>assistant 102 | """ 103 | 104 | formatted_chosen = f"{dpo_triplets['chosen']}<|im_end|>" 105 | formatted_rejected = f"{dpo_triplets['rejected']}<|im_end|>" 106 | 107 | return { 108 | 'prompt': formatted_prompt, 109 | 'chosen': formatted_chosen, 110 | 'rejected': formatted_rejected, 111 | } 112 | 113 | return _format_row 114 | 115 | 116 | def get_dpo_triplets(row: UltraFeedbackRow) -> FormattedDatasetRow: 117 | return { 118 | 'prompt': row['prompt'], 119 | 'chosen': row['chosen'][1]['content'], 120 | 'rejected': row['rejected'][1]['content'], 121 | } 122 | 123 | 124 | def print_test_rows(row: UltraFeedbackRow) -> None: 125 | print('=' * 80) 126 | print(row['prompt']) 127 | print('=' * 80) 128 | 129 | 130 | def filter_dataset(dataset: Dataset, score_distance: float): 131 | def filter_rows(score_distance: float): 132 | # Put all the filter criteria here 133 | def _filter_rows(row: UltraFeedbackRow): 134 | # Remove a malfeascant row about covid. It has wrong answer 135 | remove_prompt_ids = [ 136 | '744aa5f9a6cbab1c168c606df3e1daf63a6cba08e20d5b8526a70627606b9f2e' 137 | ] 138 | if row['prompt_id'] in remove_prompt_ids: 139 | return False 140 | 141 | distance = row['score_chosen'] - row['score_rejected'] 142 | 143 | if distance < score_distance: 144 | return False 145 | 146 | return True 147 | 148 | return _filter_rows 149 | 150 | return dataset.filter(filter_rows(score_distance=score_distance)) 151 | -------------------------------------------------------------------------------- /dpo/adapters/ultrafeedback_binarized_test.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | from dpo.adapters.ultrafeedback_binarized import format_dataset 4 | 5 | 6 | def test_format_dataset_chatml(): 7 | MOCK_DATA = load_dataset( 8 | 'json', 9 | split=['train'], 10 | data_files='./dpo/adapters/ultrafeedback_binarized_test.jsonl', 11 | streaming=False, 12 | )[0] 13 | 14 | train_data = format_dataset( 15 | dataset=MOCK_DATA, 16 | prompt_template='chatml', 17 | score_distance=0, 18 | ) 19 | 20 | train_data_row_one = train_data[0] 21 | mock_data_row_one = MOCK_DATA[0] 22 | 23 | expected_prompt = f"""\ 24 | <|im_start|>user 25 | {mock_data_row_one['prompt']}<|im_end|> 26 | <|im_start|>assistant 27 | """ 28 | 29 | assert len(MOCK_DATA) == len(train_data) 30 | expected_chosen = f"""{mock_data_row_one['chosen'][1]["content"]}<|im_end|>""" 31 | expected_rejected = f"""{mock_data_row_one['rejected'][1]["content"]}<|im_end|>""" 32 | 33 | assert expected_prompt == train_data_row_one['prompt'] 34 | assert expected_chosen == train_data_row_one['chosen'] 35 | assert expected_rejected == train_data_row_one['rejected'] 36 | -------------------------------------------------------------------------------- /dpo/ds_config_stage1.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1 4 | }, 5 | "gradient_accumulation_steps": "auto", 6 | "train_micro_batch_size_per_gpu": "auto", 7 | "gradient_clipping": "auto", 8 | "bf16": { 9 | "enabled": "auto" 10 | } 11 | } -------------------------------------------------------------------------------- /dpo/ds_config_stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2 4 | }, 5 | "gradient_accumulation_steps": "auto", 6 | "train_micro_batch_size_per_gpu": "auto", 7 | "gradient_clipping": "auto", 8 | "bf16": { 9 | "enabled": "auto" 10 | } 11 | } -------------------------------------------------------------------------------- /dpo/ds_config_stage3.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "offload_param": { 5 | "device": "cpu", 6 | "pin_memory": true 7 | }, 8 | "offload_optimizer": { 9 | "device": "cpu", 10 | "pin_memory": true 11 | } 12 | }, 13 | "gradient_accumulation_steps": "auto", 14 | "train_micro_batch_size_per_gpu": "auto", 15 | "gradient_clipping": "auto", 16 | "bf16": { 17 | "enabled": "auto" 18 | } 19 | } -------------------------------------------------------------------------------- /dpo/make_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from datasets import Dataset, concatenate_datasets 4 | 5 | from dpo.adapters.base import DatasetAdapter, FormattedDatasetRow 6 | 7 | # The amount of test data to split from datasets that doesn't have 8 | # test/eval sets 9 | TEST_SIZE = 0.1 10 | 11 | 12 | def make_dataset( 13 | adapters: List[DatasetAdapter], 14 | seed = None 15 | ) -> Tuple[List[FormattedDatasetRow], List[FormattedDatasetRow]]: 16 | """ 17 | Make dataset and collator for Direct-Preference Optimization. 18 | Datasets are expected to have the following columns: {`prompt`, `chosen`, `rejected` } 19 | Returns a tuple of (train_dataset, eval_dataset) 20 | """ 21 | 22 | train_dataset_list: List[Dataset] = [] 23 | eval_dataset_list: List[Dataset] = [] 24 | 25 | # Get all datasets indicated in config.py, then append them to list 26 | for adapter in adapters: 27 | assert 'train' in adapter.splits, \ 28 | f"'train' split not found in {adapter.__name__}." 29 | 30 | # TODO: Remove hardcoded 'train' and 'test' splits 31 | dataset = adapter.load_dataset('train') 32 | if 'test' in adapter.splits: 33 | eval_dataset = adapter.load_dataset('test') 34 | 35 | train_dataset_list.append(dataset) 36 | eval_dataset_list.append(eval_dataset) 37 | 38 | else: 39 | dataset = dataset.train_test_split(test_size=TEST_SIZE, seed=seed) 40 | train_dataset_list.append(dataset['train']) 41 | eval_dataset_list.append(dataset['test']) 42 | 43 | # We concatenate all datasets in list first, then shuffle before getting train eval split 44 | dataset = concatenate_datasets(train_dataset_list) 45 | eval_dataset = concatenate_datasets(eval_dataset_list) 46 | 47 | return dataset, eval_dataset 48 | -------------------------------------------------------------------------------- /dpo/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.0 2 | transformers==4.42.3 3 | trl==0.8.1 4 | datasets==2.18.0 5 | click 6 | deepspeed==0.14.4 7 | peft==0.11.1 -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # LM Eval Harness 2 | 3 | ## Setting up the lm-evaluation-harness repository 4 | 5 | Clone the lm-evaluation-harness repository and checkout the branch used by Huggingface OpenLLM Leaderboard 6 | 7 | ```bash 8 | git clone https://github.com/EleutherAI/lm-evaluation-harness.git && \ 9 | cd lm-evaluation-harness && git checkout b281b09 10 | ``` 11 | 12 | ## Set up conda enviroment inside the repository 13 | 14 | ```bash 15 | conda create --prefix=./.conda python=3.10 && \ 16 | conda activate ./.conda 17 | ``` 18 | 19 | ## Install the requirements, and go back to the `eval`. 20 | 21 | ```bash 22 | pip install -r requirements.txt && cd .. 23 | ``` 24 | 25 | ## Run the script 26 | 27 | ``` 28 | bash huggingface_llm_eval_fast.sh /path/to/model /path/to/output/results 29 | ``` 30 | 31 | You can also use the non-fast version `huggingface_llm_eval.sh` that runs with `batch size 1` that will be fully accurate with huggingface leaderboard. However, it takes a lot longer, up to 1.5-2x the time. 32 | 33 | ## Parse the results 34 | 35 | ```python 36 | python eval.py /path/to/results 37 | ``` 38 | 39 | ### Notes 40 | 41 | The script needs two arguments 42 | | argument | Description | 43 | | -------- | ---------- | 44 | | /path/to/model | This is the path to the model safetensor we want to evaluate | 45 | | /path/to/output | This is the directory we want to save the result jsons | 46 | 47 | After evaluation is complete, copy `eval.py` inside the output dir which has all the evaluation result jsons. Run with `python eval.py`. This will load the eval results and print the final scores. 48 | -------------------------------------------------------------------------------- /eval/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from sys import argv 3 | from pathlib import Path 4 | from pprint import pprint 5 | 6 | """ 7 | Usage: 8 | 9 | python eval.py /path/to/results 10 | """ 11 | 12 | result_dir = argv[1] 13 | assert result_dir, 'You must provide the path to the results.' 14 | 15 | result_dir = Path(result_dir).absolute().resolve() 16 | assert result_dir.exists(), f'Directory {str(result_dir)} does not exist. Please check.' 17 | 18 | 19 | def get_arc_result(json: dict): 20 | return json['results']['arc_challenge']['acc_norm'] * 100 21 | 22 | 23 | def get_hellaswag_result(json: dict): 24 | return json['results']['hellaswag']['acc_norm'] * 100 25 | 26 | 27 | def get_mmlu_result(json: dict): 28 | # MMLU reults needs to be averaged across a few tests... 29 | total_acc = sum(json['results'][key]['acc'] for key in json['results']) 30 | avg_acc = total_acc / len(json['results']) 31 | return avg_acc * 100 32 | 33 | 34 | def get_truthfulQA_result(json: dict): 35 | return json['results']['truthfulqa_mc']['mc2'] * 100 36 | 37 | 38 | def get_winogrande_result(json: dict): 39 | return json['results']['winogrande']['acc'] * 100 40 | 41 | 42 | def get_gsm8k_result(json: dict): 43 | return json['results']['gsm8k']['acc'] * 100 44 | 45 | 46 | evaluations = [ 47 | {'name': 'arc', 'getter': get_arc_result}, 48 | {'name': 'hellaswag', 'getter': get_hellaswag_result}, 49 | {'name': 'mmlu', 'getter': get_mmlu_result}, 50 | {'name': 'truthfulQA', 'getter': get_truthfulQA_result}, 51 | {'name': 'winogrande', 'getter': get_winogrande_result}, 52 | {'name': 'gsm8k', 'getter': get_gsm8k_result}, 53 | ] 54 | 55 | total_score = 0 56 | results = {} 57 | 58 | for evaluation in evaluations: 59 | json_path = result_dir / f'{evaluation["name"]}.json' 60 | if not json_path.exists(): 61 | print( 62 | f'WARN: Not able to find `{evaluation["name"]}` results at {str(json_path)}. Perhaps the test failed?' 63 | ) 64 | 65 | match input( 66 | f'Results will not include failed `{evaluation["name"]}` results. Do you want to continue? (yes/no): ' 67 | ): 68 | case 'yes': 69 | continue 70 | case _: 71 | exit() 72 | 73 | with open(json_path, 'r') as file: 74 | result_json = json.load(file) 75 | 76 | score = evaluation['getter'](result_json) 77 | results[evaluation['name']] = score 78 | total_score += score 79 | 80 | results['average'] = total_score / len(evaluations) 81 | 82 | pprint(results) 83 | 84 | all_results_path = result_dir / 'all.json' 85 | with open(all_results_path, 'w') as file: 86 | file.write(json.dumps(results)) 87 | 88 | print(f'Results saved to {str(all_results_path)}') 89 | -------------------------------------------------------------------------------- /eval/huggingface_llm_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ###################################################################### 3 | # https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard 4 | # The tasks and few shots parameters are: 5 | 6 | # ARC: 25-shot, arc-challenge (acc_norm) 7 | # HellaSwag: 10-shot, hellaswag (acc_norm) 8 | # TruthfulQA: 0-shot, truthfulqa-mc (mc2) 9 | # MMLU: 5-shot, hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions (average of all the results acc) 10 | # Winogrande: 5-shot, winogrande (acc) 11 | # GSM8k: 5-shot, gsm8k (acc) 12 | ####################################################################### 13 | # Usage 14 | # bash huggingface_leaderboard_eval_script.sh /path/to/model 15 | 16 | # Check if the correct number of arguments was provided 17 | if [ "$#" -ne 2 ]; then 18 | echo "Usage: $0 /path/to/model /path/to/output" 19 | exit 1 20 | fi 21 | 22 | # Assign the first argument to a variable for the model path 23 | MODEL_PATH="$1" 24 | 25 | # Assign the second argument to a variable for the output path 26 | OUTPUT_PATH="$2" 27 | 28 | # Define arrays for task_list, n_few_shot, and output_path for each benchmark 29 | declare -a task_lists=("arc_challenge" "hellaswag" "truthfulqa_mc" "hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions" "winogrande" "gsm8k") 30 | declare -a few_shots=(25 10 0 5 5 5) 31 | declare -a output_paths=("arc.json" "hellaswag.json" "truthfulQA.json" "mmlu.json" "winogrande.json" "gsm8k.json") 32 | 33 | CURRENT_DIR=$(pwd) 34 | 35 | # Loop through each benchmark configuration 36 | for i in {0..5}; do 37 | 38 | output_file="${OUTPUT_PATH}/${output_paths[$i]}" 39 | 40 | if [ -f "$output_file" ]; then 41 | echo "${output_file} already exists, skipping benchmark.." 42 | else 43 | echo "Running benchmark $(($i + 1))" 44 | # Hellaswag and Winogrande HF_DATASETS_TRUST_REMOTE_CODE=1 to run 45 | # There was a breaking change in `datasets` module that necessitated this force overwrite. 46 | PYTHONPATH=$PYTHONPATH:$CURRENT_DIR/lm-evaluation-harness HF_DATASETS_TRUST_REMOTE_CODE=1 python lm-evaluation-harness/main.py \ 47 | --model=hf-causal-experimental \ 48 | --model_args="pretrained=${MODEL_PATH},use_accelerate=True,trust_remote_code=True" \ 49 | --tasks=${task_lists[$i]} \ 50 | --num_fewshot=${few_shots[$i]} \ 51 | --batch_size=1 \ 52 | --output_path="${output_file}" 53 | echo "Benchmark $(($i + 1)) completed" 54 | fi 55 | done 56 | 57 | echo "All benchmarks completed!" 58 | -------------------------------------------------------------------------------- /eval/huggingface_llm_eval_fast.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ###################################################################### 3 | # https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard 4 | # The tasks and few shots parameters are: 5 | 6 | # ARC: 25-shot, arc-challenge (acc_norm) 7 | # HellaSwag: 10-shot, hellaswag (acc_norm) 8 | # TruthfulQA: 0-shot, truthfulqa-mc (mc2) 9 | # MMLU: 5-shot, hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions (average of all the results acc) 10 | # Winogrande: 5-shot, winogrande (acc) 11 | # GSM8k: 5-shot, gsm8k (acc) 12 | ####################################################################### 13 | # Usage 14 | # bash huggingface_leaderboard_eval_script.sh /path/to/model 15 | 16 | # Check if the correct number of arguments was provided 17 | if [ "$#" -ne 2 ]; then 18 | echo "Usage: $0 /path/to/model /path/to/output" 19 | exit 1 20 | fi 21 | 22 | # Assign the first argument to a variable for the model path 23 | MODEL_PATH="$1" 24 | 25 | # Assign the second argument to a variable for the output path 26 | OUTPUT_PATH="$2" 27 | 28 | # Define arrays for task_list, n_few_shot, and output_path for each benchmark 29 | declare -a task_lists=("arc_challenge" "hellaswag" "truthfulqa_mc" "hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions" "winogrande" "gsm8k") 30 | declare -a few_shots=(25 10 0 5 5 5) 31 | declare -a output_paths=("arc.json" "hellaswag.json" "truthfulQA.json" "mmlu.json" "winogrande.json" "gsm8k.json") 32 | # batch_size "auto" is only useful for large evaluation 33 | # otherwise for small evaluations like truthfulQA and winogrande, the 34 | # time taken to find the batch size is not worth it, better to run with batch size 1 35 | declare -a batch_sizes=("auto" "auto" 1 "auto" 1 "auto") 36 | 37 | 38 | CURRENT_DIR=$(pwd) 39 | 40 | # Loop through each benchmark configuration 41 | for i in {0..5}; do 42 | 43 | output_file="${OUTPUT_PATH}/${output_paths[$i]}" 44 | 45 | if [ -f "$output_file" ]; then 46 | echo "${output_file} already exists, skipping benchmark.." 47 | else 48 | echo "Running benchmark $(($i + 1))" 49 | # Hellaswag and Winogrande HF_DATASETS_TRUST_REMOTE_CODE=1 to run 50 | # There was a breaking change in `datasets` module that necessitated this force overwrite. 51 | PYTHONPATH=$PYTHONPATH:$CURRENT_DIR/lm-evaluation-harness HF_DATASETS_TRUST_REMOTE_CODE=1 python lm-evaluation-harness/main.py \ 52 | --model=hf-causal-experimental \ 53 | --model_args="pretrained=${MODEL_PATH},use_accelerate=True,trust_remote_code=True" \ 54 | --tasks=${task_lists[$i]} \ 55 | --num_fewshot=${few_shots[$i]} \ 56 | --batch_size=${batch_sizes[$i]} \ 57 | --output_path="${output_file}" 58 | 59 | echo "Benchmark $(($i + 1)) completed" 60 | fi 61 | done 62 | 63 | echo "All benchmarks completed!" 64 | -------------------------------------------------------------------------------- /finetune/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/finetune/.DS_Store -------------------------------------------------------------------------------- /finetune/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/finetune/__init__.py -------------------------------------------------------------------------------- /finetune/custom_dataclasses.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | from transformers import Seq2SeqTrainingArguments 4 | 5 | @dataclass 6 | class ModelArguments: 7 | model_name_or_path: Optional[str] = field( 8 | default="EleutherAI/pythia-12b" 9 | ) 10 | trust_remote_code: Optional[bool] = field( 11 | default=False, 12 | metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."} 13 | ) 14 | 15 | @dataclass 16 | class DataArguments: 17 | eval_dataset_size: int = field( 18 | default=1024, metadata={"help": "Size of validation dataset."} 19 | ) 20 | max_train_samples: Optional[int] = field( 21 | default=None, 22 | metadata={ 23 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 24 | "value if set." 25 | }, 26 | ) 27 | max_eval_samples: Optional[int] = field( 28 | default=None, 29 | metadata={ 30 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 31 | "value if set." 32 | }, 33 | ) 34 | source_max_len: int = field( 35 | default=1024, 36 | metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."}, 37 | ) 38 | target_max_len: int = field( 39 | default=256, 40 | metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."}, 41 | ) 42 | dataset: str = field( 43 | default='alpaca', 44 | metadata={"help": "Which dataset to finetune on. See datamodule for options."} 45 | ) 46 | dataset_format: Optional[str] = field( 47 | default=None, 48 | metadata={"help": "Which dataset format is used. [alpaca|chip2|self-instruct|hh-rlhf]"} 49 | ) 50 | 51 | @dataclass 52 | class TrainingArguments(Seq2SeqTrainingArguments): 53 | 54 | train_on_source: Optional[bool] = field( 55 | default=False, 56 | metadata={"help": "Whether to train on the input in addition to the target text."} 57 | ) 58 | 59 | report_to: str = field( 60 | default='none', 61 | metadata={"help": "To use wandb or something else for reporting."} 62 | ) 63 | output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'}) 64 | optim: str = field(default='adamw_torch', metadata={"help": 'The optimizer to be used'}) 65 | per_device_train_batch_size: int = field(default=16, metadata={"help": 'The training batch size per GPU. Increase for better speed.'}) 66 | gradient_accumulation_steps: int = field(default=1, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'}) 67 | max_steps: int = field(default=10000, metadata={"help": 'How many optimizer update steps to take'}) 68 | weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) 69 | learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'}) 70 | remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'}) 71 | max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'}) 72 | gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'}) 73 | do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'}) 74 | lr_scheduler_type: str = field(default='constant', metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'}) 75 | warmup_ratio: float = field(default=0.03, metadata={"help": 'Fraction of steps to do a warmup for'}) 76 | logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'}) 77 | group_by_length: bool = field(default=True, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'}) 78 | save_strategy: str = field(default='steps', metadata={"help": 'When to save checkpoints'}) 79 | save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) 80 | save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'}) 81 | template: str = field(default='chatml', metadata={"help": 'Template to used for finetuning'}) 82 | 83 | @dataclass 84 | class GenerationArguments: 85 | # For more hyperparameters check: 86 | # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig 87 | # Length arguments 88 | max_new_tokens: Optional[int] = field( 89 | default=256, 90 | metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops" 91 | "if predict_with_generate is set."} 92 | ) 93 | min_new_tokens : Optional[int] = field( 94 | default=None, 95 | metadata={"help": "Minimum number of new tokens to generate."} 96 | ) 97 | 98 | # Generation strategy 99 | do_sample: Optional[bool] = field(default=False) 100 | num_beams: Optional[int] = field(default=1) 101 | num_beam_groups: Optional[int] = field(default=1) 102 | penalty_alpha: Optional[float] = field(default=None) 103 | use_cache: Optional[bool] = field(default=True) 104 | 105 | # Hyperparameters for logit manipulation 106 | temperature: Optional[float] = field(default=1.0) 107 | top_k: Optional[int] = field(default=50) 108 | top_p: Optional[float] = field(default=1.0) 109 | typical_p: Optional[float] = field(default=1.0) 110 | diversity_penalty: Optional[float] = field(default=0.0) 111 | repetition_penalty: Optional[float] = field(default=1.0) 112 | length_penalty: Optional[float] = field(default=1.0) 113 | no_repeat_ngram_size: Optional[int] = field(default=0) 114 | -------------------------------------------------------------------------------- /finetune/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.27.2 2 | chardet==5.2.0 3 | datasets==2.20.0 4 | transformers==4.43.3 5 | docstring_parser==0.16 6 | rotary_emb @ git+https://github.com/Dao-AILab/flash-attention.git#subdirectory=csrc/rotary 7 | xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git#subdirectory=csrc/xentropy 8 | dropout_layer_norm @ git+https://github.com/Dao-AILab/flash-attention.git#subdirectory=csrc/layer_norm 9 | -------------------------------------------------------------------------------- /generation_config/README.md: -------------------------------------------------------------------------------- 1 | # Generation config 2 | 3 | You need to included the respective `generation_config` for 2k or 16k models during inference. 4 | 5 | WARN: Do not include these files during training/finetuning as it will cause errors. 6 | -------------------------------------------------------------------------------- /generation_config/generation_config-16k.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "bos_token_id": 1, 4 | "eos_token_id": 2, 5 | "max_length": 16384, 6 | "pad_token_id": 32001, 7 | "temperature": 0.1, 8 | "repetition_penalty": 1.3, 9 | "transformers_version": "4.38.0" 10 | } -------------------------------------------------------------------------------- /generation_config/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "bos_token_id": 1, 4 | "eos_token_id": 2, 5 | "max_length": 2048, 6 | "pad_token_id": 32001, 7 | "temperature": 0.1, 8 | "repetition_penalty": 1.3, 9 | "transformers_version": "4.38.0" 10 | } -------------------------------------------------------------------------------- /inference/README.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 | ## How to run 4 | 5 | 1. (If full checkpoints were saved) Convert the deepspeed checkpoint (`.pt`) files to litgpt weights using `python zero_to_fp32.py . lit_model.pth`. 6 | 1. `cd` into `Pints-Train/` and run `python -m inference.generate --model_config_name --checkpoint_path /lit_model.pth --tokenizer_path `. 7 | 1. Note that the prompts should be properly formatted with the correct prompt style (e.g., ChatML), otherwise the model might output random garbage. 8 | -------------------------------------------------------------------------------- /inference/generate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Optional 3 | import warnings 4 | import torch 5 | from pathlib import Path 6 | from lit_gpt.tokenizer import Tokenizer 7 | import lightning 8 | import time 9 | from os.path import normpath, join 10 | from os import getcwd 11 | from jsonargparse import CLI 12 | from lit_gpt.model import GPT, Config 13 | 14 | # Referenced from https://github.com/Lightning-AI/lit-llama/blob/main/generate/full.py 15 | def main( 16 | prompt: str = "Hello, my name is", 17 | max_new_tokens: int = 50, 18 | top_k: int = 200, 19 | temperature: float = 0.8, 20 | checkpoint_path: Optional[Path] = None, 21 | tokenizer_path: Path = Path('tokenizer'), 22 | ): 23 | 24 | checkpoint_path = normpath(join(getcwd(), checkpoint_path)) 25 | print('checkpoint_path: ', checkpoint_path) 26 | checkpoint_path = Path(checkpoint_path) 27 | 28 | tokenizer_path = normpath(join(getcwd(), tokenizer_path)) 29 | print('tokenizer_path: ', tokenizer_path) 30 | tokenizer_path = Path(tokenizer_path) 31 | 32 | assert checkpoint_path.is_file(), checkpoint_path 33 | assert tokenizer_path.is_dir(), tokenizer_path 34 | 35 | config = Config.from_name('2.0-Pints-Upscaled') 36 | model = GPT(config) 37 | checkpoint = torch.load(checkpoint_path) 38 | checkpoint = checkpoint["model"] 39 | model.load_state_dict(checkpoint) 40 | model.eval() 41 | 42 | fabric = lightning.Fabric(devices=1, precision='bf16-true') 43 | # fabric = lightning.Fabric(devices=1, precision='32-true') 44 | model = fabric.setup(model) 45 | 46 | tokenizer = Tokenizer(tokenizer_path) 47 | 48 | prompt = f'''<|im_start|>system 49 | you are an expert in writing<|im_end|> 50 | <|im_start|>user 51 | {prompt}<|im_end|> 52 | <|im_start|>assistant\n''' 53 | 54 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device) 55 | print(encoded) 56 | prompt_length = encoded.size(0) 57 | lightning.seed_everything(1234) 58 | 59 | # Use `samples` to generate a few samples. 60 | # for i in range(samples): 61 | t0 = time.perf_counter() 62 | y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) 63 | t = time.perf_counter() - t0 64 | 65 | model.reset_cache() 66 | 67 | print(tokenizer.decode(y)) 68 | tokens_generated = y.size(0) - prompt_length 69 | print(f"Time for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) 70 | if fabric.device.type == "cuda": 71 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr) 72 | 73 | 74 | # Referenced from https://github.com/Lightning-AI/lit-llama/blob/main/generate.py#L19 75 | @torch.no_grad() 76 | def generate( 77 | model, 78 | idx: torch.Tensor, 79 | max_new_tokens: int, 80 | *, 81 | max_seq_length: Optional[int] = None, 82 | temperature: float = 1.0, 83 | top_k: Optional[int] = None, 84 | eos_id: Optional[int] = None, 85 | ) -> torch.Tensor: 86 | """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 87 | 88 | The implementation of this function is modified from A. Karpathy's nanoGPT. 89 | 90 | Args: 91 | model: The model to use. 92 | idx: Tensor of shape (T) with indices of the prompt sequence. 93 | max_new_tokens: The number of new tokens to generate. 94 | max_seq_length: The maximum sequence length allowed. 95 | temperature: Scales the predicted logits by 1 / temperature 96 | top_k: If specified, only sample among the tokens with the k highest probabilities 97 | eos_id: If specified, stop generating any more token once the token is triggered 98 | """ 99 | # create an empty tensor of the expected final shape and fill in the current tokens 100 | T = idx.size(0) 101 | T_new = T + max_new_tokens 102 | if max_seq_length is None: 103 | max_seq_length = min(T_new, model.config.block_size) 104 | 105 | device, dtype = idx.device, idx.dtype 106 | # create an empty tensor of the expected final shape and fill in the current tokens 107 | empty = torch.empty(T_new, dtype=dtype, device=device) 108 | empty[:T] = idx 109 | idx = empty 110 | input_pos = torch.arange(0, T, device=device) 111 | 112 | if idx.device.type == "xla": 113 | import torch_xla.core.xla_model as xm 114 | 115 | xm.mark_step() 116 | 117 | # generate max_new_tokens tokens 118 | for _ in range(max_new_tokens): 119 | x = idx.index_select(0, input_pos).view(1, -1) 120 | 121 | # forward 122 | logits = model(x, input_pos=input_pos, max_seq_length=max_seq_length) 123 | logits = logits[0, -1] / temperature 124 | 125 | # optionally crop the logits to only the top k options 126 | if top_k is not None: 127 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 128 | logits = torch.where(logits < v[[-1]], -float("Inf"), logits) 129 | 130 | probs = torch.nn.functional.softmax(logits, dim=-1) 131 | idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) 132 | 133 | # advance 134 | input_pos = input_pos[-1:] + 1 135 | 136 | if idx.device.type == "xla": 137 | xm.mark_step() 138 | 139 | # concatenate the new generation 140 | idx = idx.index_copy(0, input_pos, idx_next) 141 | 142 | # if token is triggered, return the output (stop generation) 143 | if idx_next == eos_id: 144 | return idx[:input_pos] # include the EOS token 145 | 146 | return idx 147 | 148 | 149 | if __name__ == '__main__': 150 | torch.set_float32_matmul_precision("high") 151 | warnings.filterwarnings( 152 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31 153 | "ignore", 154 | message="ComplexHalf support is experimental and many operators don't support it yet" 155 | ) 156 | warnings.filterwarnings( 157 | # Triggered in bitsandbytes/autograd/_functions.py:298 158 | "ignore", 159 | message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization", 160 | ) 161 | CLI(main) -------------------------------------------------------------------------------- /inference/generate_hf.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from pathlib import Path 4 | from os.path import normpath, join 5 | from os import getcwd 6 | from jsonargparse import CLI 7 | 8 | def main( 9 | checkpoint_path: Path = Path('checkpoint'), 10 | prompt: str = "???", 11 | max_new_tokens: int = 50, 12 | repetition_penalty: float = 1.0, 13 | temperature: float = 0.8, 14 | tokenizer_path: Path = Path('tokenizer'), 15 | ): 16 | checkpoint_path = normpath(join(getcwd(), checkpoint_path)) 17 | print('checkpoint_path: ', checkpoint_path) 18 | checkpoint_path = Path(checkpoint_path) 19 | 20 | tokenizer_path = normpath(join(getcwd(), tokenizer_path)) 21 | print('tokenizer_path: ', tokenizer_path) 22 | tokenizer_path = Path(tokenizer_path) 23 | 24 | assert checkpoint_path.is_dir(), checkpoint_path 25 | assert tokenizer_path.is_dir(), tokenizer_path 26 | 27 | 28 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, device_map='cuda') 29 | 30 | prompt = f'''<|im_start|>system 31 | You are a helpful, respectful, and honest assistant.<|im_end|> 32 | <|im_start|>user 33 | {prompt}<|im_end|> 34 | <|im_start|>assistant\n''' 35 | 36 | input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda() 37 | input_ids_len = input_ids.shape[1] 38 | 39 | # Load the model 40 | model = AutoModelForCausalLM.from_pretrained(checkpoint_path, device_map='cuda', attn_implementation="eager") 41 | model.resize_token_embeddings(len(tokenizer)) 42 | 43 | # Generate output 44 | answer_ids = model.generate( 45 | input_ids, 46 | max_new_tokens=max_new_tokens, 47 | pad_token_id=tokenizer.pad_token_id, 48 | temperature=temperature, 49 | repetition_penalty=repetition_penalty 50 | )[0][input_ids_len:] 51 | print(tokenizer.pad_token_id) 52 | # Decode the generated answer 53 | answer = tokenizer.decode(answer_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) 54 | print(answer) 55 | 56 | if __name__ == '__main__': 57 | CLI(main) 58 | -------------------------------------------------------------------------------- /inference/generate_lit.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import List, Union 4 | 5 | import lightning 6 | import torch 7 | from jsonargparse import CLI 8 | from lightning.fabric.accelerators.accelerator import Accelerator 9 | 10 | from lit_gpt.generate.base import generate 11 | from lit_gpt.model import GPT, Config 12 | from lit_gpt.tokenizer import Tokenizer 13 | from lit_gpt.utils import load_weights 14 | 15 | # use global variables to ensure that model is loaded only once, saving lots of time 16 | # WARNING: it won't work if get_lit_inferences() is called multiple times 17 | # with different models, checkpoints or tokenizers in the same run. 18 | # Start a new run if those are different. 19 | model = fabric = tokenizer = None 20 | 21 | 22 | def get_lit_inferences( 23 | model_config_name: str, 24 | checkpoint_path: Path, # The path to lit_model.pth. 25 | tokenizer_path: Path, # The path to the tokenizer folder. 26 | prompts: List[str] = ['<|im_start|>user\nHello, who is Bill Gates?<|im_end|>\n<|im_start|>assistant'], 27 | max_new_tokens: int = 512, # Set a low value for testing as model may be broken and not output eos 28 | top_k: int = 200, 29 | temperature: float = 0.01, # Don't set this to 0 as it would cause a runtime error. 30 | accelerator: Union[str, Accelerator] = 'auto', 31 | devices: Union[List[int], str, int] = 'auto', 32 | precision: str = 'bf16-mixed' 33 | ): 34 | checkpoint_path = checkpoint_path.resolve() 35 | print('checkpoint_path: ', checkpoint_path) 36 | checkpoint_path = Path(checkpoint_path) 37 | 38 | tokenizer_path = tokenizer_path.resolve() 39 | print('tokenizer_path: ', tokenizer_path) 40 | tokenizer_path = Path(tokenizer_path) 41 | 42 | # Verify checkpoint and tokenizer paths 43 | assert checkpoint_path.is_file(), checkpoint_path 44 | assert tokenizer_path.is_dir(), tokenizer_path 45 | 46 | # Use global variables to ensure that model is loaded only once, saving lots of time 47 | global model, fabric, tokenizer 48 | if not model: 49 | # Initialize Fabric 50 | fabric = lightning.Fabric( 51 | accelerator=accelerator, 52 | devices=devices, 53 | precision=precision 54 | ) 55 | 56 | # Load the model config 57 | model_config = Config.from_name(model_config_name) 58 | 59 | with fabric.init_module(empty_init=True): 60 | # Initialize the model with the given config 61 | model = GPT(model_config) 62 | 63 | # Load the state dict into RAM 64 | state_dict = load_weights(checkpoint_path, remove_prefix=True) 65 | 66 | # Load the state dict into the model 67 | model.load_state_dict(state_dict) 68 | print(f'INFO: Model {model.__class__.__name__} loaded') 69 | 70 | model.eval() 71 | model = fabric.setup(model) 72 | 73 | # Initialize tokenizer 74 | tokenizer = Tokenizer(tokenizer_path) 75 | 76 | # Generate inferences 77 | results = [] 78 | for i, prompt in enumerate(prompts): 79 | # Encode the prompt. eos should be set to False, otherwise the model might think that 80 | # the conversation has ended, and start generating irrelevant information. 81 | # bos should be set to the same as the training data. 82 | encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device) 83 | 84 | encoded_result = generate( 85 | fabric=fabric, 86 | model=model, 87 | prompt=encoded, 88 | max_returned_tokens=encoded.size(0) + max_new_tokens, 89 | temperature=temperature, 90 | eos_id=tokenizer.eos_id, 91 | top_k=top_k, 92 | include_prompt=False, 93 | ) 94 | 95 | # Decode the result and print it 96 | result = tokenizer.decode(encoded_result) 97 | print('=' * 80) 98 | print(result) 99 | print('=' * 80) 100 | print(f'INFO: {i+1}/{len(prompts)} inferences generated.') 101 | results.append(result) 102 | 103 | return results 104 | 105 | 106 | if __name__ == '__main__': 107 | from jsonargparse import CLI 108 | 109 | torch.set_float32_matmul_precision('high') 110 | warnings.filterwarnings( 111 | # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31 112 | 'ignore', 113 | message="ComplexHalf support is experimental and many operators don't support it yet" 114 | ) 115 | warnings.filterwarnings( 116 | # Triggered in bitsandbytes/autograd/_functions.py:298 117 | 'ignore', 118 | message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization', 119 | ) 120 | CLI(get_lit_inferences, as_positional=False) 121 | -------------------------------------------------------------------------------- /lit_gpt/__init__.py: -------------------------------------------------------------------------------- 1 | from lightning_utilities.core.imports import RequirementCache 2 | 3 | if not bool(RequirementCache("torch>=2.1.0dev")): 4 | raise ImportError( 5 | "Lit-GPT requires torch nightly (future torch 2.1). Please follow the installation instructions in the" 6 | " repository README.md" 7 | ) 8 | _LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0") 9 | if not bool(_LIGHTNING_AVAILABLE): 10 | raise ImportError( 11 | "Lit-GPT requires Lightning nightly (future lightning 2.1). Please run:\n" 12 | f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}" 13 | ) 14 | -------------------------------------------------------------------------------- /lit_gpt/args.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | 7 | @dataclass 8 | class TrainArgs: 9 | """Training-related arguments""" 10 | 11 | save_interval: Optional[int] = 1000 12 | """Number of optimizer steps between saving checkpoints""" 13 | log_interval: int = 1 14 | """Number of iterations between logging calls""" 15 | global_batch_size: int = 64 16 | """Number of samples between optimizer steps across data-parallel ranks""" 17 | micro_batch_size: int = 4 18 | """Number of samples per data-parallel rank""" 19 | lr_warmup_steps: int = 100 20 | """Number of iterations with learning rate warmup active""" 21 | epochs: Optional[int] = None 22 | """Number of epochs to train on""" 23 | # TODO: `pretrain` is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs? 24 | max_tokens: Optional[int] = None 25 | """Total number of tokens to train on""" 26 | max_steps: Optional[int] = None 27 | """Limits the number of optimizer steps to run""" 28 | max_seq_length: Optional[int] = None 29 | """Limits the length of samples""" 30 | tie_embeddings: Optional[bool] = None 31 | """Whether to tie the embedding weights with the language modeling head weights""" 32 | 33 | # Optimization args 34 | 35 | # The default of 1e-3 that came from litgpt caused a loss spike. Hence, lowered to 1e-5. 36 | # See original litgpt code: https://github.com/Lightning-AI/litgpt/blob/64bd9eb32e7fd2bebe8ff187c6f4847b85fe16e8/litgpt/args.py#L36 37 | learning_rate: float = 1e-5 38 | weight_decay: float = 0.02 39 | beta1: float = 0.9 40 | beta2: float = 0.95 41 | max_norm: Optional[float] = None 42 | # Tinyllama https://arxiv.org/pdf/2401.02385.pdf 43 | min_lr: float = 4e-5 44 | 45 | def gradient_accumulation_iters(self, devices: int) -> int: 46 | """Number of iterations between gradient synchronizations""" 47 | gradient_accumulation_iters = self.batch_size(devices) // self.micro_batch_size 48 | assert gradient_accumulation_iters > 0 49 | return gradient_accumulation_iters 50 | 51 | def batch_size(self, devices: int) -> int: 52 | """Number of samples between optimizer steps per data-parallel rank""" 53 | batch_size = self.global_batch_size // devices 54 | assert batch_size > 0 55 | return batch_size 56 | 57 | 58 | @dataclass 59 | class EvalArgs: 60 | """Evaluation-related arguments""" 61 | 62 | interval: int = 600 63 | """Number of optimizer steps between evaluation calls""" 64 | max_new_tokens: Optional[int] = None 65 | """Number of tokens to generate""" 66 | max_iters: int = 100 67 | """Number of iterations""" 68 | -------------------------------------------------------------------------------- /lit_gpt/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Type 3 | import lit_gpt.model 4 | from lit_gpt.config_base import ConfigBase 5 | 6 | @dataclass 7 | class Config(ConfigBase): 8 | 9 | ''' 10 | Config subclasses ConfigBase to add on things that is not required 11 | outside of training, so that data processing can be done. 12 | Independent of additional training modules (that are imported by `lit_gpt.model`). 13 | For example, modules such as `flash-attention` won't run on non-CUDA devices. 14 | ''' 15 | 16 | @property 17 | def mlp_class(self) -> Type: 18 | # `self._mlp_class` cannot be the type to keep the config json serializable 19 | return getattr(lit_gpt.model, self._mlp_class) 20 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/lit_gpt/datamodules/__init__.py -------------------------------------------------------------------------------- /lit_gpt/datamodules/base.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | from abc import abstractmethod 3 | from functools import partial 4 | from typing import List, Optional, Union 5 | from multiprocessing import cpu_count 6 | 7 | import torch 8 | from lightning import LightningDataModule 9 | 10 | from lit_gpt.tokenizer import Tokenizer 11 | from lit_gpt.prompts import PromptStyle 12 | from lit_gpt.datamodules.typings.base import SFTDatasetItem, SFTCollatedBatch 13 | from lit_gpt.datamodules.typings.prepared_dataset import PreparedDataset 14 | 15 | 16 | class DataModule(LightningDataModule): 17 | """Base class for all data modules in LitGPT.""" 18 | 19 | """Whether to mask the prompt section from the label (with ``ignore_index``).""" 20 | mask_prompt: bool = True 21 | """The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.""" 22 | prompt_style: Union[str, PromptStyle] = 'chatml' 23 | """The index to use for elements to be ignored in the label.""" 24 | ignore_index: int = -100 25 | """The random seed for shuffling the dataset.""" 26 | seed: int = 42 27 | """How many DataLoader processes to use for loading.""" 28 | num_workers = max(1, cpu_count() // 2) # Use half of available cores 29 | """The directory in which the downloaded dataset gets saved.""" 30 | include_multiturn_conversations: bool = True 31 | """The repo from where the data is downloaded""" 32 | repo_id: str = '' 33 | 34 | @abstractmethod 35 | def connect( 36 | self, 37 | tokenizer: Optional[Tokenizer] = None, 38 | batch_size: int = 1, 39 | max_seq_length: Optional[int] = None, 40 | ) -> None: 41 | """All settings that can't be determined at the time of instantiation need to be passed through here 42 | before any dataloaders can be accessed. 43 | """ 44 | 45 | def setup(self, stage: str = '') -> PreparedDataset: 46 | # Stub is to redefine the default signature, because the concept of 'stage' does not exist in LitGPT 47 | pass 48 | 49 | def __repr__(self) -> str: 50 | return f'{self.__class__.__name__}()' 51 | 52 | 53 | def get_sft_collate_fn( 54 | max_seq_length: int = -1, 55 | pad_id: int = 0, 56 | ignore_index: int = -100, 57 | eos_id: Optional[int] = 2, 58 | ): 59 | """Returns the collate function for supervised finetuning (needed in the DataLoader). 60 | 61 | The collate function gets a list of dicts with keys `input_ids` and `labels`. 62 | It returns a dict with batched `input_ids` and `labels`. Also pads short sequences to the longest element in 63 | the batch. Optionally truncates all sequences to the specified maximum length. 64 | Args: 65 | max_seq_length: Anything more than this will be truncated. 66 | pad_id: This defaults to -100. It is not recommend to use the tokenizer's pad_id here. 67 | We need to ignore cross-entropy loss on the pad tokens, but if pad_id is provided, it will not be ignored. 68 | ignore_index: Defaults to -100 as pytorch's default. 69 | eos_id: Specify if you wish to use other eos_id, otherwise it defaults to 2. This will be added to the end 70 | if `max_seq_length` is reached. Otherwise, you can provied `None` to not add. 71 | """ 72 | 73 | if pad_id == 0: 74 | print( 75 | 'WARN: Pad token is 0, which is usually token. It is more correct to create and use a pad token.' 76 | ) 77 | if pad_id is None: 78 | pad_id = 0 79 | print("WARN: datamodules/base.py: pad_id was None, changed to 0") 80 | 81 | return partial( 82 | _sft_collate_fn, 83 | max_seq_length=max_seq_length, 84 | pad_id=pad_id, 85 | ignore_index=ignore_index, 86 | eos_id=eos_id, 87 | ) 88 | 89 | 90 | def _sft_collate_fn( 91 | samples: List[SFTDatasetItem], 92 | max_seq_length: int, 93 | pad_id: int, 94 | ignore_index: int, 95 | eos_id: Optional[int], 96 | ) -> SFTCollatedBatch: 97 | batched = {} 98 | for key in ('input_ids', 'labels'): 99 | # So we want to pad the inputs only. 100 | # For the labels, we don't want the model to learn to predict the pad tokens 101 | # So we just pad it with ignore_index 102 | pad_value = pad_id if key == 'input_ids' else ignore_index 103 | 104 | batched[key] = torch.nn.utils.rnn.pad_sequence( 105 | [sample[key] for sample in samples], 106 | batch_first=True, 107 | padding_value=pad_value, 108 | 109 | ) 110 | 111 | # Truncate if needed 112 | if max_seq_length > 0: 113 | batched[key] = batched[key][:, :max_seq_length] 114 | 115 | if isinstance(eos_id, int): 116 | # Replace last token with eos_id only if it's not `ignore_index` or `pad_id`: 117 | condition = (batched[key][:, -1] != ignore_index) & ( 118 | batched[key][:, -1] != pad_id 119 | ) 120 | batched[key][condition, -1] = eos_id 121 | 122 | elif eos_id is None: 123 | pass 124 | else: 125 | raise TypeError( 126 | f'`eos_id` of value[{eos_id}] and type[{type(eos_id)}] is not valid.' 127 | ) 128 | 129 | return batched 130 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/base_test.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | import pytest 4 | import torch 5 | 6 | from lit_gpt.datamodules.base import get_sft_collate_fn 7 | from lit_gpt.datamodules.sft_dataset_base import SFTDataset 8 | from lit_gpt.prompts import PromptStyle 9 | 10 | 11 | @pytest.mark.parametrize('mask_prompt', [True, False]) 12 | @pytest.mark.parametrize('ignore_index', [-1, -100]) 13 | @pytest.mark.parametrize('max_seq_length', [1000]) 14 | def test_sft_dataset(max_seq_length, ignore_index, mask_prompt, MOCK_TOKENIZER_FIXTURE): 15 | class Style(PromptStyle): 16 | def apply(self, prompt, **kwargs): 17 | return f'In: {prompt} Out:' 18 | 19 | i = ignore_index 20 | data = [ 21 | {'instruction': 'Foo', 'output': 'Bar'}, 22 | {'instruction': 'Boo', 'output': 'Ahh'}, 23 | ] 24 | 25 | dataset = SFTDataset( 26 | data=data, 27 | tokenizer=MOCK_TOKENIZER_FIXTURE, 28 | prompt_style=Style(), 29 | mask_prompt=mask_prompt, 30 | ignore_index=ignore_index, 31 | max_seq_length=max_seq_length, 32 | ) 33 | assert len(dataset) == len(data) 34 | 35 | expected_input_ids = torch.tensor( 36 | [1, 73, 110, 58, 32, 70, 111, 111, 32, 79, 117, 116, 58, 66, 97, 114, 2] 37 | ) 38 | # If prompt is not masked, labels == input_ids 39 | expected_labels = ( 40 | torch.tensor([i, i, i, i, i, i, i, i, i, i, i, i, i, 66, 97, 114, 2]) 41 | if mask_prompt 42 | else expected_input_ids 43 | ) 44 | 45 | assert torch.equal(dataset[0]['input_ids'], expected_input_ids[:max_seq_length]) 46 | assert torch.equal(dataset[0]['labels'], expected_labels[:max_seq_length]) 47 | 48 | 49 | @pytest.mark.parametrize('mask_prompt', [True, False]) 50 | @pytest.mark.parametrize('ignore_index', [-1, -100]) 51 | @pytest.mark.parametrize('max_seq_length', [5]) 52 | def test_sft_dataset_hit_max_seq_length( 53 | max_seq_length, ignore_index, mask_prompt, MOCK_TOKENIZER_FIXTURE 54 | ): 55 | class Style(PromptStyle): 56 | def apply(self, prompt, **kwargs): 57 | return f'In: {prompt} Out:' 58 | 59 | i = ignore_index 60 | data = [ 61 | {'instruction': 'Foo', 'output': 'Bar'}, 62 | {'instruction': 'Boo', 'output': 'Ahh'}, 63 | ] 64 | 65 | dataset = SFTDataset( 66 | data=data, 67 | tokenizer=MOCK_TOKENIZER_FIXTURE, 68 | prompt_style=Style(), 69 | mask_prompt=mask_prompt, 70 | ignore_index=ignore_index, 71 | max_seq_length=max_seq_length, 72 | ) 73 | assert len(dataset) == len(data) 74 | 75 | eos = torch.tensor([MOCK_TOKENIZER_FIXTURE.eos_id]) 76 | 77 | expected_input_ids = torch.tensor( 78 | [1, 73, 110, 58, 32, 70, 111, 111, 32, 79, 117, 116, 58, 66, 97, 114, 1] 79 | ) 80 | # If prompt is not masked, labels == input_ids 81 | expected_labels = ( 82 | torch.tensor([i, i, i, i, i, i, i, i, i, i, i, i, i, 66, 97, 114, 1]) 83 | if mask_prompt 84 | else expected_input_ids 85 | ) 86 | 87 | assert torch.equal( 88 | dataset[0]['input_ids'], 89 | torch.concat((expected_input_ids[: max_seq_length - 1], eos)), 90 | ) 91 | 92 | # If the prompt is not masked, we won't see the eos token 93 | # Although we should, but this is too trival a problem to fix 94 | # And it is more likely to occur for training very low context window models, 95 | # using dataset with high tokens length or without filtering them, 96 | # i.e, context window 2k, but dataset often has prompts (not even response yet) exceeding that. 97 | if mask_prompt: 98 | assert torch.equal(dataset[0]['labels'], expected_labels[:max_seq_length]) 99 | else: 100 | assert torch.equal( 101 | dataset[0]['labels'], 102 | torch.concat((expected_labels[: max_seq_length - 1], eos)), 103 | ) 104 | 105 | 106 | @pytest.mark.parametrize('ignore_index', [-1, -100]) 107 | @pytest.mark.parametrize('pad_id', [0, 100]) 108 | def test_sft_collate_fn_padding(pad_id, ignore_index): 109 | collate = get_sft_collate_fn(pad_id=pad_id, ignore_index=ignore_index) 110 | samples = [ 111 | {'input_ids': torch.tensor([1, 2, 3]), 'labels': torch.tensor([10, 20, 30])}, 112 | { 113 | 'input_ids': torch.tensor([4, 5, 6, 7, 8]), 114 | 'labels': torch.tensor([40, 50, 60, 70, 80]), 115 | }, 116 | ] 117 | expected = { 118 | 'input_ids': torch.tensor([[1, 2, 3, pad_id, pad_id], [4, 5, 6, 7, 8]]), 119 | 'labels': torch.tensor( 120 | [[10, 20, 30, ignore_index, ignore_index], [40, 50, 60, 70, 80]] 121 | ), 122 | } 123 | batch = collate(samples) 124 | assert all(torch.equal(batch[k], expected[k]) for k in ('input_ids', 'labels')) 125 | 126 | 127 | def test_sft_collate_fn_truncation(): 128 | default_eos_token = 2 129 | collate = get_sft_collate_fn(max_seq_length=2) 130 | samples = [ 131 | {'input_ids': torch.tensor([1, 2, 3]), 'labels': torch.tensor([10, 20, 30])}, 132 | { 133 | 'input_ids': torch.tensor([4, 5, 6, 7, 8]), 134 | 'labels': torch.tensor([40, 50, 60, 70, 80]), 135 | }, 136 | ] 137 | expected = { 138 | # It should truncate and make sure the last token is the eos token 139 | 'input_ids': torch.tensor([[1, default_eos_token], [4, default_eos_token]]), 140 | 'labels': torch.tensor([[10, default_eos_token], [40, default_eos_token]]), 141 | } 142 | batch = collate(samples) 143 | assert all(torch.equal(batch[k], expected[k]) for k in ('input_ids', 'labels')) 144 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/capybara.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import List, Optional, TypedDict 4 | from datasets import load_dataset, Dataset 5 | 6 | from lit_gpt.prompts import PromptStyle 7 | from lit_gpt.datamodules.base import DataModule 8 | from lit_gpt.datamodules.sft_multiturn_dataset_base import SFTMultiTurnDataset 9 | from lit_gpt.tokenizer import Tokenizer 10 | from lit_gpt.datamodules.typings.formatted_dataset import ( 11 | FormattedSFTDict, 12 | FormattedSFTMultiturnConversation, 13 | FormattedSFTMultiturnDataset, 14 | ) 15 | 16 | HUGGINGFACE_ID = 'LDJnr/Capybara' 17 | DOWNLOAD_DIR = './data/sft/capybara' 18 | SPLIT = ['train'] 19 | 20 | 21 | class PreparedCapybara(TypedDict): 22 | train_dataset: SFTMultiTurnDataset 23 | val_dataset: None 24 | test_dataset: None 25 | 26 | 27 | class CapybaraMessage(TypedDict): 28 | input: str 29 | output: str 30 | 31 | 32 | class CapybaraRow(TypedDict): 33 | source: str 34 | conversation: List[CapybaraMessage] 35 | 36 | 37 | @dataclass 38 | class Capybara(DataModule): 39 | """Capybara data module for supervised finetuning.""" 40 | 41 | """Whether to include multi-turn conversations in the dataset.""" 42 | include_multiturn_conversations: bool = True 43 | 44 | """The directory in which the downloaded dataset gets saved.""" 45 | download_dir: Path = Path(DOWNLOAD_DIR) 46 | 47 | """The repo from where the data is downloaded""" 48 | repo_id: str = HUGGINGFACE_ID 49 | 50 | tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) 51 | batch_size: int = field(default=1, init=False, repr=False) 52 | max_seq_length: int = field(default=-1, init=False, repr=False) 53 | train_dataset: Optional[SFTMultiTurnDataset] = field( 54 | default=None, init=False, repr=False 55 | ) 56 | test_dataset: Optional[SFTMultiTurnDataset] = field( 57 | default=None, init=False, repr=False 58 | ) 59 | 60 | def __init__(self, prompt_style='chatml', num_workers: Optional[int] = None): 61 | super().__init__() 62 | self.prompt_style = PromptStyle.from_name(prompt_style) 63 | if num_workers: 64 | self.num_workers = num_workers 65 | 66 | def connect( 67 | self, 68 | tokenizer: Optional[Tokenizer] = None, 69 | batch_size: int = 1, 70 | max_seq_length: Optional[int] = None, 71 | ) -> None: 72 | self.tokenizer = tokenizer 73 | self.batch_size = batch_size 74 | self.max_seq_length = -1 if max_seq_length is None else max_seq_length 75 | 76 | def prepare_data(self) -> List[Dataset]: 77 | return load_dataset( 78 | self.repo_id, 79 | split=SPLIT, 80 | cache_dir=self.download_dir, 81 | streaming=False, 82 | ) 83 | 84 | def setup(self, stage: str = '') -> PreparedCapybara: 85 | dataset = self.prepare_data() 86 | train_data = format_dataset(dataset[0], self.include_multiturn_conversations) 87 | # TODO: Capybara doesn't have test_data. We can split it out from the train_data 88 | # test_data = format_dataset(dataset[1], self.include_multiturn_conversations) 89 | 90 | self.train_dataset = SFTMultiTurnDataset( 91 | data=train_data, 92 | tokenizer=self.tokenizer, 93 | prompt_style=self.prompt_style, 94 | max_seq_length=self.max_seq_length, 95 | mask_prompt=self.mask_prompt, 96 | ignore_index=self.ignore_index, 97 | ) 98 | # self.test_dataset = SFTMultiTurnDataset( 99 | # data=test_data, 100 | # tokenizer=self.tokenizer, 101 | # prompt_style=self.prompt_style, 102 | # max_seq_length=self.max_seq_length, 103 | # mask_prompt=self.mask_prompt, 104 | # ignore_index=self.ignore_index, 105 | # ) 106 | 107 | return { 108 | 'train_dataset': self.train_dataset, 109 | 'val_dataset': self.test_dataset, 110 | 'test_dataset': self.test_dataset, 111 | } 112 | 113 | 114 | def format_dataset( 115 | dataset: List[CapybaraRow], include_multi_turn_conversations: bool 116 | ) -> FormattedSFTMultiturnDataset: 117 | formatted: FormattedSFTMultiturnDataset = [] 118 | 119 | for entry in dataset: 120 | formatted_convo: FormattedSFTMultiturnConversation = [] 121 | convo = entry['conversation'] 122 | 123 | for i in range(0, len(convo)): 124 | formatted_sft_dict: FormattedSFTDict = { 125 | 'instruction': convo[i]['input'], 126 | 'input': '', 127 | 'output': convo[i]['output'], 128 | } 129 | formatted_convo.append(formatted_sft_dict) 130 | 131 | # If don't want to include multi turn, break after first 132 | # turn is appended: - no point including latter turns as 133 | # they become orphaned discussions without starting context 134 | if not include_multi_turn_conversations: 135 | break 136 | 137 | formatted.append(formatted_convo) 138 | 139 | return formatted 140 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/capybara_test.py: -------------------------------------------------------------------------------- 1 | from json import loads 2 | from lit_gpt.datamodules.capybara import format_dataset 3 | 4 | MOCK_DATA = [] 5 | 6 | with open( 7 | './lit_gpt/datamodules/capybara_test.jsonl', 'r', encoding='utf-8' 8 | ) as jsonl_file: 9 | for line in jsonl_file: 10 | MOCK_DATA.append(loads(line)) 11 | 12 | 13 | def test_capybara_format_dataset_multiturn(): 14 | train_data = format_dataset( 15 | dataset=MOCK_DATA, include_multi_turn_conversations=True 16 | ) 17 | 18 | train_data_row_one = train_data[0] 19 | mock_data_row_one = MOCK_DATA[0]['conversation'] 20 | 21 | # Multiturn row check 22 | # Each data row has an input/output pair, hence 1 to 1 conversion 23 | assert len(mock_data_row_one) == len(train_data_row_one) 24 | 25 | # The instruction and output pair in train_data is made 26 | # from `content` from the nth and nth+1 row in `messages`` 27 | assert mock_data_row_one[0]['input'] == train_data_row_one[0]['instruction'] 28 | assert mock_data_row_one[0]['output'] == train_data_row_one[0]['output'] 29 | assert len(train_data_row_one) == len(mock_data_row_one) 30 | 31 | 32 | def test_capybara_format_dataset(): 33 | train_data = format_dataset( 34 | dataset=MOCK_DATA, include_multi_turn_conversations=False 35 | ) 36 | 37 | train_data_row_two = train_data[1] 38 | mock_data_row_two = MOCK_DATA[1]['conversation'] 39 | 40 | assert mock_data_row_two[0]['input'] == train_data_row_two[0]['instruction'] 41 | assert mock_data_row_two[0]['output'] == train_data_row_two[0]['output'] 42 | 43 | # Because we don't include multiturn, 44 | # the behaviour is that only 1 instruction/output pair is made, 45 | assert len(train_data_row_two) == 1 46 | # despite the dataset having more than 1 pair 47 | assert len(mock_data_row_two) > 1 48 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/create_test_jsonl.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datasets import load_dataset 3 | 4 | 5 | def create_test_jsonl(huggingface_dataset_id: str, split: str, number_of_rows=16): 6 | dataset_stream = load_dataset(huggingface_dataset_id, split=split, streaming=True) 7 | 8 | sampled_rows = [] 9 | 10 | for i, row in enumerate(dataset_stream): 11 | sampled_rows.append(row) 12 | 13 | if len(sampled_rows) == number_of_rows: 14 | break 15 | 16 | filename = huggingface_dataset_id.split('/')[1] 17 | 18 | with open(f'{filename.lower()}_test.jsonl', 'w') as outfile: 19 | for row in sampled_rows: 20 | json.dump(row, outfile) # Dump each row as a JSON object 21 | outfile.write('\n') 22 | 23 | 24 | if __name__ == '__main__': 25 | from jsonargparse import CLI 26 | 27 | CLI(create_test_jsonl, as_positional=False) 28 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/deita.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import List, Optional, TypedDict, Literal 4 | from datasets import load_dataset, Dataset 5 | 6 | from lit_gpt.prompts import PromptStyle 7 | from lit_gpt.datamodules.base import DataModule 8 | from lit_gpt.datamodules.sft_multiturn_dataset_base import SFTMultiTurnDataset 9 | from lit_gpt.tokenizer import Tokenizer 10 | from lit_gpt.datamodules.typings.formatted_dataset import ( 11 | FormattedSFTDict, 12 | FormattedSFTMultiturnConversation, 13 | FormattedSFTMultiturnDataset, 14 | ) 15 | 16 | HUGGINGFACE_ID = 'HuggingFaceH4/deita-10k-v0-sft' 17 | DOWNLOAD_DIR = './data/sft/deita' 18 | SPLIT = ['train_sft', 'test_sft'] 19 | 20 | 21 | class PreparedDeita(TypedDict): 22 | train_dataset: SFTMultiTurnDataset 23 | val_dataset: SFTMultiTurnDataset 24 | test_dataset: SFTMultiTurnDataset 25 | 26 | 27 | class DeitaMessage(TypedDict): 28 | content: str 29 | role: Literal['user', 'assistant'] 30 | 31 | 32 | class DeitaRow(TypedDict): 33 | prompt: str 34 | prompt_id: str 35 | messages: List[DeitaMessage] 36 | 37 | 38 | @dataclass 39 | class Deita(DataModule): 40 | """Deita data module for supervised finetuning.""" 41 | 42 | """Whether to include multi-turn conversations in the dataset.""" 43 | include_multiturn_conversations: bool = True 44 | 45 | """The directory in which the downloaded dataset gets saved.""" 46 | download_dir: Path = Path(DOWNLOAD_DIR) 47 | 48 | """The repo from where the data is downloaded""" 49 | repo_id: str = HUGGINGFACE_ID 50 | 51 | tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) 52 | batch_size: int = field(default=1, init=False, repr=False) 53 | max_seq_length: int = field(default=-1, init=False, repr=False) 54 | train_dataset: Optional[SFTMultiTurnDataset] = field( 55 | default=None, init=False, repr=False 56 | ) 57 | test_dataset: Optional[SFTMultiTurnDataset] = field( 58 | default=None, init=False, repr=False 59 | ) 60 | 61 | def __init__(self, prompt_style='chatml', num_workers: Optional[int] = None): 62 | super().__init__() 63 | self.prompt_style = PromptStyle.from_name(prompt_style) 64 | if num_workers: 65 | self.num_workers = num_workers 66 | 67 | def connect( 68 | self, 69 | tokenizer: Optional[Tokenizer] = None, 70 | batch_size: int = 1, 71 | max_seq_length: Optional[int] = None, 72 | ) -> None: 73 | self.tokenizer = tokenizer 74 | self.batch_size = batch_size 75 | self.max_seq_length = -1 if max_seq_length is None else max_seq_length 76 | 77 | def prepare_data(self) -> List[Dataset]: 78 | return load_dataset( 79 | self.repo_id, 80 | split=SPLIT, 81 | cache_dir=self.download_dir, 82 | streaming=False, 83 | ) 84 | 85 | def setup(self, stage: str = '') -> PreparedDeita: 86 | dataset = self.prepare_data() 87 | train_data = format_dataset(dataset[0], self.include_multiturn_conversations) 88 | test_data = format_dataset(dataset[1], self.include_multiturn_conversations) 89 | 90 | self.train_dataset = SFTMultiTurnDataset( 91 | data=train_data, 92 | tokenizer=self.tokenizer, 93 | prompt_style=self.prompt_style, 94 | max_seq_length=self.max_seq_length, 95 | mask_prompt=self.mask_prompt, 96 | ignore_index=self.ignore_index, 97 | ) 98 | self.test_dataset = SFTMultiTurnDataset( 99 | data=test_data, 100 | tokenizer=self.tokenizer, 101 | prompt_style=self.prompt_style, 102 | max_seq_length=self.max_seq_length, 103 | mask_prompt=self.mask_prompt, 104 | ignore_index=self.ignore_index, 105 | ) 106 | 107 | return { 108 | 'train_dataset': self.train_dataset, 109 | 'val_dataset': self.test_dataset, 110 | 'test_dataset': self.test_dataset, 111 | } 112 | 113 | 114 | def format_dataset( 115 | dataset: List[DeitaRow], include_multi_turn_conversations: bool 116 | ) -> FormattedSFTMultiturnDataset: 117 | formatted: FormattedSFTMultiturnDataset = [] 118 | 119 | for entry in dataset: 120 | formatted_convo: FormattedSFTMultiturnConversation = [] 121 | convo = entry['messages'] 122 | 123 | # Each conversation is a flat list of user-assistant pairs. 124 | # So we iterate in 2-step manner 125 | for i in range(0, len(convo) - 1, 2): 126 | if convo[i]['role'] != 'user': 127 | print( 128 | f'WARN: Deita row with prompt_id[{entry["prompt_id"]}] is corrupted. Expected role to be `user`, but is `{convo[i]["role"]}` instead.' 129 | ) 130 | if convo[i + 1]['role'] != 'assistant': 131 | print( 132 | f'WARN: Deita row with prompt_id[{entry["prompt_id"]}] is corrupted. Expected role to be `assistant`, but is `{convo[i+1]["role"]}` instead.' 133 | ) 134 | 135 | formatted_sft_dict: FormattedSFTDict = { 136 | 'instruction': convo[i]['content'], 137 | 'input': '', 138 | 'output': convo[i + 1]['content'], 139 | } 140 | 141 | formatted_convo.append(formatted_sft_dict) 142 | 143 | # If don't want to include multi turn, break after first 144 | # turn is appended: - no point including latter turns as 145 | # they become orphaned discussions without starting context 146 | if not include_multi_turn_conversations: 147 | break 148 | 149 | formatted.append(formatted_convo) 150 | 151 | return formatted 152 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/deita_test.py: -------------------------------------------------------------------------------- 1 | from json import loads 2 | from lit_gpt.datamodules.deita import format_dataset 3 | 4 | MOCK_DATA = [] 5 | 6 | with open( 7 | './lit_gpt/datamodules/deita_test.jsonl', 'r', encoding='utf-8' 8 | ) as jsonl_file: 9 | for line in jsonl_file: 10 | MOCK_DATA.append(loads(line)) 11 | 12 | 13 | def test_deita_format_dataset_multiturn(): 14 | train_data = format_dataset( 15 | dataset=MOCK_DATA, include_multi_turn_conversations=True 16 | ) 17 | 18 | train_data_row_one = train_data[0] 19 | mock_data_row_one = MOCK_DATA[0]['messages'] 20 | 21 | # Multiturn row check 22 | # It comes in a flat list of user-assistant pair, that is turned into 1 train data row 23 | # Hence divde by 2 24 | assert len(mock_data_row_one) / 2 == len(train_data_row_one) 25 | 26 | # The instruction and output pair in train_data is made 27 | # from `content` from the nth and nth+1 row in `messages`` 28 | assert train_data_row_one[0]['instruction'] == mock_data_row_one[0]['content'] 29 | assert train_data_row_one[0]['output'] == mock_data_row_one[1]['content'] 30 | 31 | # Because 2 rows of `content` from the data is condensed into 1 row of train_data: 32 | assert len(train_data_row_one) == len(mock_data_row_one) / 2 33 | 34 | 35 | def test_deita_format_dataset(): 36 | train_data = format_dataset( 37 | dataset=MOCK_DATA, include_multi_turn_conversations=False 38 | ) 39 | 40 | train_data_row_two = train_data[1] 41 | mock_data_row_two = MOCK_DATA[1]['messages'] 42 | 43 | assert train_data_row_two[0]['instruction'] == mock_data_row_two[0]['content'] 44 | assert train_data_row_two[0]['output'] == mock_data_row_two[1]['content'] 45 | 46 | # Because we don't include multiturn, 47 | # the behaviour is that only 1 instruction/output pair is made, 48 | assert len(train_data_row_two) == 1 49 | # despite the dataset having more than 1 pair 50 | assert len(mock_data_row_two) > 2 51 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/llama_instruct.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import List, Optional, TypedDict 4 | from datasets import load_dataset, Dataset 5 | 6 | from lit_gpt.prompts import PromptStyle 7 | from lit_gpt.datamodules.base import DataModule 8 | from lit_gpt.datamodules.sft_multiturn_dataset_base import SFTMultiTurnDataset 9 | from lit_gpt.tokenizer import Tokenizer 10 | from lit_gpt.datamodules.typings.formatted_dataset import ( 11 | FormattedSFTDict, 12 | FormattedSFTMultiturnConversation, 13 | FormattedSFTMultiturnDataset, 14 | ) 15 | 16 | HUGGINGFACE_ID = 'togethercomputer/llama-instruct' 17 | DOWNLOAD_DIR = './data/sft/llama-instruct' 18 | SPLIT = ['train'] 19 | 20 | class PreparedLlamaInstruct(TypedDict): 21 | train_dataset: SFTMultiTurnDataset 22 | val_dataset: None 23 | test_dataset: None 24 | 25 | 26 | class LlamaInstructRow(TypedDict): 27 | text: str 28 | 29 | 30 | @dataclass 31 | class LlamaInstruct(DataModule): 32 | """LlamaInstruct data module for supervised finetuning.""" 33 | 34 | """Whether to include multi-turn conversations in the dataset.""" 35 | include_multiturn_conversations: bool = False 36 | 37 | """The directory in which the downloaded dataset gets saved.""" 38 | download_dir: Path = Path(DOWNLOAD_DIR) 39 | 40 | """The repo from where the data is downloaded""" 41 | repo_id: str = HUGGINGFACE_ID 42 | 43 | tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) 44 | # `batch_size` has no use anymore, as the dataloaders are moved outside 45 | # in order to support training with multiple SFT datasets at one go. 46 | batch_size: int = field(default=1, init=False, repr=False) 47 | max_seq_length: int = field(default=-1, init=False, repr=False) 48 | train_dataset: Optional[SFTMultiTurnDataset] = field(default=None, init=False, repr=False) 49 | test_dataset: Optional[SFTMultiTurnDataset] = field(default=None, init=False, repr=False) 50 | 51 | def __init__(self, prompt_style='chatml', num_workers: Optional[int] = None): 52 | super().__init__() 53 | self.prompt_style = PromptStyle.from_name(prompt_style) 54 | if num_workers: 55 | self.num_workers = num_workers 56 | 57 | def connect( 58 | self, 59 | tokenizer: Optional[Tokenizer] = None, 60 | batch_size: int = 1, 61 | max_seq_length: Optional[int] = None, 62 | ) -> None: 63 | self.tokenizer = tokenizer 64 | self.batch_size = batch_size 65 | self.max_seq_length = -1 if max_seq_length is None else max_seq_length 66 | 67 | def prepare_data(self) -> List[Dataset]: 68 | return load_dataset( 69 | self.repo_id, 70 | split=SPLIT, 71 | cache_dir=self.download_dir, 72 | streaming=False, 73 | ) 74 | 75 | def setup(self, stage: str = '') -> PreparedLlamaInstruct: 76 | dataset = self.prepare_data() 77 | 78 | train_data = format_dataset(dataset[0], self.include_multiturn_conversations) 79 | # TODO: WizardLMEvolInstructV2 doesn't have test_data. We can split it out from the train_data 80 | # test_data = format_dataset(dataset[1], self.include_multiturn_conversations) 81 | 82 | self.train_dataset = SFTMultiTurnDataset( 83 | data=train_data, 84 | tokenizer=self.tokenizer, 85 | prompt_style=self.prompt_style, 86 | max_seq_length=self.max_seq_length, 87 | mask_prompt=self.mask_prompt, 88 | ignore_index=self.ignore_index, 89 | ) 90 | 91 | return { 92 | 'train_dataset': self.train_dataset, 93 | 'val_dataset': self.test_dataset, 94 | 'test_dataset': self.test_dataset, 95 | } 96 | 97 | 98 | def format_dataset( 99 | dataset: List[LlamaInstructRow], 100 | include_multi_turn_conversations=bool, 101 | ) -> FormattedSFTMultiturnDataset: 102 | formatted: FormattedSFTMultiturnDataset = [] 103 | 104 | for entry in dataset: 105 | conversation = entry["text"] 106 | 107 | # It came with [INST] format. Turn all of it into a special token and split one time. 108 | split_turns = conversation.replace('[INST]', '<||SPLIT||>').replace('[/INST]', '<||SPLIT||>').split('<||SPLIT||>') 109 | 110 | # Strip the content of any spaces, turn into a flat list 111 | # `cleaned_split_turns` will contain pairs of alternating user and assistant content 112 | cleaned_split_turns: List[str] = [] 113 | for turn in split_turns: 114 | turn = turn.strip() 115 | if len(turn) < 1: 116 | continue 117 | 118 | cleaned_split_turns.append(turn) 119 | 120 | formatted_convo: FormattedSFTMultiturnConversation = [] 121 | for i in range(0, len(cleaned_split_turns) - 1, 2): 122 | formatted_sft_dict: FormattedSFTDict = { 123 | 'instruction': cleaned_split_turns[i], 124 | 'input': '', 125 | 'output': cleaned_split_turns[i+1] 126 | } 127 | formatted_convo.append(formatted_sft_dict) 128 | 129 | # If don't want to include multi turn, break after first 130 | # turn is appended: - no point including latter turns as 131 | # they become orphaned discussions without starting context 132 | if not include_multi_turn_conversations: 133 | break 134 | 135 | 136 | formatted.append(formatted_convo) 137 | 138 | return formatted 139 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/llama_instruct_test.py: -------------------------------------------------------------------------------- 1 | from json import loads 2 | from lit_gpt.datamodules.llama_instruct import format_dataset 3 | 4 | MOCK_DATA = [] 5 | 6 | with open( 7 | './lit_gpt/datamodules/llama_instruct_test.jsonl', 'r', encoding='utf-8' 8 | ) as jsonl_file: 9 | for line in jsonl_file: 10 | MOCK_DATA.append(loads(line)) 11 | 12 | 13 | def test_format_dataset_multiturn(): 14 | train_data = format_dataset( 15 | dataset=MOCK_DATA, include_multi_turn_conversations=True 16 | ) 17 | 18 | train_data_row_one = train_data[0] 19 | # Hard code cause if we do the same string manipulation process, it's just not testing the code. 20 | mock_data_row_one = { 21 | 'instruction': '''root@openvpn:/home/openvpn# ./openvpn-install.sh\nWelcome to OpenVPN-install!\nThe git repository is available at: https://github.com/angristan/openvpn-install\n\nIt looks like OpenVPN is already installed.\n\nWhat do you want to do?\n 1) Add a new user\n 2) Revoke existing user\n 3) Remove OpenVPN\n 4) Exit\nSelect an option [1-4]: 1\n\nTell me a name for the client.\nThe name must consist of alphanumeric character. It may also include an underscore or a dash.\nClient name: naam\n\nDo you want to protect the configuration file with a password?\n(e.g. encrypt the private key with a password)\n 1) Add a passwordless client\n 2) Use a password for the client\nSelect an option [1-2]: 1\n\nNote: using Easy-RSA configuration from: /etc/openvpn/easy-rsa/vars\nUsing SSL: openssl OpenSSL 3.0.2 15 Mar 2022 (Library: OpenSSL 3.0.2 15 Mar 2022)\n-----\nUsing configuration from /etc/openvpn/easy-rsa/pki/easy-rsa-54848.BT2FXv/tmp.dFLd6V\nEnter pass phrase for /etc/openvpn/easy-rsa/pki/private/ca.key:\nCheck that the request matches the signature\nSignature ok\nThe Subject's Distinguished Name is as follows\ncommonName :ASN.1 12:'naam'\nCertificate is to be certified until Apr 21 22:48:57 2025 GMT (825 days)\n\nWrite out database with 1 new entries\nData Base Updated\n\nClient naam added.\n\nThe configuration file has been written to /home/openvpn/naam.ovpn.\nDownload the .ovpn file and import it in your OpenVPN client.\nroot@openvpn:/home/openvpn# \n\n\nkan je op basis van deze gegevens een verslag maken hoe je iemand toevoegd op een openvpn server''', 22 | 'output': '''OpenVPN-installer shell script has been used to add a new user to an existing OpenVPN server. The script can be found on GitHub.\n\nThe user has been prompted to enter a name for the client, which must consist of alphanumeric characters and may include an underscore or a dash. The name \"naam\" has been entered.\n\nThe user has been asked whether to protect the configuration file with a password. The user has chosen not to add a password for the client.\n\nThe script has used the Easy-RSA configuration from /etc/openvpn/easy-rsa/vars and has generated a private key for the client using OpenSSL. The user has been prompted to enter a pass phrase for the private key.\n\nThe script has then generated a certificate for the client using the entered name and the private key. The certificate is valid until April 21, 2025.\n\nThe script has written the new entry to the database and has updated the configuration file. The client's configuration file has been written to /home/openvpn/naam.ovpn, and the user has been instructed to download the file and import it into their OpenVPN client.\n\nIn conclusion, the OpenVPN-installer script has been used to successfully add a new user to an existing OpenVPN server. The user has been prompted to enter a name for the client and has been given the option to protect the configuration file with a password. The script has generated a private key and certificate for the client and has updated the configuration file. The user can now download the configuration file and import it into their OpenVPN client.''' 23 | } 24 | 25 | # The instruction and output pair in train_data is made 26 | # from `content` from the nth and nth+1 row in `messages`` 27 | assert train_data_row_one[0]['instruction'] == mock_data_row_one['instruction'] 28 | assert train_data_row_one[0]['output'] == mock_data_row_one['output'] 29 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/meta_math_qa.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import List, Optional, Union, TypedDict 4 | from datasets import load_dataset, Dataset 5 | 6 | from lit_gpt.prompts import PromptStyle 7 | from lit_gpt.datamodules.base import DataModule 8 | from lit_gpt.datamodules.sft_dataset_base import SFTDataset 9 | from lit_gpt.tokenizer import Tokenizer 10 | from lit_gpt.datamodules.typings.formatted_dataset import ( 11 | FormattedSFTSingleturnConversation, 12 | FormattedSFTSingleturnDataset, 13 | ) 14 | 15 | HUGGINGFACE_ID = 'meta-math/MetaMathQA' 16 | DOWNLOAD_DIR = './data/sft/meta-math-qa' 17 | SPLIT = ['train'] 18 | 19 | 20 | class PreparedMetaMathQA(TypedDict): 21 | train_dataset: SFTDataset 22 | val_dataset: None 23 | test_dataset: None 24 | 25 | 26 | class MetaMathQARow(TypedDict): 27 | type: str 28 | query: str 29 | original_question: str 30 | response: str 31 | 32 | 33 | @dataclass 34 | class MetaMathQA(DataModule): 35 | """MetaMathQA data module for supervised finetuning.""" 36 | 37 | """Whether to include multi-turn conversations in the dataset.""" 38 | include_multiturn_conversations: bool = False 39 | 40 | """The directory in which the downloaded dataset gets saved.""" 41 | download_dir: Path = Path(DOWNLOAD_DIR) 42 | 43 | """The repo from where the data is downloaded""" 44 | repo_id: str = HUGGINGFACE_ID 45 | 46 | tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) 47 | batch_size: int = field(default=1, init=False, repr=False) 48 | max_seq_length: int = field(default=-1, init=False, repr=False) 49 | train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) 50 | test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) 51 | 52 | def __init__(self, prompt_style='chatml', num_workers: Optional[int] = None): 53 | super().__init__() 54 | self.prompt_style = PromptStyle.from_name(prompt_style) 55 | if num_workers: 56 | self.num_workers = num_workers 57 | 58 | def connect( 59 | self, 60 | tokenizer: Optional[Tokenizer] = None, 61 | batch_size: int = 1, 62 | max_seq_length: Optional[int] = None, 63 | ) -> None: 64 | self.tokenizer = tokenizer 65 | self.batch_size = batch_size 66 | self.max_seq_length = -1 if max_seq_length is None else max_seq_length 67 | 68 | def prepare_data(self) -> List[Dataset]: 69 | return load_dataset( 70 | self.repo_id, 71 | split=SPLIT, 72 | cache_dir=self.download_dir, 73 | streaming=False, 74 | ) 75 | 76 | def setup(self, stage: str = '') -> PreparedMetaMathQA: 77 | dataset = self.prepare_data() 78 | 79 | train_data = format_dataset(dataset[0], self.include_multiturn_conversations) 80 | # TODO: MetaMathQA doesn't have test_data. We can split it out from the train_data 81 | # test_data = format_dataset(dataset[1], self.include_multiturn_conversations) 82 | 83 | self.train_dataset = SFTDataset( 84 | data=train_data, 85 | tokenizer=self.tokenizer, 86 | prompt_style=self.prompt_style, 87 | max_seq_length=self.max_seq_length, 88 | mask_prompt=self.mask_prompt, 89 | ignore_index=self.ignore_index, 90 | ) 91 | # self.test_dataset = SFTDataset( 92 | # data=test_data, 93 | # tokenizer=self.tokenizer, 94 | # prompt_style=self.prompt_style, 95 | # max_seq_length=self.max_seq_length, 96 | # mask_prompt=self.mask_prompt, 97 | # ignore_index=self.ignore_index, 98 | # ) 99 | 100 | return { 101 | 'train_dataset': self.train_dataset, 102 | 'val_dataset': self.test_dataset, 103 | 'test_dataset': self.test_dataset, 104 | } 105 | 106 | 107 | def format_dataset( 108 | dataset: List[MetaMathQARow], 109 | # `include_multi_turn_conversations` kept for backward compatibility with litgpt 110 | include_multi_turn_conversations=False, 111 | ) -> FormattedSFTSingleturnDataset: 112 | formatted: FormattedSFTSingleturnDataset = [] 113 | 114 | for entry in dataset: 115 | formatted_sft_dict: FormattedSFTSingleturnConversation = { 116 | 'instruction': entry['query'], 117 | 'input': '', 118 | 'output': entry['response'], 119 | } 120 | formatted.append(formatted_sft_dict) 121 | 122 | return formatted 123 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/meta_math_qa_e2e_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from datasets import load_dataset 3 | from torch.utils.data import DataLoader 4 | 5 | from lit_gpt.datamodules.meta_math_qa import MetaMathQA 6 | from lit_gpt.utils import CycleIterator 7 | from lit_gpt.tokenizer import Tokenizer 8 | from lit_gpt.datamodules.base import get_sft_collate_fn 9 | 10 | tokenizer = Tokenizer(Path('tokenizer/pints')) 11 | 12 | 13 | def test_metamath_e2e(): 14 | batch_size = 2 15 | 16 | metamath = MetaMathQA(prompt_style='chatml') 17 | metamath.tokenizer = tokenizer 18 | metamath.include_multiturn_conversations = True 19 | 20 | # overwrite the prepare_data method to load custom 21 | def mock_prepare_data(): 22 | # We do it like this to make it return a (train, test) tuple that Deita has. 23 | mock_train = load_dataset( 24 | 'json', 25 | # Even though jsonl file doesn't have a split, `train` is needed by default 26 | split=['train'], 27 | data_files='lit_gpt/datamodules/meta_math_qa_test.jsonl', 28 | ) 29 | 30 | mock_test = load_dataset( 31 | 'json', 32 | # Even though jsonl file doesn't have a split, `train` is needed by default 33 | split=['train'], 34 | data_files='lit_gpt/datamodules/meta_math_qa_test.jsonl', 35 | ) 36 | return (mock_train[0], mock_test[0]) 37 | 38 | metamath.prepare_data = mock_prepare_data 39 | 40 | data = metamath.setup() 41 | 42 | train_dataloader = DataLoader( 43 | data['train_dataset'], 44 | batch_size=batch_size, 45 | collate_fn=get_sft_collate_fn(2048), 46 | ) 47 | 48 | deita_iterator = CycleIterator(train_dataloader) 49 | 50 | iter_batch = next(deita_iterator) 51 | 52 | # Check that we plucked out according to batch_size 53 | assert len(iter_batch['input_ids']) == batch_size 54 | assert len(iter_batch['labels']) == batch_size 55 | 56 | inputs_ids = iter_batch['input_ids'][1] 57 | inputs_ids[inputs_ids == -100] = 0 58 | 59 | decoded_input = tokenizer.decode(inputs_ids, skip_special_tokens=False) 60 | 61 | expected_input = """<|im_start|> user 62 | What is the total cost of purchasing equipment for all sixteen players on the football team, considering that each player requires a $25 jersey, a $15.20 pair of shorts, and a pair of socks priced at $6.80?<|im_end|> 63 | <|im_start|> assistant 64 | Each player requires a $25 jersey, a $15.20 pair of shorts, and a pair of socks priced at $6.80. 65 | So the total cost for each player is $25 + $15.20 + $6.80 = $47. 66 | Since there are sixteen players on the football team, the total cost for all of them is 16 * $47 = $752. 67 | #### 752 68 | The answer is: 752<|im_end|>""" 69 | 70 | labels = iter_batch['labels'][1] 71 | labels[labels == -100] = 0 72 | decoded_labels = tokenizer.decode( 73 | iter_batch['labels'][1], skip_special_tokens=False 74 | ) 75 | 76 | expected_labels = """Each player requires a $25 jersey, a $15.20 pair of shorts, and a pair of socks priced at $6.80. 77 | So the total cost for each player is $25 + $15.20 + $6.80 = $47. 78 | Since there are sixteen players on the football team, the total cost for all of them is 16 * $47 = $752. 79 | #### 752 80 | The answer is: 752<|im_end|>""" 81 | 82 | print('=' * 80) 83 | print(decoded_input) 84 | print('=' * 80) 85 | print(decoded_labels) 86 | print('=' * 80) 87 | 88 | assert expected_input == decoded_input 89 | assert expected_labels == decoded_labels 90 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/meta_math_qa_test.py: -------------------------------------------------------------------------------- 1 | from json import loads 2 | from lit_gpt.datamodules.meta_math_qa import format_dataset 3 | 4 | MOCK_DATA = [] 5 | 6 | with open( 7 | './lit_gpt/datamodules/meta_math_qa_test.jsonl', 'r', encoding='utf-8' 8 | ) as jsonl_file: 9 | for line in jsonl_file: 10 | MOCK_DATA.append(loads(line)) 11 | 12 | 13 | def test_metamath_format_dataset(): 14 | train_data = format_dataset(dataset=MOCK_DATA) 15 | 16 | train_data_row_one = train_data[0] 17 | mock_data_row_one = MOCK_DATA[0] 18 | 19 | # The instruction and output pair in train_data is made 20 | # from `content` from the nth and nth+1 row in `messages`` 21 | assert mock_data_row_one['query'] == train_data_row_one['instruction'] 22 | assert mock_data_row_one['response'] == train_data_row_one['output'] 23 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/sft_dataset_base.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Union, Optional, Callable, Any 2 | from torch import int64 3 | from torch.utils.data import Dataset 4 | from lit_gpt.prompts import PromptStyle 5 | from lit_gpt.tokenizer import Tokenizer 6 | from lit_gpt.datamodules.typings.base import SFTDatasetItem 7 | 8 | class SFTDataset(Dataset): 9 | """An in-memory dataset for supervised finetuning with `input_ids` and `labels`. 10 | 11 | Args: 12 | data: A list of samples (dicts). The target/label must be stored under the key 'output' and the instruction 13 | or other data can be stored under any key as long as it is compatible with the given prompt template. 14 | tokenizer: The tokenizer to use. Should match the one that was used to pretrain the model. 15 | prompt_style: The style to apply to prompts. See `litgpt.prompts` for a list of available styles. 16 | max_seq_length: Truncate sequences that are longer than this value. By default, no truncation is applied. 17 | mask_prompt: Whether to mask the prompt section from the label (with ``ignore_index``). 18 | ignore_index: The index to use for elements to be ignored in the label. 19 | transform: An optional transform to apply to the sample before it gets tokenized. Use this to rename the 20 | keys in the dataset to the expected 'instruction' and 'output' keys. 21 | 22 | Returns a dict with two keys: 23 | input_ids: The encoded prompt + response 24 | labels: Same as input_ids, unless ``mask_prompt=True`` in which case the 'prompt' part is replaced with 25 | the ``ignore_index``. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | data: List[Dict[str, str]], 31 | tokenizer: Tokenizer, 32 | prompt_style: Union[str, PromptStyle], 33 | max_seq_length: int = -1, 34 | mask_prompt: bool = True, 35 | ignore_index: int = -100, 36 | transform: Optional[Callable[[Any], Any]] = None, 37 | ) -> None: 38 | self.data = data 39 | self.tokenizer = tokenizer 40 | self.prompt_style = ( 41 | prompt_style 42 | if isinstance(prompt_style, PromptStyle) 43 | else PromptStyle.from_name(prompt_style) 44 | ) 45 | self.max_seq_length = max_seq_length 46 | self.mask_prompt = mask_prompt 47 | self.ignore_index = ignore_index 48 | self.transform = transform 49 | 50 | def __len__(self) -> int: 51 | return len(self.data) 52 | 53 | def __getitem__(self, idx: int) -> SFTDatasetItem: 54 | example = self.data[idx] 55 | if self.transform is not None: 56 | example = self.transform(example) 57 | prompt = self.prompt_style.apply(prompt=example['instruction'], **example) 58 | prompt_and_response = prompt + example['output'] 59 | encoded_prompt = self.tokenizer.encode( 60 | prompt, bos=True, max_length=self.max_seq_length 61 | ) 62 | encoded_prompt_and_response = self.tokenizer.encode( 63 | prompt_and_response, bos=True, eos=True, max_length=self.max_seq_length 64 | ) 65 | 66 | # The labels are the full prompt with response, but with the prompt masked out 67 | labels = encoded_prompt_and_response.clone() 68 | if self.mask_prompt: 69 | labels[: len(encoded_prompt)] = self.ignore_index 70 | 71 | return { 72 | 'input_ids': encoded_prompt_and_response.type(int64), 73 | 'labels': labels.type(int64), 74 | } -------------------------------------------------------------------------------- /lit_gpt/datamodules/sft_multiturn_dataset_base_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from typing import List 3 | from torch import tensor, concat, equal 4 | from lit_gpt.datamodules.sft_multiturn_dataset_base import ( 5 | SFTMultiTurnDataset, 6 | MultiTurnDataRow, 7 | ) 8 | 9 | mocked_data: List[List[MultiTurnDataRow]] = [ 10 | [ 11 | {'instruction': 'Foo', 'input': '', 'output': 'Bar'}, 12 | {'instruction': 'Foo2', 'input': '', 'output': 'Bar2'}, 13 | ], 14 | [ 15 | { 16 | 'instruction': 'Foo3', 17 | 'input': '', 18 | 'output': 'Bar3', 19 | }, 20 | ], 21 | ] 22 | 23 | mocked_data_tokenized = concat( 24 | ( 25 | tensor([1, 60, 124, 105, 109, 95, 115, 116, 97, 114, 116, 124, 62, 117]), 26 | tensor([115, 101, 114, 10, 70, 111, 111, 60, 124, 105, 109, 95, 101, 110]), 27 | tensor([100, 124, 62, 10, 60, 124, 105, 109, 95, 115, 116, 97, 114, 116]), 28 | tensor([124, 62, 97, 115, 115, 105, 115, 116, 97, 110, 116, 10, 66, 97]), 29 | tensor([114, 2, 10, 60, 124, 105, 109, 95, 115, 116, 97, 114, 116, 124]), 30 | tensor([62, 117, 115, 101, 114, 10, 70, 111, 111, 50, 60, 124, 105, 109]), 31 | tensor([95, 101, 110, 100, 124, 62, 10, 60, 124, 105, 109, 95, 115, 116]), 32 | tensor([97, 114, 116, 124, 62, 97, 115, 115, 105, 115, 116, 97, 110, 116]), 33 | tensor([10, 66, 97, 114, 50, 2]), 34 | ) 35 | ) 36 | 37 | 38 | def mocked_data_labels(ignore_index=-100): 39 | i = ignore_index 40 | return concat( 41 | ( 42 | tensor([i, i, i, i, i, i, i, i, i, i, i, i]), 43 | tensor([i, i, i, i, i, i, i, i, i, i, i, i]), 44 | tensor([i, i, i, i, i, i, i, i, i, i, i, i]), 45 | tensor([i, i, i, i, i, i, i, i, i, i, i, i]), 46 | tensor([i, i, i, i, i, i, 66, 97, 114, 2, i, i]), 47 | tensor([i, i, i, i, i, i, i, i, i, i, i, i]), 48 | tensor([i, i, i, i, i, i, i, i, i, i, i, i]), 49 | tensor([i, i, i, i, i, i, i, i, i, i, i, i]), 50 | tensor([i, i, i, i, i, i, i, i, i, i, i, i]), 51 | tensor([i, i, i, i, i, 66, 97, 114, 50, 2]), 52 | ) 53 | ) 54 | 55 | 56 | @pytest.mark.parametrize('mask_prompt', [True, False]) 57 | @pytest.mark.parametrize('ignore_index', [-1, -100]) 58 | @pytest.mark.parametrize('max_seq_length', [1000]) 59 | def test_getitem(max_seq_length, ignore_index, mask_prompt, MOCK_TOKENIZER_FIXTURE): 60 | dataset = SFTMultiTurnDataset( 61 | data=mocked_data, 62 | tokenizer=MOCK_TOKENIZER_FIXTURE, 63 | prompt_style='chatml', 64 | mask_prompt=mask_prompt, 65 | ignore_index=ignore_index, 66 | max_seq_length=max_seq_length, 67 | ) 68 | 69 | assert len(dataset) == len(mocked_data) 70 | 71 | expected_labels = ( 72 | mocked_data_labels(ignore_index) if mask_prompt else mocked_data_tokenized 73 | ) 74 | 75 | first_set = dataset[0] 76 | 77 | assert equal(first_set['input_ids'], mocked_data_tokenized[:max_seq_length]) 78 | assert equal(first_set['labels'], expected_labels) 79 | 80 | 81 | @pytest.mark.parametrize('mask_prompt', [True, False]) 82 | @pytest.mark.parametrize('ignore_index', [-1, -100]) 83 | @pytest.mark.parametrize('max_seq_length', [5]) 84 | def test_getitem_hit_max_seq_length( 85 | max_seq_length, ignore_index, mask_prompt, MOCK_TOKENIZER_FIXTURE 86 | ): 87 | dataset = SFTMultiTurnDataset( 88 | data=mocked_data, 89 | tokenizer=MOCK_TOKENIZER_FIXTURE, 90 | prompt_style='chatml', 91 | mask_prompt=mask_prompt, 92 | ignore_index=ignore_index, 93 | max_seq_length=max_seq_length, 94 | ) 95 | 96 | assert len(dataset) == len(mocked_data) 97 | 98 | eos = tensor([MOCK_TOKENIZER_FIXTURE.eos_id]) 99 | 100 | expected_labels = ( 101 | mocked_data_labels(ignore_index) if mask_prompt else mocked_data_tokenized 102 | ) 103 | 104 | first_set = dataset[0] 105 | 106 | assert equal( 107 | first_set['input_ids'], 108 | concat((mocked_data_tokenized[: max_seq_length - 1], eos)), 109 | ) 110 | 111 | # If the prompt is not masked, we won't see the eos token 112 | # Although we should, but this is too trival a problem to fix 113 | # And it is more likely to occur for training very low context window models, 114 | # using dataset with high tokens length or without filtering them, 115 | # i.e, context window 2k, but dataset often has prompts (not even response yet) exceeding that. 116 | if mask_prompt: 117 | assert equal(first_set['labels'], expected_labels[:max_seq_length]) 118 | else: 119 | assert equal( 120 | first_set['labels'], 121 | concat((expected_labels[: max_seq_length - 1], eos)), 122 | ) 123 | 124 | 125 | def test_getitem_check_labels(MOCK_TOKENIZER_FIXTURE): 126 | dataset = SFTMultiTurnDataset( 127 | data=mocked_data, 128 | tokenizer=MOCK_TOKENIZER_FIXTURE, 129 | prompt_style='chatml', 130 | ) 131 | 132 | labels_from_first_row = dataset[0]['labels'] 133 | labels_from_first_row_decoded: str = MOCK_TOKENIZER_FIXTURE.decode( 134 | labels_from_first_row 135 | ) 136 | 137 | print(labels_from_first_row_decoded.replace('', '')) 138 | # Removing away represents tokens that the model is trained on 139 | # Again, this is different from attention. All tokens in the sequence are attended to. 140 | # Therefore, we should only have a concatenated string of responses: 141 | assert ( 142 | labels_from_first_row_decoded.replace('', '') 143 | == 'BarBar2' # is correct and not <|im_end|> because tokenizer is mocked 144 | ) 145 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/slim-orca-idontknow_test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/lit_gpt/datamodules/slim-orca-idontknow_test.parquet -------------------------------------------------------------------------------- /lit_gpt/datamodules/slim_orca_dedup.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import List, Optional, Union, TypedDict, Literal 4 | from datasets import load_dataset, Dataset 5 | 6 | from lit_gpt.prompts import PromptStyle 7 | from lit_gpt.datamodules.base import DataModule 8 | from lit_gpt.datamodules.sft_dataset_base import SFTDataset 9 | from lit_gpt.tokenizer import Tokenizer 10 | from lit_gpt.datamodules.typings.formatted_dataset import ( 11 | FormattedSFTSingleturnConversation, 12 | FormattedSFTSingleturnDataset, 13 | ) 14 | 15 | HUGGINGFACE_ID = 'Open-Orca/SlimOrca-Dedup' 16 | DOWNLOAD_DIR = './data/sft/slimorca-dedup' 17 | SPLIT = ['train'] 18 | 19 | class PreparedSlimOrcaDedup(TypedDict): 20 | train_dataset: SFTDataset 21 | val_dataset: None 22 | test_dataset: None 23 | 24 | 25 | 26 | # Use declarative syntax as 'from' will not be allowed using class syntax 27 | OpenOrcaConversationTurn = TypedDict( 28 | 'OpenOrcaConversation', { 29 | 'from': Literal['system', 'human', 'gpt'], 'value': str} 30 | ) 31 | 32 | 33 | class SlimOrcaDedupRow(TypedDict): 34 | conversations: List[OpenOrcaConversationTurn] 35 | 36 | 37 | @dataclass 38 | class SlimOrcaDedup(DataModule): 39 | """SlimOrcaDedup data module for supervised finetuning.""" 40 | 41 | """Whether to include multi-turn conversations in the dataset.""" 42 | include_multiturn_conversations: bool = False 43 | 44 | """The directory in which the downloaded dataset gets saved.""" 45 | download_dir: Path = Path(DOWNLOAD_DIR) 46 | 47 | """The repo from where the data is downloaded""" 48 | repo_id: str = HUGGINGFACE_ID 49 | 50 | tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) 51 | batch_size: int = field(default=1, init=False, repr=False) 52 | max_seq_length: int = field(default=-1, init=False, repr=False) 53 | train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) 54 | test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) 55 | 56 | def __init__(self, prompt_style='chatml', num_workers: Optional[int] = None): 57 | super().__init__() 58 | self.prompt_style = PromptStyle.from_name(prompt_style) 59 | if num_workers: 60 | self.num_workers = num_workers 61 | 62 | def connect( 63 | self, 64 | tokenizer: Optional[Tokenizer] = None, 65 | batch_size: int = 1, 66 | max_seq_length: Optional[int] = None, 67 | ) -> None: 68 | self.tokenizer = tokenizer 69 | self.batch_size = batch_size 70 | self.max_seq_length = -1 if max_seq_length is None else max_seq_length 71 | 72 | def prepare_data(self) -> List[Dataset]: 73 | return load_dataset( 74 | self.repo_id, 75 | split=SPLIT, 76 | streaming=False, 77 | ) 78 | 79 | def setup(self, stage: str = '') -> PreparedSlimOrcaDedup: 80 | dataset = self.prepare_data() 81 | train_data = format_dataset( 82 | dataset[0], self.include_multiturn_conversations) 83 | # TODO: SlimOrcaDedup doesn't have test_data. We can split it out from the train_data 84 | # test_data = format_dataset(dataset[1], self.include_multiturn_conversations) 85 | 86 | self.train_dataset = SFTDataset( 87 | data=train_data, 88 | tokenizer=self.tokenizer, 89 | prompt_style=self.prompt_style, 90 | max_seq_length=self.max_seq_length, 91 | mask_prompt=self.mask_prompt, 92 | ignore_index=self.ignore_index, 93 | ) 94 | 95 | return { 96 | 'train_dataset': self.train_dataset, 97 | 'val_dataset': self.test_dataset, 98 | 'test_dataset': self.test_dataset, 99 | } 100 | 101 | 102 | def format_dataset( 103 | dataset: List[SlimOrcaDedupRow], 104 | # `include_multi_turn_conversations` kept for backward compatibility with litgpt 105 | include_multi_turn_conversations: False, 106 | ) -> FormattedSFTSingleturnDataset: 107 | formatted: FormattedSFTSingleturnDataset = [] 108 | 109 | for entry in dataset: 110 | conversation = entry['conversations'] 111 | 112 | # NOTE: We are not training with system message. This is a deliberate decision. 113 | # We find that system messages generally weaken the model, and usually can be part of the instruction. 114 | # Additionally, the current litgpt instruction/input/output doesn't cater for system messages 115 | 116 | # system_message = conversation[0] 117 | # if system_message['from'] != 'system': 118 | # print( 119 | # f'WARN: A Slim orca row is corrupted. Expected role to be `user`, but is `{conversation[i]["from"]}` instead.' 120 | # ) 121 | 122 | # Start from index 1, which should be human message 123 | human_message = conversation[1] 124 | if human_message['from'] != 'human': 125 | print( 126 | f'WARN: A SlimOrca row is corrupted. Expected role to be `user`, but is `{human_message["from"]}` instead.' 127 | ) 128 | 129 | ai_message = conversation[2] 130 | if ai_message['from'] != 'gpt': 131 | print( 132 | f'WARN: A SlimOrca row is corrupted. Expected role to be `assistant`, but is `{ai_message["from"]}` instead.' 133 | ) 134 | 135 | formatted_sft_dict: FormattedSFTSingleturnConversation = { 136 | 'instruction': human_message['value'], 137 | 'input': '', 138 | 'output': ai_message['value'], 139 | } 140 | 141 | formatted.append(formatted_sft_dict) 142 | 143 | return formatted 144 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/slim_orca_dedup_test.py: -------------------------------------------------------------------------------- 1 | from json import loads 2 | from lit_gpt.datamodules.slim_orca_dedup import format_dataset 3 | 4 | MOCK_DATA = [] 5 | with open( 6 | './lit_gpt/datamodules/slim_orca_dedup_test.jsonl', 'r', encoding='utf-8' 7 | ) as jsonl_file: 8 | for line in jsonl_file: 9 | MOCK_DATA.append(loads(line)) 10 | 11 | 12 | def test_format_dataset(): 13 | train_data = format_dataset( 14 | dataset=MOCK_DATA, include_multi_turn_conversations=None 15 | ) 16 | 17 | train_data_row_one = train_data[0] 18 | mock_data_row_one = MOCK_DATA[0]['conversations'] 19 | 20 | # The instruction and output pair in train_data is made 21 | # from `content` from the nth and nth+1 row in `messages`` 22 | assert mock_data_row_one[1]['value'] == train_data_row_one['instruction'] 23 | assert mock_data_row_one[2]['value'] == train_data_row_one['output'] 24 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/slim_orca_idontknow.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import List, Optional, Union, TypedDict, Literal 4 | from datasets import load_dataset, Dataset 5 | 6 | from lit_gpt.prompts import PromptStyle 7 | from lit_gpt.datamodules.base import DataModule 8 | from lit_gpt.datamodules.sft_dataset_base import SFTDataset 9 | from lit_gpt.tokenizer import Tokenizer 10 | from lit_gpt.datamodules.typings.formatted_dataset import ( 11 | FormattedSFTSingleturnConversation, 12 | FormattedSFTSingleturnDataset, 13 | ) 14 | 15 | DATA_FILES_PATH = '.data/sft/slim-orca-idontknow/slim-orca-idontknow.parquet' 16 | DOWNLOAD_DIR = './data/sft/slim-orca-idontknow' 17 | SPLIT = ['train'] 18 | 19 | class PreparedSlimOrcaIDK(TypedDict): 20 | train_dataset: SFTDataset 21 | val_dataset: None 22 | test_dataset: None 23 | 24 | 25 | # Use declarative syntax as 'from' will not be allowed using class syntax 26 | OpenOrcaConversationTurn = TypedDict( 27 | 'OpenOrcaConversation', {'from': Literal['system', 'human', 'gpt'], 'value': str} 28 | ) 29 | 30 | 31 | class SlimOrcaIDKRow(TypedDict): 32 | conversations: List[OpenOrcaConversationTurn] 33 | 34 | 35 | @dataclass 36 | class SlimOrcaIDK(DataModule): 37 | """SlimOrcaIDK data module for supervised finetuning.""" 38 | 39 | """Whether to include multi-turn conversations in the dataset.""" 40 | include_multiturn_conversations: bool = False 41 | 42 | """The directory in which the downloaded dataset gets saved.""" 43 | download_dir: Path = Path(DOWNLOAD_DIR) 44 | 45 | """The repo from where the data is downloaded""" 46 | repo_id: str = DATA_FILES_PATH 47 | 48 | data_files_path: str = DATA_FILES_PATH 49 | 50 | tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) 51 | batch_size: int = field(default=1, init=False, repr=False) 52 | max_seq_length: int = field(default=-1, init=False, repr=False) 53 | train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) 54 | test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) 55 | 56 | def __init__( 57 | self, 58 | prompt_style='chatml', 59 | data_files_path='data/slim_orca_idontknow/slim-orca-idontknow.parquet', 60 | ): 61 | super().__init__() 62 | self.prompt_style = PromptStyle.from_name(prompt_style) 63 | self.data_files_path = self.repo_id = data_files_path 64 | 65 | def connect( 66 | self, 67 | tokenizer: Optional[Tokenizer] = None, 68 | batch_size: int = 1, 69 | max_seq_length: Optional[int] = None, 70 | ) -> None: 71 | self.tokenizer = tokenizer 72 | self.batch_size = batch_size 73 | self.max_seq_length = -1 if max_seq_length is None else max_seq_length 74 | 75 | def prepare_data(self) -> List[Dataset]: 76 | return load_dataset( 77 | 'parquet', 78 | split=SPLIT, 79 | data_files=self.repo_id, 80 | streaming=False, 81 | ) 82 | 83 | def setup(self, stage: str = '') -> PreparedSlimOrcaIDK: 84 | dataset = self.prepare_data() 85 | train_data = format_dataset(dataset[0], self.include_multiturn_conversations) 86 | # TODO: SlimOrcaIDK doesn't have test_data. We can split it out from the train_data 87 | # test_data = format_dataset(dataset[1], self.include_multiturn_conversations) 88 | 89 | self.train_dataset = SFTDataset( 90 | data=train_data, 91 | tokenizer=self.tokenizer, 92 | prompt_style=self.prompt_style, 93 | max_seq_length=self.max_seq_length, 94 | mask_prompt=self.mask_prompt, 95 | ignore_index=self.ignore_index, 96 | ) 97 | # self.test_dataset = SFTDataset( 98 | # data=test_data, 99 | # tokenizer=self.tokenizer, 100 | # prompt_style=self.prompt_style, 101 | # max_seq_length=self.max_seq_length, 102 | # mask_prompt=self.mask_prompt, 103 | # ignore_index=self.ignore_index, 104 | # ) 105 | 106 | return { 107 | 'train_dataset': self.train_dataset, 108 | 'val_dataset': self.test_dataset, 109 | 'test_dataset': self.test_dataset, 110 | } 111 | 112 | 113 | def format_dataset( 114 | dataset: List[SlimOrcaIDKRow], 115 | # `include_multi_turn_conversations` kept for backward compatibility with litgpt 116 | include_multi_turn_conversations: False, 117 | ) -> FormattedSFTSingleturnDataset: 118 | formatted: FormattedSFTSingleturnDataset = [] 119 | 120 | for entry in dataset: 121 | conversation = entry['conversations'] 122 | 123 | # NOTE: We are not training with system message. This is a deliberate decision. 124 | # We find that system messages generally weaken the model, and usually can be part of the instruction. 125 | # Additionally, the current litgpt instruction/input/output doesn't cater for system messages 126 | 127 | # system_message = conversation[0] 128 | # if system_message['from'] != 'system': 129 | # print( 130 | # f'WARN: A Slim orca row is corrupted. Expected role to be `user`, but is `{conversation[i]["from"]}` instead.' 131 | # ) 132 | 133 | # Start from index 1, which should be human message 134 | human_message = conversation[1] 135 | if human_message['from'] != 'human': 136 | print( 137 | f'WARN: A Slim orca row is corrupted. Expected role to be `user`, but is `{human_message["from"]}` instead.' 138 | ) 139 | 140 | ai_message = conversation[2] 141 | if ai_message['from'] != 'gpt': 142 | print( 143 | f'WARN: A Slim orca row is corrupted. Expected role to be `assistant`, but is `{ai_message["from"]}` instead.' 144 | ) 145 | 146 | formatted_sft_dict: FormattedSFTSingleturnConversation = { 147 | 'instruction': human_message['value'], 148 | 'input': '', 149 | 'output': ai_message['value'], 150 | } 151 | 152 | formatted.append(formatted_sft_dict) 153 | 154 | return formatted 155 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/slim_orca_idontknow_e2e_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from torch.utils.data import DataLoader 3 | 4 | from lit_gpt.datamodules.slim_orca_idontknow import SlimOrcaIDK 5 | from lit_gpt.utils import CycleIterator 6 | from lit_gpt.tokenizer import Tokenizer 7 | from lit_gpt.datamodules.base import get_sft_collate_fn 8 | 9 | tokenizer = Tokenizer(Path('tokenizer/pints_chat')) 10 | 11 | 12 | def test_slim_orca_idk_e2e(): 13 | batch_size = 2 14 | 15 | slim_orca_idk = SlimOrcaIDK( 16 | prompt_style='chatml', 17 | data_files_path='lit_gpt/datamodules/slim-orca-idontknow_test.parquet', 18 | ) 19 | slim_orca_idk.tokenizer = tokenizer 20 | slim_orca_idk.include_multiturn_conversations = True 21 | 22 | data = slim_orca_idk.setup() 23 | 24 | train_dataloader = DataLoader( 25 | data['train_dataset'], 26 | batch_size=batch_size, 27 | collate_fn=get_sft_collate_fn(2048), 28 | ) 29 | 30 | slim_orca_idk_iterator = CycleIterator(train_dataloader) 31 | 32 | iter_batch = next(slim_orca_idk_iterator) 33 | 34 | # Check that we plucked out according to batch_size 35 | assert len(iter_batch['input_ids']) == batch_size 36 | assert len(iter_batch['labels']) == batch_size 37 | 38 | decoded_input = tokenizer.decode( 39 | iter_batch['input_ids'][1], skip_special_tokens=False 40 | ) 41 | 42 | expected_input = """<|im_start|> user 43 | Chanakya, 4th Century BC Indian political philosopher. The Arthashastra provides an account of the science of politics for a wise ruler, policies for foreign affairs and wars, the system of a spy state and surveillance and economic stability of the state. Chanakya quotes several authorities including Bruhaspati, Ushanas, Prachetasa Manu, Parasara, and Ambi, and described himself as a descendant of a lineage of political philosophers, with his father Chanaka being his immediate predecessor. Another influential extant Indian treatise on political philosophy is the Sukra Neeti. An example of a code of law in ancient India is the Manusmṛti or Laws of Manu. 44 | Where is Sukra Neeti an example of a code of law? (If the question is unanswerable, say "unanswerable")<|im_end|> 45 | <|im_start|> assistant 46 | Unanswerable. The Sukra Neeti is not an example of a code of law but rather an influential Indian treatise on political philosophy. The example of a code of law in ancient India mentioned is the Manusmṛti or Laws of Manu.<|im_end|>""" 47 | 48 | labels = iter_batch['labels'][1] 49 | labels[labels == -100] = 0 50 | decoded_labels = tokenizer.decode( 51 | iter_batch['labels'][1], skip_special_tokens=False 52 | ) 53 | 54 | expected_labels = """Unanswerable. The Sukra Neeti is not an example of a code of law but rather an influential Indian treatise on political philosophy. The example of a code of law in ancient India mentioned is the Manusmṛti or Laws of Manu.<|im_end|>""" 55 | 56 | assert expected_input == decoded_input 57 | assert expected_labels == decoded_labels 58 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/slim_orca_idontknow_test.py: -------------------------------------------------------------------------------- 1 | from lit_gpt.datamodules.slim_orca_idontknow import format_dataset, SlimOrcaIDK 2 | 3 | slim_orca_idk = SlimOrcaIDK( 4 | data_files_path='lit_gpt/datamodules/slim-orca-idontknow_test.parquet' 5 | ) 6 | slim_orca_idk_dataset = slim_orca_idk.prepare_data() 7 | MOCK_DATA = slim_orca_idk_dataset[0] 8 | 9 | 10 | def test_slim_orca_idontknow_format_dataset(): 11 | train_data = format_dataset( 12 | dataset=MOCK_DATA, include_multi_turn_conversations=None 13 | ) 14 | 15 | train_data_row_one = train_data[0] 16 | mock_data_row_one = MOCK_DATA[0]['conversations'] 17 | 18 | # The instruction and output pair in train_data is made 19 | # from `content` from the nth and nth+1 row in `messages`` 20 | assert mock_data_row_one[1]['value'] == train_data_row_one['instruction'] 21 | assert mock_data_row_one[2]['value'] == train_data_row_one['output'] 22 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/typings/base.py: -------------------------------------------------------------------------------- 1 | from typing import TypedDict 2 | from torch import Tensor 3 | 4 | 5 | class SFTDatasetItem(TypedDict): 6 | input_ids: Tensor # 1-dimensionality in y direction 7 | labels: Tensor 8 | 9 | 10 | class SFTCollatedBatch(TypedDict): 11 | input_ids: Tensor # n-dimensionality in y direction, where n is batch size 12 | labels: Tensor 13 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/typings/formatted_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, TypedDict 2 | 3 | 4 | class FormattedSFTDict(TypedDict): 5 | instruction: str 6 | input: str 7 | output: str 8 | 9 | 10 | # In a single turn dataset, each instruction, input, and output tuples 11 | # is a complete conversation discussing the same subject. 12 | FormattedSFTSingleturnConversation = FormattedSFTDict 13 | FormattedSFTSingleturnDataset = List[FormattedSFTDict] 14 | 15 | # A `Conversation` comprise multiple instruction, input, and output tuples. 16 | # Everything inside of one `Converstion` is discussing the same subject 17 | FormattedSFTMultiturnConversation = List[FormattedSFTDict] 18 | # Therefore, a multiturn dataset will comprise multiple `Conversations` 19 | # each discussing a different subject. 20 | FormattedSFTMultiturnDataset = List[List[FormattedSFTDict]] 21 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/typings/prepared_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Union, TypedDict 2 | from lit_gpt.datamodules.sft_dataset_base import SFTDataset 3 | from lit_gpt.datamodules.sft_multiturn_dataset_base import SFTMultiTurnDataset 4 | 5 | class PreparedDataset(TypedDict): 6 | train_dataset: Union[SFTDataset, SFTMultiTurnDataset, None] 7 | val_dataset: Union[SFTDataset, SFTMultiTurnDataset, None] 8 | test_dataset: Union[SFTDataset, SFTMultiTurnDataset, None] 9 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/ultrachat_200k.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import List, Optional, Union, TypedDict, Literal 4 | from datasets import load_dataset, Dataset 5 | 6 | from lit_gpt.prompts import PromptStyle 7 | from lit_gpt.datamodules.base import DataModule 8 | from lit_gpt.datamodules.sft_multiturn_dataset_base import SFTMultiTurnDataset 9 | from lit_gpt.tokenizer import Tokenizer 10 | from lit_gpt.datamodules.typings.formatted_dataset import ( 11 | FormattedSFTDict, 12 | FormattedSFTMultiturnConversation, 13 | FormattedSFTMultiturnDataset, 14 | ) 15 | 16 | HUGGINGFACE_ID = 'HuggingFaceH4/ultrachat_200k' 17 | DOWNLOAD_DIR = './data/sft/ultrachat_200k' 18 | SPLIT = ['train_sft'] 19 | 20 | class PreparedUltraChat(TypedDict): 21 | train_dataset: SFTMultiTurnDataset 22 | val_dataset: None 23 | test_dataset: None 24 | 25 | 26 | class UltraChatMessage(TypedDict): 27 | content: str 28 | role: Literal['user', 'assistant'] 29 | 30 | 31 | class UltraChatRow(TypedDict): 32 | prompt: str 33 | prompt_id: str 34 | messages: List[UltraChatMessage] 35 | 36 | 37 | @dataclass 38 | class UltraChat(DataModule): 39 | """UltraChat data module for supervised finetuning.""" 40 | 41 | """Whether to include multi-turn conversations in the dataset.""" 42 | include_multiturn_conversations: bool = True 43 | 44 | """The directory in which the downloaded dataset gets saved.""" 45 | download_dir: Path = Path(DOWNLOAD_DIR) 46 | 47 | """The repo from where the data is downloaded""" 48 | repo_id: str = HUGGINGFACE_ID 49 | 50 | tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) 51 | batch_size: int = field(default=1, init=False, repr=False) 52 | max_seq_length: int = field(default=-1, init=False, repr=False) 53 | train_dataset: Optional[SFTMultiTurnDataset] = field( 54 | default=None, init=False, repr=False 55 | ) 56 | test_dataset: Optional[SFTMultiTurnDataset] = field( 57 | default=None, init=False, repr=False 58 | ) 59 | 60 | def __init__(self, prompt_style='chatml', num_workers: Optional[int] = None): 61 | super().__init__() 62 | self.prompt_style = PromptStyle.from_name(prompt_style) 63 | if num_workers: 64 | self.num_workers = num_workers 65 | 66 | def connect( 67 | self, 68 | tokenizer: Optional[Tokenizer] = None, 69 | batch_size: int = 1, 70 | max_seq_length: Optional[int] = None, 71 | ) -> None: 72 | self.tokenizer = tokenizer 73 | self.batch_size = batch_size 74 | self.max_seq_length = -1 if max_seq_length is None else max_seq_length 75 | 76 | def prepare_data(self) -> List[Dataset]: 77 | return load_dataset( 78 | self.repo_id, 79 | split=SPLIT, 80 | cache_dir=self.download_dir, 81 | streaming=False, 82 | ) 83 | 84 | def setup(self, stage: str = '') -> PreparedUltraChat: 85 | dataset = self.prepare_data() 86 | train_data = format_dataset(dataset[0], self.include_multiturn_conversations) 87 | # test_data = format_dataset(dataset[1], self.include_multiturn_conversations) 88 | 89 | self.train_dataset = SFTMultiTurnDataset( 90 | data=train_data, 91 | tokenizer=self.tokenizer, 92 | prompt_style=self.prompt_style, 93 | max_seq_length=self.max_seq_length, 94 | mask_prompt=self.mask_prompt, 95 | ignore_index=self.ignore_index, 96 | ) 97 | # self.test_dataset = SFTMultiTurnDataset( 98 | # data=test_data, 99 | # tokenizer=self.tokenizer, 100 | # prompt_style=self.prompt_style, 101 | # max_seq_length=self.max_seq_length, 102 | # mask_prompt=self.mask_prompt, 103 | # ignore_index=self.ignore_index, 104 | # ) 105 | 106 | return { 107 | 'train_dataset': self.train_dataset, 108 | 'val_dataset': self.test_dataset, 109 | 'test_dataset': self.test_dataset, 110 | } 111 | 112 | 113 | def format_dataset( 114 | dataset: List[UltraChatRow], include_multi_turn_conversations: bool 115 | ) -> FormattedSFTMultiturnDataset: 116 | formatted: FormattedSFTMultiturnDataset = [] 117 | 118 | for entry in dataset: 119 | formatted_convo: FormattedSFTMultiturnConversation = [] 120 | convo = entry['messages'] 121 | 122 | # Each conversation is a flat list of user-assistant pairs. 123 | # So we iterate in 2-step manner 124 | for i in range(0, len(convo) - 1, 2): 125 | if convo[i]['role'] != 'user': 126 | print( 127 | f'WARN: UltraChat row with prompt_id[{entry["prompt_id"]}] is corrupted. Expected role to be `user`, but is `{convo[i]["role"]}` instead.' 128 | ) 129 | if convo[i + 1]['role'] != 'assistant': 130 | print( 131 | f'WARN: UltraChat row with prompt_id[{entry["prompt_id"]}] is corrupted. Expected role to be `assistant`, but is `{convo[i+1]["role"]}` instead.' 132 | ) 133 | 134 | formatted_sft_dict: FormattedSFTDict = { 135 | 'instruction': convo[i]['content'], 136 | 'input': '', 137 | 'output': convo[i + 1]['content'], 138 | } 139 | 140 | formatted_convo.append(formatted_sft_dict) 141 | 142 | # If don't want to include multi turn, break after first 143 | # turn is appended: - no point including latter turns as 144 | # they become orphaned discussions without starting context 145 | if not include_multi_turn_conversations: 146 | break 147 | 148 | formatted.append(formatted_convo) 149 | 150 | return formatted 151 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/ultrachat_200k_test.py: -------------------------------------------------------------------------------- 1 | from json import loads 2 | from lit_gpt.datamodules.ultrachat_200k import format_dataset 3 | 4 | MOCK_DATA = [] 5 | 6 | with open( 7 | './lit_gpt/datamodules/ultrachat_200k_test.jsonl', 'r', encoding='utf-8' 8 | ) as jsonl_file: 9 | for line in jsonl_file: 10 | MOCK_DATA.append(loads(line)) 11 | 12 | 13 | def test_format_dataset_multiturn(): 14 | train_data = format_dataset( 15 | dataset=MOCK_DATA, include_multi_turn_conversations=True 16 | ) 17 | 18 | train_data_row_one = train_data[0] 19 | mock_data_row_one = MOCK_DATA[0]['messages'] 20 | 21 | # Multiturn row check 22 | # It comes in a flat list of user-assistant pair, that is turned into 1 train data row 23 | # Hence divde by 2 24 | assert len(mock_data_row_one) / 2 == len(train_data_row_one) 25 | 26 | # The instruction and output pair in train_data is made 27 | # from `content` from the nth and nth+1 row in `messages`` 28 | assert train_data_row_one[0]['instruction'] == mock_data_row_one[0]['content'] 29 | assert train_data_row_one[0]['output'] == mock_data_row_one[1]['content'] 30 | 31 | # Because 2 rows of `content` from the data is condensed into 1 row of train_data: 32 | assert len(train_data_row_one) == len(mock_data_row_one) / 2 33 | 34 | 35 | def test_format_dataset(): 36 | train_data = format_dataset( 37 | dataset=MOCK_DATA, include_multi_turn_conversations=False 38 | ) 39 | 40 | train_data_row_two = train_data[1] 41 | mock_data_row_two = MOCK_DATA[1]['messages'] 42 | 43 | assert train_data_row_two[0]['instruction'] == mock_data_row_two[0]['content'] 44 | assert train_data_row_two[0]['output'] == mock_data_row_two[1]['content'] 45 | 46 | # Because we don't include multiturn, 47 | # the behaviour is that only 1 instruction/output pair is made, 48 | assert len(train_data_row_two) == 1 49 | # despite the dataset having more than 1 pair 50 | assert len(mock_data_row_two) > 2 51 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/wizardlm_evol_instruct_v2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import List, Optional, Union, TypedDict, Literal 4 | from datasets import load_dataset, Dataset 5 | 6 | from lit_gpt.prompts import PromptStyle 7 | from lit_gpt.datamodules.base import DataModule 8 | from lit_gpt.datamodules.sft_dataset_base import SFTDataset 9 | from lit_gpt.tokenizer import Tokenizer 10 | from lit_gpt.datamodules.typings.formatted_dataset import ( 11 | FormattedSFTSingleturnConversation, 12 | FormattedSFTSingleturnDataset, 13 | ) 14 | 15 | HUGGINGFACE_ID = 'WizardLM/WizardLM_evol_instruct_V2_196k' 16 | DOWNLOAD_DIR = './data/sft/wizardlm-evol-instruct-v2' 17 | SPLIT = ['train'] 18 | 19 | class PreparedWizardLMEvolInstructV2(TypedDict): 20 | train_dataset: SFTDataset 21 | val_dataset: None 22 | test_dataset: None 23 | 24 | 25 | # Use declarative syntax as 'from' will not be allowed using class syntax 26 | WizardLMEvolInstructV2Conversations = List[ 27 | TypedDict( 28 | 'WizardLMEvolInstructV2Conversations', 29 | {'from': Literal['human', 'gpt'], 'value': str}, 30 | ) 31 | ] 32 | 33 | 34 | class WizardLMEvolInstructV2Row(TypedDict): 35 | conversations: WizardLMEvolInstructV2Conversations 36 | idx: str 37 | 38 | 39 | @dataclass 40 | class WizardLMEvolInstructV2(DataModule): 41 | """WizardLMEvolInstructV2 data module for supervised finetuning.""" 42 | 43 | """Whether to include multi-turn conversations in the dataset.""" 44 | include_multiturn_conversations: bool = False 45 | """The directory in which the downloaded dataset gets saved.""" 46 | download_dir: Path = Path(DOWNLOAD_DIR) 47 | """The repo from where the data is downloaded""" 48 | repo_id: str = HUGGINGFACE_ID 49 | 50 | tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) 51 | # `batch_size` has no use anymore, as the dataloaders are moved outside 52 | # in order to support training with multiple SFT datasets at one go. 53 | batch_size: int = field(default=1, init=False, repr=False) 54 | max_seq_length: int = field(default=-1, init=False, repr=False) 55 | train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) 56 | test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) 57 | 58 | def __init__(self, prompt_style='chatml', num_workers: Optional[int] = None): 59 | super().__init__() 60 | self.prompt_style = PromptStyle.from_name(prompt_style) 61 | if num_workers: 62 | self.num_workers = num_workers 63 | 64 | def connect( 65 | self, 66 | tokenizer: Optional[Tokenizer] = None, 67 | batch_size: int = 1, 68 | max_seq_length: Optional[int] = None, 69 | ) -> None: 70 | self.tokenizer = tokenizer 71 | self.batch_size = batch_size 72 | self.max_seq_length = -1 if max_seq_length is None else max_seq_length 73 | 74 | def prepare_data(self) -> List[Dataset]: 75 | return load_dataset( 76 | self.repo_id, 77 | split=SPLIT, 78 | cache_dir=self.download_dir, 79 | streaming=False, 80 | ) 81 | 82 | def setup(self, stage: str = '') -> PreparedWizardLMEvolInstructV2: 83 | dataset = self.prepare_data() 84 | 85 | train_data = format_dataset(dataset[0], self.include_multiturn_conversations) 86 | # TODO: WizardLMEvolInstructV2 doesn't have test_data. We can split it out from the train_data 87 | # test_data = format_dataset(dataset[1], self.include_multiturn_conversations) 88 | 89 | self.train_dataset = SFTDataset( 90 | data=train_data, 91 | tokenizer=self.tokenizer, 92 | prompt_style=self.prompt_style, 93 | max_seq_length=self.max_seq_length, 94 | mask_prompt=self.mask_prompt, 95 | ignore_index=self.ignore_index, 96 | ) 97 | # self.test_dataset = SFTDataset( 98 | # data=test_data, 99 | # tokenizer=self.tokenizer, 100 | # prompt_style=self.prompt_style, 101 | # max_seq_length=self.max_seq_length, 102 | # mask_prompt=self.mask_prompt, 103 | # ignore_index=self.ignore_index, 104 | # ) 105 | 106 | return { 107 | 'train_dataset': self.train_dataset, 108 | 'val_dataset': self.test_dataset, 109 | 'test_dataset': self.test_dataset, 110 | } 111 | 112 | 113 | def format_dataset( 114 | dataset: List[WizardLMEvolInstructV2Row], 115 | # `include_multi_turn_conversations` kept for backward compatibility with litgpt 116 | include_multi_turn_conversations=False, 117 | ) -> FormattedSFTSingleturnDataset: 118 | formatted: FormattedSFTSingleturnDataset = [] 119 | 120 | for entry in dataset: 121 | conversation = entry['conversations'] 122 | 123 | human_message = conversation[0] 124 | if human_message['from'] != 'human': 125 | print( 126 | f'WARN: WizardLM row[{entry["idx"]}] is corrupted. Expected role to be `user`, but is `{human_message["from"]}` instead.' 127 | ) 128 | 129 | ai_message = conversation[1] 130 | if ai_message['from'] != 'gpt': 131 | print( 132 | f'WARN: WizardLM row[{entry["idx"]}] is corrupted. Expected role to be `assistant`, but is `{ai_message["from"]}` instead.' 133 | ) 134 | 135 | formatted_sft_dict: FormattedSFTSingleturnConversation = { 136 | 'instruction': human_message['value'], 137 | 'input': '', 138 | 'output': ai_message['value'], 139 | } 140 | formatted.append(formatted_sft_dict) 141 | 142 | return formatted 143 | -------------------------------------------------------------------------------- /lit_gpt/datamodules/wizardlm_evol_instruct_v2_test.py: -------------------------------------------------------------------------------- 1 | from json import loads 2 | from lit_gpt.datamodules.wizardlm_evol_instruct_v2 import format_dataset 3 | 4 | MOCK_DATA = [] 5 | 6 | with open( 7 | './lit_gpt/datamodules/wizardlm_evol_instruct_v2_test.jsonl', 'r', encoding='utf-8' 8 | ) as jsonl_file: 9 | for line in jsonl_file: 10 | MOCK_DATA.append(loads(line)) 11 | 12 | 13 | def test_format_dataset(): 14 | train_data = format_dataset(dataset=MOCK_DATA) 15 | 16 | train_data_row_one = train_data[0] 17 | mock_data_row_one = MOCK_DATA[0] 18 | 19 | # The instruction and output pair in train_data is made 20 | # from `content` from the nth and nth+1 row in `messages`` 21 | assert ( 22 | mock_data_row_one['conversations'][0]['value'] 23 | == train_data_row_one['instruction'] 24 | ) 25 | assert ( 26 | mock_data_row_one['conversations'][1]['value'] == train_data_row_one['output'] 27 | ) 28 | -------------------------------------------------------------------------------- /lit_gpt/fused_rotary_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import math 4 | from typing import Optional, Tuple 5 | 6 | import rotary_emb 7 | import torch 8 | from einops import rearrange, repeat 9 | 10 | class ApplyRotaryEmb(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, x, cos, sin, interleaved=False, inplace=False): 13 | """ 14 | x: (batch_size, seqlen, nheads, headdim) 15 | cos, sin: (seqlen, rotary_dim / 2) 16 | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead 17 | of 1st half and 2nd half (GPT-NeoX style). 18 | rotary_dim must be <= headdim 19 | Apply rotary embedding to the first rotary_dim of x. 20 | """ 21 | batch, seqlen, nheads, headdim = x.shape 22 | rotary_seqlen, rotary_dim = cos.shape 23 | rotary_dim *= 2 24 | assert rotary_dim <= headdim 25 | assert seqlen <= rotary_seqlen 26 | assert sin.shape == (rotary_seqlen, rotary_dim // 2) 27 | x_ro = x[..., :rotary_dim] 28 | x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) 29 | out = torch.empty_like(x) if not inplace else x 30 | out_ro = out[..., :rotary_dim] 31 | if inplace: 32 | o1, o2 = x1, x2 33 | else: 34 | o1, o2 = ( 35 | out_ro.chunk(2, dim=-1) 36 | if not interleaved 37 | else (out_ro[..., ::2], out_ro[..., 1::2]) 38 | ) 39 | rotary_emb.apply_rotary( 40 | x1, 41 | x2, 42 | rearrange(cos[:seqlen], "s d -> s 1 d"), 43 | rearrange(sin[:seqlen], "s d -> s 1 d"), 44 | o1, 45 | o2, 46 | False, 47 | ) 48 | if not inplace and rotary_dim < headdim: 49 | out[..., rotary_dim:].copy_(x[..., rotary_dim:]) 50 | ctx.save_for_backward(cos, sin) 51 | ctx.interleaved = interleaved 52 | ctx.inplace = inplace 53 | return out if not inplace else x 54 | 55 | @staticmethod 56 | def backward(ctx, do): 57 | cos, sin = ctx.saved_tensors 58 | _, seqlen, _, headdim = do.shape 59 | rotary_dim = cos.shape[-1] 60 | rotary_dim *= 2 61 | inplace = ctx.inplace 62 | do_ro = do[..., :rotary_dim] 63 | do1, do2 = ( 64 | do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) 65 | ) 66 | dx = torch.empty_like(do) if not inplace else do 67 | if inplace: 68 | dx1, dx2 = do1, do2 69 | else: 70 | dx_ro = dx[..., :rotary_dim] 71 | dx1, dx2 = ( 72 | dx_ro.chunk(2, dim=-1) 73 | if not ctx.interleaved 74 | else (dx_ro[..., ::2], dx_ro[..., 1::2]) 75 | ) 76 | rotary_emb.apply_rotary( 77 | do1, 78 | do2, 79 | rearrange(cos[:seqlen], "s d -> s 1 d"), 80 | rearrange(sin[:seqlen], "s d -> s 1 d"), 81 | dx1, 82 | dx2, 83 | True, 84 | ) 85 | if not inplace and rotary_dim < headdim: 86 | dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) 87 | return dx, None, None, None, None 88 | 89 | 90 | apply_rotary_emb_func = ApplyRotaryEmb.apply 91 | 92 | -------------------------------------------------------------------------------- /lit_gpt/generate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/lit_gpt/generate/__init__.py -------------------------------------------------------------------------------- /lit_gpt/generate/base.py: -------------------------------------------------------------------------------- 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. 2 | 3 | from lightning import Fabric 4 | import torch 5 | import torch._dynamo.config 6 | import torch._inductor.config 7 | 8 | from typing import Any, Optional 9 | 10 | from lit_gpt.model import GPT 11 | 12 | # Referenced from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/generate/base.py 13 | def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor: 14 | if torch._dynamo.is_compiling(): 15 | # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly 16 | distribution = torch.empty_like(probs).exponential_(1) 17 | return torch.argmax(probs / distribution, dim=-1, keepdim=True) 18 | return torch.multinomial(probs, num_samples=1) 19 | 20 | 21 | def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor: 22 | sorted_logits, sorted_indices = torch.sort(logits, descending=False) 23 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 24 | # Example: 25 | # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0] 26 | # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7 27 | sorted_indices_to_remove = cumulative_probs <= (1 - top_p) 28 | # Keep at least 1 token always to prevent the case where no token is selected 29 | # In this case the most probable one is always kept 30 | sorted_indices_to_remove[-1:] = 0 31 | indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) 32 | logits = logits.masked_fill(indices_to_remove, float("-inf")) 33 | return logits 34 | 35 | 36 | def sample( 37 | logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0 38 | ) -> torch.Tensor: 39 | if top_p < 0.0 or top_p > 1.0: 40 | raise ValueError(f"top_p must be in [0, 1], got {top_p}") 41 | logits = logits[0, -1] 42 | # optionally crop the logits to only the top k options 43 | if top_k is not None: 44 | v, i = torch.topk(logits, min(top_k, logits.size(-1))) 45 | # do not use `torch.where` as in nanogpt because it will repeat top-k collisions 46 | logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v) 47 | # optionally scale the logits and sample from a probability distribution 48 | if temperature > 0.0 or top_p > 0.0: 49 | if temperature > 0.0: 50 | logits = logits / temperature 51 | # optionally crop the logits to smallest set of logits with a cumulative probability above top_p 52 | if top_p < 1.0: 53 | logits = sample_top_p(logits, top_p) 54 | probs = torch.nn.functional.softmax(logits, dim=-1) 55 | return multinomial_num_samples_1(probs) 56 | return torch.argmax(logits, dim=-1, keepdim=True) 57 | 58 | 59 | def next_token( 60 | model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any 61 | ) -> torch.Tensor: 62 | logits = model(x, input_pos) 63 | next = sample(logits, **kwargs) 64 | return next.to(dtype=x.dtype) 65 | 66 | 67 | @torch.inference_mode() 68 | def generate( 69 | fabric: Fabric, 70 | model: GPT, 71 | prompt: torch.Tensor, 72 | max_returned_tokens: int, 73 | *, 74 | temperature: float = 1.0, 75 | top_k: Optional[int] = None, 76 | top_p: float = 1.0, 77 | eos_id: Optional[int] = None, 78 | include_prompt: bool = True, 79 | ) -> torch.Tensor: 80 | """ 81 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 82 | The implementation of this function is modified from A. Karpathy's nanoGPT. 83 | 84 | Args: 85 | model: The model to use. 86 | prompt: Tensor of shape (T) with indices of the prompt sequence. 87 | max_returned_tokens: The maximum number of tokens to return (given plus generated). 88 | temperature: Scales the predicted logits by 1 / temperature. 89 | top_k: If specified, only sample among the tokens with the k highest probabilities. 90 | top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. 91 | In top-p sampling, the next token is sampled from the highest probability tokens 92 | whose cumulative probability exceeds the threshold `top_p`. When specified, 93 | it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent 94 | to sampling the most probable token, while `top_p=1` samples from the whole distribution. 95 | It can be used in conjunction with `top_k` and `temperature` with the following order 96 | of application: 97 | 98 | 1. `top_k` sampling 99 | 2. `temperature` scaling 100 | 3. `top_p` sampling 101 | 102 | For more details, see https://arxiv.org/abs/1904.09751 103 | or https://huyenchip.com/2024/01/16/sampling.html#top_p 104 | eos_id: If specified, stop generating any more token once the token is triggered. 105 | include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output. 106 | """ 107 | T = prompt.size(0) 108 | assert max_returned_tokens > T 109 | if model.max_seq_length < max_returned_tokens - 1: 110 | # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a 111 | # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do 112 | # not support it to avoid negatively impacting the overall speed 113 | raise NotImplementedError( 114 | f'max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}' 115 | ) 116 | 117 | model.reset_cache() 118 | # set the max_seq_length to limit the memory usage to what we need 119 | model.max_seq_length = max_returned_tokens 120 | 121 | device = prompt.device 122 | if include_prompt: 123 | tokens = [prompt] 124 | else: 125 | tokens = [] 126 | input_pos = torch.tensor([T], device=device) 127 | token = next_token( 128 | model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p 129 | ).clone() 130 | tokens.append(token) 131 | for _ in range(2, max_returned_tokens - T + 1): 132 | token = next_token( 133 | model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p 134 | ).clone() 135 | tokens.append(token) 136 | if token == eos_id: 137 | break 138 | input_pos = input_pos.add_(1) 139 | return torch.cat(tokens) 140 | -------------------------------------------------------------------------------- /lit_gpt/prompts_test.py: -------------------------------------------------------------------------------- 1 | from lit_gpt.prompts import ChatML 2 | 3 | 4 | def test_chatml(): 5 | chatml = ChatML() 6 | 7 | example = { 8 | 'instruction': 'Hello World!', 9 | } 10 | 11 | formatted = chatml.apply(example['instruction'], **example) 12 | 13 | # <|im_start|>user\n 14 | expected = f'{chatml.special_tokens["start_user"]}\n' 15 | # Hello World!<|im_end|>\n 16 | expected += f'Hello World!{chatml.special_tokens["end"]}\n' 17 | # <|im_start|>assistant 18 | expected += chatml.special_tokens['start_assistant'] + '\n' 19 | 20 | assert formatted == expected 21 | 22 | 23 | def test_chatml_with_system_message(): 24 | chatml = ChatML() 25 | 26 | example = {'instruction': 'Hello World!', 'system': 'You are Tom.'} 27 | 28 | formatted = chatml.apply(example['instruction'], **example) 29 | 30 | # <|im_start|>system\n 31 | expected = f'{chatml.special_tokens["start_system"]}\n' 32 | # You are Tom.<|im_end|>\n 33 | expected += f'You are Tom.{chatml.special_tokens["end"]}\n' 34 | # <|im_start|>user\n 35 | expected += f'{chatml.special_tokens["start_user"]}\n' 36 | # Hello World!<|im_end|>\n 37 | expected += f'Hello World!{chatml.special_tokens["end"]}\n' 38 | # <|im_start|>assistant 39 | expected += chatml.special_tokens['start_assistant'] + '\n' 40 | 41 | assert formatted == expected 42 | -------------------------------------------------------------------------------- /lit_gpt/tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | import torch 5 | from os.path import normpath, join 6 | from os import getcwd 7 | 8 | HUGGINGFACE = 'huggingface' 9 | SENTENCEPIECE = 'sentencepiece' 10 | 11 | 12 | class Tokenizer: 13 | def __init__(self, checkpoint_dir: Path) -> None: 14 | # some checkpoints have both files, `.model` takes precedence 15 | # TODO: Deprecate SentencePieceProcessor. It behaves differently from tokenizer.json 16 | # So it's best to avoid and reduce complexity. 17 | # For example, see: https://github.com/google/sentencepiece/issues/667 18 | # if (vocabulary_path := checkpoint_dir / 'tokenizer.model').is_file(): 19 | # print( 20 | # f'Tokenizer class is initialised with `{vocabulary_path}` as this takes precedence.' 21 | # ) 22 | # print( 23 | # 'If you have intended to use `tokenizer.json`, please remove `tokenizer.model`.' 24 | # ) 25 | # from sentencepiece import SentencePieceProcessor 26 | 27 | # self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) 28 | # self.backend = SENTENCEPIECE 29 | # self.bos_id = self.processor.bos_id() 30 | # self.eos_id = self.processor.eos_id() 31 | # self.pad_id = self.processor.pad_id() 32 | 33 | # self.processor.Decode 34 | 35 | if (vocabulary_path := checkpoint_dir / 'tokenizer.json').is_file(): 36 | print(f'Tokenizer class is initalised with `{vocabulary_path}`.') 37 | from tokenizers import Tokenizer as HFTokenizer 38 | 39 | self.processor: HFTokenizer = HFTokenizer.from_file(str(vocabulary_path)) 40 | self.backend = HUGGINGFACE 41 | 42 | with open(checkpoint_dir / 'tokenizer_config.json') as fp: 43 | config = json.load(fp) 44 | 45 | bos_token_config = config.get('bos_token', None) 46 | bos_token = ( 47 | bos_token_config 48 | # bos_token_config in tokenizer_config can be a str or object with 'content' 49 | if isinstance(bos_token_config, str) 50 | else bos_token_config['content'] 51 | ) 52 | self.bos_id = ( 53 | self.token_to_id(bos_token) if bos_token_config is not None else None 54 | ) 55 | 56 | eos_token_config = config.get('eos_token', None) 57 | eos_token = ( 58 | eos_token_config 59 | if isinstance(eos_token_config, str) 60 | else eos_token_config['content'] 61 | ) 62 | self.eos_id = ( 63 | self.token_to_id(eos_token) if eos_token_config is not None else None 64 | ) 65 | 66 | pad_token_config = config.get('pad_token', None) 67 | pad_token = ( 68 | pad_token_config 69 | if pad_token_config is None or isinstance(pad_token_config, str) 70 | else pad_token_config['content'] 71 | ) 72 | self.pad_id = ( 73 | self.token_to_id(pad_token) if pad_token_config is not None else None 74 | ) 75 | 76 | else: 77 | full_tokenizer_path = normpath(join(getcwd(), checkpoint_dir)) 78 | full_tokenizer_path = Path(full_tokenizer_path) 79 | raise NotImplementedError( 80 | f'Cannot find tokenizer at {full_tokenizer_path}.' 81 | ) 82 | 83 | @property 84 | def vocab_size(self) -> int: 85 | if self.backend == HUGGINGFACE: 86 | return self.processor.get_vocab_size(with_added_tokens=False) 87 | if self.backend == SENTENCEPIECE: 88 | return self.processor.vocab_size() 89 | raise RuntimeError 90 | 91 | def token_to_id(self, token: str) -> int: 92 | if self.backend == HUGGINGFACE: 93 | id_ = self.processor.token_to_id(token) 94 | elif self.backend == SENTENCEPIECE: 95 | id_ = self.processor.piece_to_id(token) 96 | else: 97 | raise RuntimeError 98 | if id_ is None: 99 | raise ValueError(f'token {token!r} not found in the collection.') 100 | return id_ 101 | 102 | def encode( 103 | self, 104 | string: str, 105 | device: Optional[torch.device] = None, 106 | bos: bool = False, 107 | eos: bool = False, 108 | max_length=-1, 109 | ) -> torch.Tensor: 110 | if self.backend == HUGGINGFACE: 111 | # add_special_tokens=False as we want to manually handle it later. 112 | tokens = self.processor.encode(string, add_special_tokens=False).ids 113 | elif self.backend == SENTENCEPIECE: 114 | raise RuntimeError 115 | tokens = self.processor.encode(string) 116 | else: 117 | raise RuntimeError 118 | 119 | if bos: 120 | bos_id = self.bos_id 121 | if bos_id is None: 122 | raise NotImplementedError( 123 | 'Attempted to add bos, but this tokenizer does not defined a bos token.' 124 | ) 125 | tokens = [bos_id] + tokens 126 | 127 | if max_length > 0: 128 | if eos: 129 | if self.eos_id is None: 130 | raise NotImplementedError( 131 | 'Attempted to add eos, but this tokenizer does not defined an eos token' 132 | ) 133 | if len(tokens) >= max_length: 134 | tokens = tokens[:max_length] 135 | tokens[-1] = self.eos_id 136 | else: 137 | tokens = tokens + [self.eos_id] 138 | else: 139 | if eos: 140 | if self.eos_id is None: 141 | raise NotImplementedError( 142 | 'Attempted to add eos, but this tokenizer does not defined an eos token' 143 | ) 144 | tokens = tokens + [self.eos_id] 145 | 146 | return torch.tensor(tokens, dtype=torch.int, device=device) 147 | 148 | def decode( 149 | self, 150 | tensor: Union[torch.Tensor, list], 151 | skip_special_tokens: Optional[bool] = None, 152 | ) -> str: 153 | if self.backend != HUGGINGFACE and skip_special_tokens is not None: 154 | print(f'WARN: Using {self.backend} does not allow `skip_special_tokens`.') 155 | 156 | tokens = tensor 157 | if isinstance(tensor, torch.Tensor): 158 | tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() 159 | 160 | if skip_special_tokens is not None and self.backend == HUGGINGFACE: 161 | return self.processor.decode(tokens, skip_special_tokens) 162 | 163 | return self.processor.decode(tokens) 164 | -------------------------------------------------------------------------------- /media/pints-1.5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/media/pints-1.5.png -------------------------------------------------------------------------------- /media/pints_ai-banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/media/pints_ai-banner.png -------------------------------------------------------------------------------- /prepare_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/prepare_dataset/__init__.py -------------------------------------------------------------------------------- /prepare_dataset/prepare_dataset_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from unittest import TestCase 4 | 5 | import numpy as np 6 | 7 | from lit_gpt.packed_dataset import PackedDatasetBuilder 8 | from lit_gpt.tokenizer import Tokenizer 9 | 10 | 11 | class TestTokenizer(TestCase): 12 | def setUp(self) -> None: 13 | self.test_string = 'Hello, world!' 14 | self.tokenizer = Tokenizer(Path('../tokenizer/pints')) 15 | 16 | def test_single_bos_eos(self): 17 | """ 18 | Test that the tokenizer correctly encodes a single string with bos and eos tokens. 19 | """ 20 | encoded = self.tokenizer.encode(self.test_string, eos=True, bos=True).tolist() 21 | # bos + encoded string + eos 22 | correct_single = ( 23 | [self.tokenizer.bos_id] 24 | + self.tokenizer.encode(self.test_string, eos=False, bos=False).tolist() 25 | + [self.tokenizer.eos_id] 26 | ) 27 | self.assertEqual(encoded, correct_single) 28 | 29 | def test_packed_with_bos_eos(self): 30 | """ 31 | Test that the packed dataset is correctly formatted with bos and eos tokens, even when multiple sequences are 32 | packed into it. 33 | """ 34 | 35 | chunk_size = 40 36 | pad_token = 0 37 | OUT_DIR = '../tokenizer/pints_chat/FOO' 38 | 39 | os.makedirs(OUT_DIR, exist_ok=True) 40 | 41 | training_dataset_builder = PackedDatasetBuilder( 42 | outdir=OUT_DIR, 43 | # Use process_id to differentiate builders 44 | prefix='BAR', 45 | chunk_size=chunk_size, 46 | pad_token=pad_token, # need to added pad tokens to llama 47 | dtype='auto', 48 | vocab_size=self.tokenizer.vocab_size, 49 | ) 50 | 51 | single_encoded = self.tokenizer.encode( 52 | self.test_string, eos=True, bos=True 53 | ).tolist() 54 | 55 | # Pack the string into dataset 56 | training_dataset_builder.add_array(np.array(single_encoded)) 57 | 58 | # Get list representation of packed dataset 59 | packed_single = training_dataset_builder._arr.tolist() 60 | 61 | # Get the correct packed representation 62 | correct_single = single_encoded + [pad_token] * ( 63 | chunk_size - len(single_encoded) 64 | ) 65 | 66 | # Check that the packed dataset is correct 67 | assert packed_single == correct_single 68 | 69 | # Pack the string into dataset again 70 | training_dataset_builder.add_array(np.array(single_encoded)) 71 | test_packed_double = training_dataset_builder._arr.tolist() 72 | single_encoded.extend(single_encoded) 73 | correct_packed_double = single_encoded + [ 74 | pad_token for i in range(chunk_size - len(single_encoded)) 75 | ] 76 | 77 | print(self.tokenizer.decode(training_dataset_builder._arr)) 78 | 79 | assert test_packed_double == correct_packed_double 80 | -------------------------------------------------------------------------------- /prepare_dataset/shuffle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from pathlib import Path 4 | from typing import Optional, List 5 | import pandas as pd 6 | import pyarrow as pa 7 | import pyarrow.parquet as pq 8 | from tqdm import tqdm 9 | import argparse 10 | 11 | def read_file(filepath: Path, parquet_columns: Optional[List[str]] = None) -> pa.Table: 12 | """Read a single Parquet file.""" 13 | contents = pq.read_table(filepath, columns=parquet_columns) 14 | return contents 15 | 16 | def read_and_shuffle_all_files(directory: str) -> pa.Table: 17 | """Read all Parquet files in the directory, concatenate and shuffle them.""" 18 | all_tables = [] 19 | file_list = [f for f in os.listdir(directory) if f.endswith('.parquet')] 20 | 21 | # Iterate over all files in the directory with a progress bar 22 | for filename in tqdm(file_list, desc="Reading and concatenating files", unit="file"): 23 | file_path = Path(directory) / filename 24 | file_contents = read_file(file_path) # Read the entire file 25 | all_tables.append(file_contents) 26 | 27 | print("Concatenate all the tables into one ...") 28 | combined_table = pa.concat_tables(all_tables) 29 | 30 | print("Convert to pandas for shuffling ...") 31 | df = combined_table.to_pandas() 32 | 33 | print("Shuffle the DataFrame ...") 34 | df = df.sample(frac=1).reset_index(drop=True) 35 | 36 | print("Convert back to Arrow Table ...") 37 | shuffled_table = pa.Table.from_pandas(df) 38 | 39 | return shuffled_table 40 | 41 | def save_to_parquet(table: pa.Table, batch_size: int, output_dir: str): 42 | """Save the shuffled table to Parquet files in batches.""" 43 | output_dir = Path(output_dir) 44 | output_dir.mkdir(exist_ok=True) 45 | 46 | print("Convert table to pandas DataFrame for batching ...") 47 | df = table.to_pandas() 48 | 49 | # Split the data into batches and save each batch as a separate parquet file with a progress bar 50 | for i in tqdm(range(0, len(df), batch_size), desc="Saving shuffled data", unit="batch"): 51 | batch_df = df.iloc[i:i + batch_size] 52 | batch_table = pa.Table.from_pandas(batch_df) 53 | output_file = output_dir / f'shuffled_part_{i//batch_size}.parquet' 54 | pq.write_table(batch_table, output_file) 55 | 56 | def main(directory_path: str, output_directory: str, batch_size: int): 57 | """Main function to shuffle and save Parquet files.""" 58 | print(f"Reading from directory: {directory_path}") 59 | print(f"Shuffling and saving to: {output_directory}") 60 | print(f"Using batch size: {batch_size}") 61 | 62 | shuffled_table = read_and_shuffle_all_files(directory_path) 63 | save_to_parquet(shuffled_table, batch_size, output_directory) 64 | 65 | if __name__ == "__main__": 66 | # Parse command line arguments 67 | parser = argparse.ArgumentParser(description="Shuffle Parquet files and save them in batches.") 68 | parser.add_argument("--directory_path", type=str, required=True, help="Path to the input directory containing Parquet files.") 69 | parser.add_argument("--output_directory", type=str, required=True, help="Path to the output directory where shuffled files will be saved.") 70 | parser.add_argument("--batch_size", type=int, default=10000, help="Batch size for saving Parquet files. Default is 10000.") 71 | 72 | args = parser.parse_args() 73 | 74 | main(args.directory_path, args.output_directory, args.batch_size) 75 | -------------------------------------------------------------------------------- /prepare_dataset/standard_parquet.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, List, TypedDict 3 | from glob import glob 4 | from pandas import read_parquet 5 | from .preparer import DatasetPreparer 6 | 7 | # Name of dataset 8 | DATASET_NAME = 'pints-expository-prose-v1' 9 | 10 | 11 | class StandardParquet(TypedDict): 12 | source: List[str] 13 | source_id: List[str] 14 | text: List[str] 15 | 16 | 17 | class StandardParquetPreparer(DatasetPreparer): 18 | def collect_files(self, full_source_path: Path): 19 | return glob(f'{full_source_path}/**/*.parquet', recursive=True) 20 | 21 | def read_file( 22 | self, filepath: Path, parquet_columns: Optional[List[str]] = None 23 | ) -> StandardParquet: 24 | contents = read_parquet(filepath, engine='pyarrow', columns=parquet_columns) 25 | return contents 26 | 27 | def read_file_contents(self, filepath: Path) -> List[str]: 28 | data = self.read_file(filepath, ['text']) 29 | return data['text'] 30 | 31 | 32 | prepare_dataset = StandardParquetPreparer(DATASET_NAME) 33 | 34 | 35 | def main( 36 | source_path: Optional[Path] = None, 37 | tokenizer_path=prepare_dataset.tokenizer_path, 38 | destination_path=prepare_dataset.destination_path, 39 | chunk_size=prepare_dataset.chunk_size, 40 | percentage=prepare_dataset.percentage, 41 | train_val_split_ratio=prepare_dataset.train_val_split_ratio, 42 | max_cores: Optional[int] = None, 43 | ) -> None: 44 | prepare_dataset.prepare( 45 | source_path, 46 | tokenizer_path, 47 | destination_path, 48 | chunk_size, 49 | percentage, 50 | train_val_split_ratio, 51 | max_cores, 52 | ) 53 | 54 | 55 | if __name__ == '__main__': 56 | from jsonargparse import CLI 57 | 58 | CLI(main) 59 | -------------------------------------------------------------------------------- /pretrain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/pretrain/__init__.py -------------------------------------------------------------------------------- /pretrain/custom_types/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/pretrain/custom_types/__init__.py -------------------------------------------------------------------------------- /pretrain/custom_types/training_params.py: -------------------------------------------------------------------------------- 1 | from typing import TypedDict 2 | from pathlib import Path 3 | 4 | class TrainingParams(TypedDict): 5 | devices: int 6 | batch_size: int 7 | gradient_accumulation_steps: int 8 | warmup_iters: int 9 | max_iters: int 10 | lr_decay_iters: int 11 | log_iter_interval: int 12 | model_name: str 13 | out_dir: Path 14 | global_batch_size: int 15 | learning_rate: float 16 | micro_batch_size: int 17 | max_step: int 18 | warmup_steps: int 19 | log_step_interval: int 20 | eval_iters: int 21 | eval_step_interval: int 22 | save_step_interval: int 23 | weight_decay: float 24 | beta1: int 25 | beta2: int 26 | grad_clip: int 27 | decay_lr: bool 28 | min_lr: float 29 | -------------------------------------------------------------------------------- /pretrain/custom_types/training_state.py: -------------------------------------------------------------------------------- 1 | from typing import TypedDict, Any, Union, Tuple 2 | from lightning.fabric.wrappers import _FabricOptimizer 3 | from .training_params import TrainingParams 4 | 5 | 6 | class TrainingState(TypedDict): 7 | model: Any # fabric doesn't type this... 8 | optimizer: Union[_FabricOptimizer, Tuple[_FabricOptimizer, ...]] 9 | iter_num: int 10 | step_count: int 11 | epoch: int 12 | -------------------------------------------------------------------------------- /pretrain/main_save_checkpoint_test.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | from pathlib import Path 3 | from pretrain.main import save_checkpoint 4 | 5 | 6 | def test_save_checkpoint(): 7 | mocked_fabric = Mock() 8 | mocked_fabric.save = Mock() 9 | 10 | mocked_state = { 11 | 'step_count': 10, # divisible by `save_step_interval` 12 | 'iter_num': 999, 13 | } 14 | 15 | out_dir = Path('foo') 16 | 17 | save_checkpoint( 18 | fabric=mocked_fabric, 19 | is_accumulating=False, 20 | state=mocked_state, 21 | save_step_interval=5, 22 | out_dir=out_dir, 23 | ) 24 | 25 | mocked_fabric.save.assert_called_once_with( 26 | out_dir / f'step-{mocked_state["step_count"]:08d}/lit_model.pth', mocked_state 27 | ) 28 | 29 | 30 | def test_save_checkpoint_should_not_save_due_not_arrived_at_save_step(): 31 | mocked_fabric = Mock() 32 | mocked_fabric.save = Mock() 33 | 34 | mocked_state = { 35 | 'step_count': 9, # NOT divisible by `save_step_interval` 36 | 'iter_num': 999, 37 | } 38 | 39 | out_dir = Path('foo') 40 | 41 | save_checkpoint( 42 | fabric=mocked_fabric, 43 | is_accumulating=False, 44 | state=mocked_state, 45 | save_step_interval=5, 46 | out_dir=out_dir, 47 | ) 48 | 49 | mocked_fabric.save.assert_not_called() 50 | 51 | 52 | def test_save_checkpoint_should_not_save_due_still_accumulating(): 53 | mocked_fabric = Mock() 54 | mocked_fabric.save = Mock() 55 | 56 | mocked_state = { 57 | 'step_count': 10, # divisible by `save_step_interval` 58 | 'iter_num': 999, 59 | } 60 | 61 | out_dir = Path('foo') 62 | 63 | save_checkpoint( 64 | fabric=mocked_fabric, 65 | is_accumulating=True, 66 | state=mocked_state, 67 | save_step_interval=5, 68 | out_dir=out_dir, 69 | ) 70 | 71 | mocked_fabric.save.assert_not_called() 72 | 73 | 74 | def test_save_checkpoint_should_save_if_is_last_iteration(): 75 | mocked_fabric = Mock() 76 | mocked_fabric.save = Mock() 77 | 78 | mocked_state = { 79 | 'step_count': 8888, 80 | 'iter_num': 999999, 81 | } 82 | 83 | out_dir = Path('foo') 84 | 85 | save_checkpoint( 86 | fabric=mocked_fabric, 87 | is_accumulating=False, 88 | state=mocked_state, 89 | save_step_interval=9999999, 90 | out_dir=out_dir, 91 | is_last=True, 92 | ) 93 | 94 | mocked_fabric.save.assert_called_once_with( 95 | out_dir / f'step-{mocked_state["step_count"]:08d}/lit_model.pth', mocked_state 96 | ) 97 | -------------------------------------------------------------------------------- /pretrain/main_train_test.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | from pathlib import Path 3 | from pretrain.main import save_checkpoint 4 | 5 | 6 | def test_save_checkpoint(): 7 | mocked_fabric = Mock() 8 | mocked_fabric.save = Mock() 9 | 10 | mocked_state = { 11 | 'step_count': 10, # divisible by `save_step_interval` 12 | 'iter_num': 999, 13 | } 14 | 15 | out_dir = Path('foo') 16 | 17 | save_checkpoint( 18 | fabric=mocked_fabric, 19 | is_accumulating=False, 20 | state=mocked_state, 21 | save_step_interval=5, 22 | out_dir=out_dir, 23 | ) 24 | 25 | mocked_fabric.save.assert_called_once_with( 26 | out_dir / f'step-{mocked_state["step_count"]:08d}/lit_model.pth', mocked_state 27 | ) 28 | 29 | 30 | def test_save_checkpoint_should_not_save_due_not_arrived_at_save_step(): 31 | mocked_fabric = Mock() 32 | mocked_fabric.save = Mock() 33 | 34 | mocked_state = { 35 | 'step_count': 9, # NOT divisible by `save_step_interval` 36 | 'iter_num': 999, 37 | } 38 | 39 | out_dir = Path('foo') 40 | 41 | save_checkpoint( 42 | fabric=mocked_fabric, 43 | is_accumulating=False, 44 | state=mocked_state, 45 | save_step_interval=5, 46 | out_dir=out_dir, 47 | ) 48 | 49 | mocked_fabric.save.assert_not_called() 50 | 51 | 52 | def test_save_checkpoint_should_not_save_due_still_accumulating(): 53 | mocked_fabric = Mock() 54 | mocked_fabric.save = Mock() 55 | 56 | mocked_state = { 57 | 'step_count': 10, # divisible by `save_step_interval` 58 | 'iter_num': 999, 59 | } 60 | 61 | out_dir = Path('foo') 62 | 63 | save_checkpoint( 64 | fabric=mocked_fabric, 65 | is_accumulating=True, 66 | state=mocked_state, 67 | save_step_interval=5, 68 | out_dir=out_dir, 69 | ) 70 | 71 | mocked_fabric.save.assert_not_called() 72 | 73 | 74 | def test_save_checkpoint_should_save_if_is_last_iteration(): 75 | mocked_fabric = Mock() 76 | mocked_fabric.save = Mock() 77 | 78 | mocked_state = { 79 | 'step_count': 10, # divisible by `save_step_interval` 80 | 'iter_num': 999, 81 | } 82 | 83 | out_dir = Path('foo') 84 | 85 | save_checkpoint( 86 | fabric=mocked_fabric, 87 | is_accumulating=True, 88 | state=mocked_state, 89 | save_step_interval=5, 90 | out_dir=out_dir, 91 | ) 92 | 93 | mocked_fabric.save.assert_not_called() 94 | -------------------------------------------------------------------------------- /pretrain/requirements.txt: -------------------------------------------------------------------------------- 1 | rotary_emb @ git+https://github.com/Dao-AILab/flash-attention.git@da11d1b8535c489bb97fd16a253f38217b7d6bd5#subdirectory=csrc/rotary 2 | xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@da11d1b8535c489bb97fd16a253f38217b7d6bd5#subdirectory=csrc/xentropy 3 | dropout_layer_norm @ git+https://github.com/Dao-AILab/flash-attention.git@da11d1b8535c489bb97fd16a253f38217b7d6bd5#subdirectory=csrc/layer_norm 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | extend-select = ["I", "Q"] 3 | 4 | [tool.ruff.lint.flake8-quotes] 5 | inline-quotes = "single" 6 | 7 | [tool.ruff.format] 8 | quote-style = "single" 9 | -------------------------------------------------------------------------------- /requirements.test.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | mock -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonargparse==4.32.0 2 | lightning==2.3.3 3 | ninja==1.11.1.1 4 | numpy==1.26.4 5 | pandas==2.2.2 6 | pyarrow==17.0.0 7 | PyYAML==6.0.1 8 | sentencepiece==0.2.0 9 | tokenizers==0.19.1 10 | torch==2.3.1 11 | tqdm==4.66.5 12 | wandb==0.17.5 13 | xformers==0.0.27 # 0.0.27.post2 requires torch==2.4.0, which is incompatible. 14 | zstandard==0.23.0 15 | -------------------------------------------------------------------------------- /tokenizer/README.md: -------------------------------------------------------------------------------- 1 | # Pints tokenizer 2 | 3 | This tokenizer is based on the [**Mistral**](https://huggingface.co/mistralai/Mistral-7B-v0.1) tokenizer, but with the following modifications made to `tokenizer.json`: 4 | 5 | ## Pad Tokens 6 | 7 | Tokenizers from foundational models commonly lack the padding token, which is often necessary for many downstream use cases such as batch processing or model alignment, where sequences need to be padded to equal length. This results in the need to add the padding token retrospectively, which introduces 3 issues. 8 | 9 | Firstly, it alters the vocabulary size and, consequently, the dimension of the language model head. This alteration requires additional coding logic to extrapolate the weights (embedding layers) of the model head to the new vocabulary size. 10 | 11 | Secondly, if the new vocabulary size is not divisible by 64, there could be a reduction in model throughout of up to 25% (mentioned by Andrej Karpathy [here](https://twitter.com/karpathy/status/1621578354024677377?lang=en)). The vocabulary size could be arbitrary extrapolated to the nearest multiple of 64, which again requires additional coding logic. 12 | 13 | Thirdly, the absence of a padding token can lead to the common mistake of using the end-of-sequence token as a substitute, which provides an inaccurate representation of when to produce the end-of-sequence token to stop its generation. Another common workaround employed is the use of the unknown `unk` token, which is also fundamentally incorrect. 14 | 15 | Therefore, considering the near-universal necessity of a padding token and potential downstream logistical inconveniences and pitfalls, we decided to preemptively include the padding token `<|pad|>` and extended the vocabulary size to 32,064 (from Mistral’s original 32,000). The model is pre-trained with this extended tokenizer from the start. 16 | 17 | ## Common chat template tokens 18 | 19 | As part of extending the vocabulary size to accommodate the padding token, we also added commonly-used chat template tokens. This makes the model versatile and ready for instruct fine-tuning \textit{out-of-the-box}. Table~\ref{table:chat-template-tokens} shows the lists of chat templates tokens added our tokenizer. 20 | 21 | | Template | Tokens | 22 | | ---------------------- | ---------- | 23 | | **OpenAI ChatML** | `<\|im_start\|>` | 24 | | | `<\|im_end\|>` | 25 | | | | 26 | | **Llama-2** | `[INST]` | 27 | | | `[/INST]` | 28 | | | `<>` | 29 | | | `<>` | 30 | | | | 31 | | **Llama-3** | `<\|begin_of_text\|>` | 32 | | | `<\|start_header_id\|>` | 33 | | | `<\|end_header_id\|>` | 34 | | | `<\|eot_id\|>` | 35 | | | | 36 | | **OpenChat** | `<\|end_of_turn\|>` | 37 | | | | 38 | | **Huggingface Zephyr** | `<\|user\|>` | 39 | | | `<\|system\|>` | 40 | | | `<\|assistant\|>` | 41 | 42 | ## Reserved token spaces for future customizability 43 | 44 | The tokenizer contains 49 remaining empty (`<|reserved_n|>`) token spaces. These can be easily replaced with new tokens, which allows for ease of experimentation and fine-tuning on new chat templates. In our case, we have replaced the empty tokens with Llama-3 tokens. 45 | -------------------------------------------------------------------------------- /tokenizer/llama2/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/tokenizer/llama2/tokenizer.model -------------------------------------------------------------------------------- /tokenizer/llama2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": true, 3 | "add_eos_token": false, 4 | "bos_token": { 5 | "__type": "AddedToken", 6 | "content": "", 7 | "lstrip": false, 8 | "normalized": false, 9 | "rstrip": false, 10 | "single_word": false 11 | }, 12 | "clean_up_tokenization_spaces": false, 13 | "eos_token": { 14 | "__type": "AddedToken", 15 | "content": "", 16 | "lstrip": false, 17 | "normalized": false, 18 | "rstrip": false, 19 | "single_word": false 20 | }, 21 | "legacy": false, 22 | "model_max_length": 1000000000000000019884624838656, 23 | "pad_token": null, 24 | "padding_side": "right", 25 | "sp_model_kwargs": {}, 26 | "tokenizer_class": "LlamaTokenizer", 27 | "unk_token": { 28 | "__type": "AddedToken", 29 | "content": "", 30 | "lstrip": false, 31 | "normalized": false, 32 | "rstrip": false, 33 | "single_word": false 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /tokenizer/llama_vs_mistral/dataset_compression_cmp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script used to calculate for the generated token counts of various datasets using the Mistral tokenizer. 3 | """ 4 | 5 | from transformers import AutoTokenizer, PreTrainedTokenizer 6 | from datasets import Dataset, load_dataset 7 | from typing import List, Dict, Callable, Any 8 | 9 | import pandas 10 | 11 | mistral_tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained('tokenizer/mistral') 12 | 13 | datasets = { 14 | 'capybara': load_dataset('LDJnr/Capybara', num_proc=8), 15 | 'llama_instruct': load_dataset('togethercomputer/llama-instruct', num_proc=8), 16 | 'meta_math': load_dataset('meta-math/MetaMathQA', num_proc=8), 17 | 'slim_orca': load_dataset('Open-Orca/SlimOrca-Dedup', num_proc=8), 18 | 'ultrachat_200k': load_dataset('HuggingFaceH4/ultrachat_200k', num_proc=8), 19 | 'deita_10k': load_dataset('HuggingFaceH4/deita-10k-v0-sft', num_proc=8), 20 | 'wizardlm_evol': load_dataset('Leon-Leee/Wizardlm_Evol_Instruct_v2_196K_backuped', num_proc=8) 21 | } 22 | 23 | def encode_capybara(batch: Dict[str, List[Dict[str, str]]]) -> Dict[str, Any]: 24 | """ 25 | Encode the Capybara dataset by joining 'input' and 'output' fields of each message in the conversation. 26 | """ 27 | joined_texts = [' '.join([message['input'] + ' ' + message['output'] for message in conversation]) for conversation in batch['conversation']] 28 | batch['token_counts'] = [len(mistral_tokenizer(text)['input_ids']) for text in joined_texts] 29 | return batch 30 | 31 | def encode_llama_instruct(batch: Dict[str, List[str]]) -> Dict[str, Any]: 32 | """ 33 | Encode the Llama Instruct dataset by tokenizing the 'text' field. 34 | """ 35 | batch['token_counts'] = [len(mistral_tokenizer(text)['input_ids']) for text in batch['text']] 36 | return batch 37 | 38 | def encode_meta_math(batch: Dict[str, List[str]]) -> Dict[str, Any]: 39 | """ 40 | Encode the Meta Math dataset by combining 'query' and 'response' fields. 41 | """ 42 | combined_texts = [query + ' ' + response for query, response in zip(batch['query'], batch['response'])] 43 | batch['token_counts'] = [len(mistral_tokenizer(text)['input_ids']) for text in combined_texts] 44 | return batch 45 | 46 | def encode_conversations(batch: Dict[str, List[Dict[str, str]]]) -> Dict[str, Any]: 47 | """ 48 | Encode datasets with conversation format by joining 'value' fields of each message. 49 | """ 50 | joined_texts = [' '.join([message['value'] for message in conversation]) for conversation in batch['conversations']] 51 | batch['token_counts'] = [len(mistral_tokenizer(text)['input_ids']) for text in joined_texts] 52 | return batch 53 | 54 | def encode_messages(batch: Dict[str, List[Dict[str, str]]]) -> Dict[str, Any]: 55 | """ 56 | Encode datasets with message format by joining 'content' fields of each message. 57 | """ 58 | joined_texts = [' '.join([message['content'] for message in messages]) for messages in batch['messages']] 59 | batch['token_counts'] = [len(mistral_tokenizer(text)['input_ids']) for text in joined_texts] 60 | return batch 61 | 62 | def process_dataset(name: str, dataset: Dataset, encode_fn: Callable[[Dict[str, Any]], Dict[str, Any]]) -> int: 63 | """ 64 | Process the dataset using the provided encoding function and return the total token count. 65 | 66 | Args: 67 | name (str): The name of the dataset. 68 | dataset (Dataset): The dataset to be processed. 69 | encode_fn (Callable): The encoding function to apply to the dataset. 70 | 71 | Returns: 72 | int: The total token count for the dataset. 73 | """ 74 | processed_data = dataset.map(encode_fn, batched=True, num_proc=4) 75 | dataframe: pandas.DataFrame = processed_data.to_pandas() 76 | token_count_sum = dataframe['token_counts'].sum() 77 | print(f'{name}: {token_count_sum}') 78 | return token_count_sum 79 | 80 | if __name__ == '__main__': 81 | total = 0 82 | total += process_dataset('LDJnr/Capybara', datasets['capybara']['train'], encode_capybara) 83 | total += process_dataset('togethercomputer/llama-instruct', datasets['llama_instruct']['train'], encode_llama_instruct) 84 | total += process_dataset('meta-math/MetaMathQA', datasets['meta_math']['train'], encode_meta_math) 85 | total += process_dataset('Open-Orca/SlimOrca-Dedup', datasets['slim_orca']['train'], encode_conversations) 86 | total += process_dataset('HuggingFaceH4/ultrachat_200k', datasets['ultrachat_200k']['train_sft'], encode_messages) 87 | total += process_dataset('HuggingFaceH4/deita-10k-v0-sft', datasets['deita_10k']['train_sft'], encode_messages) 88 | total += process_dataset('Leon-Leee/Wizardlm_Evol_Instruct_v2_196K_backuped', datasets['wizardlm_evol']['train'], encode_conversations) 89 | print('total tokens:', total) 90 | -------------------------------------------------------------------------------- /tokenizer/llama_vs_mistral/mistral_llama_tokenizer_cmp.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, PreTrainedTokenizer 2 | from datasets import IterableDataset, load_dataset 3 | from typing import Tuple 4 | 5 | import multiprocessing 6 | import time 7 | 8 | # Constants 9 | NUM_CPU = multiprocessing.cpu_count() - 2 # Use available CPUs minus two for processing 10 | TOTAL_ROW = 35000000 # Total number of rows to process 11 | BATCH_SIZE = TOTAL_ROW // NUM_CPU # Calculate batch size based on number of CPUs 12 | 13 | # Load dataset with streaming mode 14 | falcon_dataset = load_dataset("tiiuae/falcon-refinedweb", split='train', streaming=True) 15 | 16 | # Initialize tokenizers 17 | mistral_tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained('tokenizer/mistral') 18 | llama_tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained('tokenizer/llama2') 19 | 20 | def encode(iterable_dataset: IterableDataset) -> Tuple[int, int]: 21 | """Encodes data using Mistral and Llama tokenizers.""" 22 | mapped_dataset = iterable_dataset.map(lambda data: { 23 | 'token_count_mistral': len(mistral_tokenizer(data['content'])['input_ids']), 24 | 'token_count_llama': len(llama_tokenizer(data['content'])['input_ids']), 25 | 'word_count': len(data['content'].split(' ')), 26 | 'character_count': len(data['content']) 27 | }) 28 | 29 | # Iterate and sum the token counts for the current batch 30 | mistral_tokens_sum, llama_tokens_sum = 0, 0 31 | word_count_sum, character_count_sum = 0, 0 32 | for row in mapped_dataset: 33 | mistral_tokens_sum += row['token_count_mistral'] 34 | llama_tokens_sum += row['token_count_llama'] 35 | word_count_sum += row['word_count'] 36 | character_count_sum += row['character_count'] 37 | 38 | print("One batch done!") 39 | return mistral_tokens_sum, llama_tokens_sum, word_count_sum, character_count_sum 40 | 41 | if __name__ == '__main__': 42 | start_time = time.time() 43 | 44 | # Split the dataset into chunks for each CPU 45 | batch_data = [falcon_dataset.skip(batch_count * BATCH_SIZE).take(BATCH_SIZE) for batch_count in range(NUM_CPU)] 46 | 47 | # Process batches in parallel 48 | total_mistral_tokens, total_llama_tokens = 0, 0 49 | total_word_counts, total_character_counts = 0, 0 50 | with multiprocessing.Pool(processes=NUM_CPU) as pool: 51 | aggregated_counts = pool.map(encode, batch_data) 52 | for mistral_tokens, llama_tokens, word_counts, character_counts in aggregated_counts: 53 | total_mistral_tokens += mistral_tokens 54 | total_llama_tokens += llama_tokens 55 | total_word_counts += word_counts 56 | total_character_counts += character_counts 57 | 58 | print(f"Total tokens used for Mistral Tokenizer: {total_mistral_tokens}") 59 | print(f"Total tokens used for Llama Tokenizer: {total_llama_tokens}") 60 | print(f"Ratio: {total_mistral_tokens / total_llama_tokens}\n") 61 | 62 | print(f"Total Words tokenized: {total_word_counts}") 63 | print(f"Total Characters tokenized: {total_character_counts}\n") 64 | 65 | end_time = time.time() 66 | elapsed_time = end_time - start_time 67 | hours, rem = divmod(elapsed_time, 3600) 68 | minutes, seconds = divmod(rem, 60) 69 | print(f"Process took {int(hours)}h {int(minutes)}m {seconds:.4f}s to complete.") 70 | -------------------------------------------------------------------------------- /tokenizer/mistral_v1/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "MistralForCausalLM" 4 | ], 5 | "bos_token_id": 1, 6 | "eos_token_id": 2, 7 | "hidden_act": "silu", 8 | "hidden_size": 4096, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 14336, 11 | "max_position_embeddings": 32768, 12 | "model_type": "mistral", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 32, 15 | "num_key_value_heads": 8, 16 | "rms_norm_eps": 1e-05, 17 | "rope_theta": 10000.0, 18 | "sliding_window": 4096, 19 | "tie_word_embeddings": false, 20 | "torch_dtype": "bfloat16", 21 | "transformers_version": "4.34.0.dev0", 22 | "use_cache": true, 23 | "vocab_size": 32000 24 | } 25 | -------------------------------------------------------------------------------- /tokenizer/mistral_v1/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "bos_token_id": 1, 4 | "eos_token_id": 2, 5 | "transformers_version": "4.34.0.dev0" 6 | } 7 | -------------------------------------------------------------------------------- /tokenizer/mistral_v1/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": "", 3 | "eos_token": "", 4 | "unk_token": "" 5 | } 6 | -------------------------------------------------------------------------------- /tokenizer/mistral_v1/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/tokenizer/mistral_v1/tokenizer.model -------------------------------------------------------------------------------- /tokenizer/mistral_v1/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": true, 3 | "add_eos_token": false, 4 | "added_tokens_decoder": { 5 | "0": { 6 | "content": "", 7 | "lstrip": false, 8 | "normalized": false, 9 | "rstrip": false, 10 | "single_word": false, 11 | "special": true 12 | }, 13 | "1": { 14 | "content": "", 15 | "lstrip": false, 16 | "normalized": false, 17 | "rstrip": false, 18 | "single_word": false, 19 | "special": true 20 | }, 21 | "2": { 22 | "content": "", 23 | "lstrip": false, 24 | "normalized": false, 25 | "rstrip": false, 26 | "single_word": false, 27 | "special": true 28 | } 29 | }, 30 | "additional_special_tokens": [], 31 | "bos_token": "", 32 | "clean_up_tokenization_spaces": false, 33 | "eos_token": "", 34 | "legacy": true, 35 | "model_max_length": 1000000000000000019884624838656, 36 | "pad_token": null, 37 | "sp_model_kwargs": {}, 38 | "spaces_between_special_tokens": false, 39 | "tokenizer_class": "LlamaTokenizer", 40 | "unk_token": "", 41 | "use_default_system_prompt": false 42 | } 43 | -------------------------------------------------------------------------------- /tokenizer/modifier/check.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from pathlib import Path 3 | 4 | NEW_TOKENIZER_FOLDER = Path('./new_tokenizer').absolute().resolve() 5 | MODEL_FILE = NEW_TOKENIZER_FOLDER / 'tokenizer.model' 6 | 7 | new_tokenizer = AutoTokenizer.from_pretrained(NEW_TOKENIZER_FOLDER) 8 | print(new_tokenizer) 9 | 10 | print(new_tokenizer.encode('<|im_start|>user\nHello<|im_end|><|pad|>')) 11 | -------------------------------------------------------------------------------- /tokenizer/modifier/compare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Checks the differences between two tokenizers. 3 | Useful when you want to confirm what changes where made, or what's the difference between them. 4 | """ 5 | 6 | from transformers import AutoTokenizer 7 | 8 | 9 | def main(path1: str, path2: str, name1='Tokenizer1', name2='Tokenizer2'): 10 | tokenizer1 = AutoTokenizer.from_pretrained(path1) 11 | vocabSize1 = len(tokenizer1) 12 | 13 | print('=' * 80) 14 | print(f'Info of tokenizer [{name1}]:') 15 | print('=' * 80) 16 | print(f'Vocab size of [{name1}]: [{vocabSize1}])') 17 | print('-' * 80) 18 | print(tokenizer1) 19 | print('-' * 80) 20 | 21 | lastTenTokenizer1Tokens = [] 22 | for i in range(vocabSize1 - 10, vocabSize1): 23 | lastTenTokenizer1Tokens.append(tokenizer1.convert_ids_to_tokens(i)) 24 | 25 | print(f'Last 10 tokens: {lastTenTokenizer1Tokens}') 26 | print('=' * 80) 27 | 28 | tokenizer2 = AutoTokenizer.from_pretrained(path2) 29 | vocabSize2 = len(tokenizer2) 30 | 31 | print('\n\n') 32 | print('=' * 80) 33 | print(f'Info of tokenizer [{name2}]:') 34 | print('=' * 80) 35 | print(f'Vocab size of [{name2}]: [{vocabSize2}])') 36 | print('-' * 80) 37 | print(tokenizer2) 38 | print('-' * 80) 39 | 40 | lastTenTokenizer2Tokens = [] 41 | 42 | for i in range(vocabSize2 - 10, vocabSize2): 43 | lastTenTokenizer2Tokens.append(tokenizer2.convert_ids_to_tokens(i)) 44 | 45 | print(f'Last 10 tokens: {lastTenTokenizer2Tokens}') 46 | print('=' * 80) 47 | 48 | longerRange = max(vocabSize1, vocabSize2) 49 | 50 | print('\n\n') 51 | print('Diffing the tokenizers...') 52 | print('\n') 53 | 54 | differencesCount = 0 55 | for i in range(longerRange): 56 | token1 = tokenizer1._convert_id_to_token(i) 57 | token2 = tokenizer2._convert_id_to_token(i) 58 | 59 | if token1 == token2: 60 | continue 61 | 62 | differencesCount += 0 63 | 64 | message = f'At position [{i}],' 65 | 66 | if token1 is not False: 67 | message += f' [{name1}] has [{token1}].' 68 | else: 69 | message += f' [{name1}] does not have a token.' 70 | 71 | if token2 is not False: 72 | message += f' [{name2}] has [{token2}].' 73 | else: 74 | message += f' [{name2}] does not have a token.' 75 | 76 | print(message) 77 | 78 | print('\nTotal differences found:', differencesCount) 79 | 80 | 81 | if __name__ == '__main__': 82 | from jsonargparse import CLI 83 | 84 | CLI(main, as_positional=False) 85 | -------------------------------------------------------------------------------- /tokenizer/modifier/investigate_tokenizer_oddity.py: -------------------------------------------------------------------------------- 1 | """ 2 | Investigate tokenizer oddity. 3 | 4 | Why does "mysql" and " mysql" get encoded to _mysql[24976] that has lower score? 5 | """ 6 | 7 | from transformers.convert_slow_tokenizer import import_protobuf 8 | from sentencepiece import SentencePieceProcessor 9 | 10 | MODEL_FILE = '../mistral/tokenizer.model' 11 | 12 | # We modify the tokenizer.model instead of tokenizer.json 13 | # Because we want to create a base tokenizer.model, which reflects the base model that is used for pretraining. 14 | tokenizer = SentencePieceProcessor(model_file=MODEL_FILE) 15 | 16 | print('=' * 80) 17 | print('Info of original tokenizer:') 18 | print('=' * 80) 19 | print('Vocab size:', len(tokenizer)) 20 | print('-' * 80) 21 | 22 | lastTenTokens = [] 23 | 24 | for i in range(len(tokenizer) - 10, len(tokenizer)): 25 | lastTenTokens.append(tokenizer.IdToPiece(i)) 26 | 27 | print(f'Last 10 tokens: {lastTenTokens}') 28 | print('=' * 80) 29 | 30 | modelProtobuf2 = import_protobuf() 31 | 32 | model = modelProtobuf2.ModelProto() 33 | model.ParseFromString(open(MODEL_FILE, 'rb').read()) 34 | 35 | 36 | print('_not [459] is:\n', model.pieces[459]) 37 | print('not [1478] is:\n', model.pieces[1478]) 38 | print('[not] encoded to: ', tokenizer.Encode('not')) 39 | print('[ not] encoded to: ', tokenizer.Encode(' not')) 40 | print('[not something] encoded to: ', tokenizer.Encode('not something')) 41 | print('[ something] encoded to: ', tokenizer.Encode('"not" something')) 42 | print('[I not happy] encoded to: ', tokenizer.Encode('I not happy')) 43 | 44 | print('\n\n') 45 | 46 | print('_mysql [24976] is:\n', model.pieces[24976]) 47 | print('mysql [20235] is:\n ', model.pieces[20235]) 48 | print('[mysql] encoded to: ', tokenizer.Encode('mysql')) 49 | print('[ mysql] encoded to: ', tokenizer.Encode(' mysql')) 50 | print('[mysql something] encoded to: ', tokenizer.Encode('mysql something')) 51 | print('[ something] encoded to: ', tokenizer.Encode('"mysql" something')) 52 | print('[I like mysql] encoded to: ', tokenizer.Encode('I like mysql')) 53 | -------------------------------------------------------------------------------- /tokenizer/pints/added_tokens.json: -------------------------------------------------------------------------------- 1 | { 2 | "": 32003, 3 | "<>": 32007, 4 | "<>": 32006, 5 | "<|assistant|>": 32010, 6 | "<|begin_of_text|>": 32011, 7 | "<|end_header_id|>": 32013, 8 | "<|end_of_turn|>": 32000, 9 | "<|eot_id|>": 32014, 10 | "<|im_start|>": 32002, 11 | "<|pad|>": 32001, 12 | "<|reserved_0|>": 32015, 13 | "<|reserved_10|>": 32025, 14 | "<|reserved_11|>": 32026, 15 | "<|reserved_12|>": 32027, 16 | "<|reserved_13|>": 32028, 17 | "<|reserved_14|>": 32029, 18 | "<|reserved_15|>": 32030, 19 | "<|reserved_16|>": 32031, 20 | "<|reserved_17|>": 32032, 21 | "<|reserved_18|>": 32033, 22 | "<|reserved_19|>": 32034, 23 | "<|reserved_1|>": 32016, 24 | "<|reserved_20|>": 32035, 25 | "<|reserved_21|>": 32036, 26 | "<|reserved_22|>": 32037, 27 | "<|reserved_23|>": 32038, 28 | "<|reserved_24|>": 32039, 29 | "<|reserved_25|>": 32040, 30 | "<|reserved_26|>": 32041, 31 | "<|reserved_27|>": 32042, 32 | "<|reserved_28|>": 32043, 33 | "<|reserved_29|>": 32044, 34 | "<|reserved_2|>": 32017, 35 | "<|reserved_30|>": 32045, 36 | "<|reserved_31|>": 32046, 37 | "<|reserved_32|>": 32047, 38 | "<|reserved_33|>": 32048, 39 | "<|reserved_34|>": 32049, 40 | "<|reserved_35|>": 32050, 41 | "<|reserved_36|>": 32051, 42 | "<|reserved_37|>": 32052, 43 | "<|reserved_38|>": 32053, 44 | "<|reserved_39|>": 32054, 45 | "<|reserved_3|>": 32018, 46 | "<|reserved_40|>": 32055, 47 | "<|reserved_41|>": 32056, 48 | "<|reserved_42|>": 32057, 49 | "<|reserved_43|>": 32058, 50 | "<|reserved_44|>": 32059, 51 | "<|reserved_45|>": 32060, 52 | "<|reserved_46|>": 32061, 53 | "<|reserved_47|>": 32062, 54 | "<|reserved_48|>": 32063, 55 | "<|reserved_4|>": 32019, 56 | "<|reserved_5|>": 32020, 57 | "<|reserved_6|>": 32021, 58 | "<|reserved_7|>": 32022, 59 | "<|reserved_8|>": 32023, 60 | "<|reserved_9|>": 32024, 61 | "<|start_header_id|>": 32012, 62 | "<|system|>": 32009, 63 | "<|user|>": 32008, 64 | "[/INST]": 32005, 65 | "[INST]": 32004 66 | } 67 | -------------------------------------------------------------------------------- /tokenizer/pints/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/tokenizer/pints/tokenizer.model -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/1.5-Pints/2a0cb6261636411b77bb87c777c1ea5449552071/util/__init__.py -------------------------------------------------------------------------------- /util/splitter.py: -------------------------------------------------------------------------------- 1 | class Splitter: 2 | 3 | """ 4 | A class to manage splitting based on a specified ratio, indicating the frequency of False outcomes. 5 | A ratio of 0.9 will return 1 True for every 10 steps. 6 | """ 7 | 8 | ratio: float 9 | numerator = 0 10 | denominator = 0 11 | 12 | def __init__(self, ratio: float) -> None: 13 | self.ratio = ratio 14 | 15 | def should_split(self): 16 | self.denominator += 1 17 | if self.numerator / self.denominator < self.ratio: 18 | self.numerator += 1 19 | return False 20 | return True --------------------------------------------------------------------------------