├── .gitignore ├── NOTICE ├── README.md ├── chatllama ├── __init__.py ├── langchain_modules │ ├── __init__.py │ └── prompt_templates.py ├── llama_model.py └── rlhf │ ├── __init__.py │ ├── actor.py │ ├── config.py │ ├── config.yaml │ ├── reward.py │ ├── test.py │ ├── trainer.py │ └── utils.py ├── generate_dataset.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | .idea 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # MacOS DS_Store 133 | .DS_Store 134 | 135 | # Pickle folder 136 | .pkl_memoize_py3 137 | 138 | # Folder where optimized models are stored 139 | optimized_model 140 | 141 | # Config file for tests coverage 142 | .coveragerc 143 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | The code is originally from [nebuly-ai](https://github.com/nebuly-ai/nebullvm/tree/main/apps/accelerate/chatllama) with some changes. More changes will follow up soon. And the original license link is here: https://github.com/nebuly-ai/nebullvm/blob/main/LICENSE 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ChatLLaMA 2 | 3 | > 📢 Open source implementation for LLaMA-based ChatGPT runnable in a single GPU. 15x faster training process than `ChatGPT` 4 | 5 | - 🔥 Please check [`pyllama`](https://github.com/juncongmoo/pyllama) for `LLaMA` installation and `single GPU inference` setup. 6 | - 🔥 To train ChatGPT in 5 mins - [minichatgpt](https://github.com/juncongmoo/minichatgpt) 7 | 8 | 9 | Meta has recently released LLaMA, a collection of foundational large language models ranging from 7 to 65 billion parameters. 10 | LLaMA is creating a lot of excitement because it is smaller than GPT-3 but has better performance. For example, LLaMA's 13B architecture outperforms GPT-3 despite being 10 times smaller. This new collection of fundamental models opens the door to faster inference performance and chatGPT-like real-time assistants, while being cost-effective and running on a single GPU. 11 | 12 | However, LLaMA was not fine-tuned for instruction task with a Reinforcement Learning from Human Feedback (RLHF) training process. 13 | 14 | The good news is that we introduce `ChatLLaMA`, the first open source implementation of LLaMA based on RLHF: 15 | 16 | - A complete open source implementation that enables you to build a ChatGPT-style service based on pre-trained LLaMA models. 17 | - Compared to the original ChatGPT, the training process and single-GPU inference are much faster and cheaper by taking advantage of the smaller size of LLaMA architectures. 18 | - ChatLLaMA has built-in support for DeepSpeed ZERO to speedup the fine-tuning process. 19 | - The library also supports all LLaMA model architectures (7B, 13B, 33B, 65B), so that you can fine-tune the model according to your preferences for training time and inference performance. 20 | 21 | 22 | Screen Shot 2023-02-26 at 10 56 13 PM 23 | 24 | Image from [OpenAI’s blog](https://openai.com/blog/chatgpt). 25 | 26 | 27 | # Installation 28 | 29 | ``` 30 | pip install chatllama 31 | ``` 32 | 33 | 34 | # Get started with ChatLLaMA 35 | 36 | > :warning: Please note this code represents the algorithmic implementation for RLHF training process of LLaMA and does not contain the model weights. To access the model weights, you need to apply to Meta's [form](https://forms.gle/jk851eBVbX1m5TAv5). 37 | 38 | ChatLLaMA allows you to easily train LLaMA-based architectures in a similar way to ChatGPT, using RLHF. 39 | For example, below is the code to start the training in the case of ChatLLaMA 7B. 40 | 41 | ```python 42 | from chatllama.rlhf.trainer import RLTrainer 43 | from chatllama.rlhf.config import Config 44 | 45 | path = "path_to_config_file.yaml" 46 | config = Config(path=path) 47 | trainer = RLTrainer(config.trainer) 48 | trainer.distillate() 49 | trainer.train() 50 | trainer.training_stats.plot() 51 | ``` 52 | 53 | Note that you should provide Meta's original weights and your custom dataset before starting the fine-tuning process. Alternatively, you can generate your own dataset using LangChain's agents. 54 | 55 | ```python 56 | python generate_dataset.py 57 | ``` 58 | 59 | -------------------------------------------------------------------------------- /chatllama/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.3' 2 | -------------------------------------------------------------------------------- /chatllama/langchain_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/henrywoo/chatllama/6e9fabf78af50ae03ac61021a27628ae5e98363c/chatllama/langchain_modules/__init__.py -------------------------------------------------------------------------------- /chatllama/langchain_modules/prompt_templates.py: -------------------------------------------------------------------------------- 1 | REWARD_TEMPLATE = dict( 2 | template=( 3 | "Lets pretend that you are a lawyer and you have to" 4 | "evalaute the following completion task from a given" 5 | "assigment with a score between 0 and 5 where 0 represents" 6 | "a bad assignment completion and 5 a perfect completion.\n" 7 | "You MUST evaluate: text quality, content quality and" 8 | "coherence.\n" 9 | "You MUST return only the number that represents your" 10 | "judgment.\n" 11 | "The assignement is:\n{user_input}\n" 12 | "The completion is:\n{completion}\n" 13 | ), 14 | input_variables=["user_input", "completion"], 15 | ) 16 | 17 | 18 | AI_CHATBOT_TEMPLATE = dict( 19 | template=( 20 | "Assistant is a large language model trained by Meta and Nebuly.ai\n" 21 | "Assistant is designed to be able to assist with a wide range of " 22 | "tasks, from answering simple questions to providing in-depth " 23 | "explanations and discussions on a wide range of topics. As a " 24 | "language model, Assistant is able to generate human-like text " 25 | "based on the input it receives, allowing it to engage in " 26 | "natural-sounding conversations and provide responses that are " 27 | "coherent and relevant to the topic at hand.\n\n" 28 | "Assistant is constantly learning and improving, and its capabilities " 29 | "are constantly evolving. It is able to process and understand large " 30 | "amounts of text, and can use this knowledge to provide accurate and " 31 | "informative responses to a wide range of questions. Additionally, " 32 | "Assistant is able to generate its own text based on the input it " 33 | "receives, allowing it to engage in discussions and provide " 34 | "explanations and descriptions on a wide range of topics.\n\n" 35 | "Overall, Assistant is a powerful tool that can help with a wide " 36 | "range of tasks and provide valuable insights and information on a " 37 | "wide range of topics. Whether you need help with a specific " 38 | "question or just want to have a conversation about a particular " 39 | "topic, Assistant is here to assist.\n\n{history}\n\n" 40 | "Human: {human_input}\n" 41 | "Assistant:" 42 | ), 43 | input_variables=["history", "human_input"], 44 | ) 45 | 46 | 47 | PERSON_CHATBOT_TEMPLATE = dict( 48 | template=( 49 | "You are a human chatting with a chatbot. The chatbot is a large " 50 | "language model trained by Meta and Nebuly-ai\n" 51 | "The chatbot is designed to be able to assist you with a wide range " 52 | "of tasks, from answering simple questions to providing in-depth " 53 | "explanations and discussions on a wide range of topics. You are a " 54 | "human and you are testing the chatbot. Ask the chatbot questions and" 55 | "see how it responds. You can also ask the chatbot to tell you a " 56 | "story." 57 | "\n\n{history}\n\n" 58 | "Chatbot: {chatbot_input}\n" 59 | "Human:" 60 | ), 61 | input_variables=["history", "chatbot_input"], 62 | ) 63 | -------------------------------------------------------------------------------- /chatllama/llama_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | from typing import Tuple, List, Union 5 | 6 | import torch.distributed 7 | import torch.nn as nn 8 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel 9 | from fairscale.nn.model_parallel.layers import ( 10 | ParallelEmbedding, 11 | ColumnParallelLinear, 12 | ) 13 | from llama import ModelArgs, Tokenizer 14 | from llama.generation import sample_top_p 15 | from llama.model import TransformerBlock, RMSNorm, precompute_freqs_cis 16 | 17 | 18 | class HFLikeTokenizer: 19 | def __init__(self, tokenizer: Tokenizer): 20 | self.tokenizer = tokenizer 21 | 22 | def __call__(self, texts: Union[List[str], str], *args, **kwargs): 23 | if isinstance(texts, str): 24 | text = self.tokenizer.encode(texts, bos=True, eos=True) 25 | tokens = torch.tensor(text).cuda().long() 26 | else: 27 | texts = [ 28 | self.tokenizer.encode(text, bos=True, eos=True) 29 | for text in texts 30 | ] 31 | max_len = max(len(text) for text in texts) 32 | tokens = ( 33 | torch.full((len(texts), max_len), self.tokenizer.pad_id) 34 | .cuda() 35 | .long() 36 | ) 37 | for i, text in enumerate(texts): 38 | tokens[i, : len(text)] = torch.tensor(text).cuda().long() 39 | output = { 40 | "input_ids": tokens, 41 | "attention_mask": (tokens != self.tokenizer.pad_id).long(), 42 | } 43 | return output 44 | 45 | def decode(self, tokens): 46 | return self.tokenizer.decode(tokens) 47 | 48 | 49 | class Transformer(nn.Module): 50 | def __init__(self, params: ModelArgs): 51 | super().__init__() 52 | self.params = params 53 | self.vocab_size = params.vocab_size 54 | self.n_layers = params.n_layers 55 | 56 | self.tok_embeddings = ParallelEmbedding( 57 | params.vocab_size, params.dim, init_method=lambda x: x 58 | ) 59 | 60 | self.layers = torch.nn.ModuleList() 61 | for layer_id in range(params.n_layers): 62 | self.layers.append(TransformerBlock(layer_id, params)) 63 | 64 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 65 | self.output = ColumnParallelLinear( 66 | params.dim, params.vocab_size, bias=False, init_method=lambda x: x 67 | ) 68 | 69 | self.freqs_cis = precompute_freqs_cis( 70 | self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 71 | ) 72 | 73 | def forward(self, tokens: torch.Tensor, attention_mask: torch.Tensor): 74 | start_pos = int(torch.argmax(attention_mask.detach(), dim=-1).item()) 75 | logits = self._forward(tokens, start_pos) 76 | return logits 77 | 78 | def _forward(self, tokens: torch.Tensor, start_pos: int): 79 | _bsz, seqlen = tokens.shape 80 | h = self.tok_embeddings(tokens) 81 | self.freqs_cis = self.freqs_cis.to(h.device) 82 | freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] # noqa E203 83 | 84 | mask = None 85 | if seqlen > 1: 86 | mask = torch.full( 87 | (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device 88 | ) 89 | mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) 90 | 91 | for layer in self.layers: 92 | h = layer(h, start_pos, freqs_cis, mask) 93 | h = self.norm(h) 94 | output = self.output(h[:, -1, :]) # only compute last logits 95 | return output.float() 96 | 97 | @torch.no_grad() 98 | def generate( 99 | self, 100 | inputs: torch.Tensor, 101 | attention_mask: torch.Tensor, 102 | max_length: int, 103 | temperature: float, 104 | top_p: float = 1.0, 105 | ): 106 | prompt_size = inputs.shape[1] 107 | total_len = min(self.params.max_seq_len, max_length + prompt_size) 108 | start_pos = prompt_size # We assume left padding 109 | prev_pos = 0 110 | generated_tokens = [] 111 | for cur_pos in range(start_pos, total_len): 112 | logits = self._forward(inputs[:, prev_pos:cur_pos], prev_pos) 113 | if temperature > 0: 114 | probs = torch.softmax(logits / temperature, dim=-1) 115 | next_token = sample_top_p(probs, top_p) 116 | else: 117 | next_token = torch.argmax(logits, dim=-1) 118 | next_token = next_token.reshape(-1) 119 | generated_tokens.append(next_token) 120 | prev_pos = cur_pos 121 | return torch.stack(generated_tokens, dim=1) 122 | 123 | 124 | def setup_model_parallel() -> Tuple[int, int]: 125 | local_rank = int(os.environ.get("LOCAL_RANK", -1)) 126 | world_size = int(os.environ.get("WORLD_SIZE", -1)) 127 | 128 | torch.distributed.init_process_group("nccl") 129 | initialize_model_parallel(world_size) 130 | torch.cuda.set_device(local_rank) 131 | 132 | # seed must be the same in all processes 133 | torch.manual_seed(1) 134 | return local_rank, world_size 135 | 136 | 137 | def load_checkpoints( 138 | ckpt_dir: str, local_rank: int, world_size: int 139 | ) -> Tuple[dict, dict]: 140 | checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) 141 | assert world_size == len(checkpoints), ( 142 | f"Loading a checkpoint for MP={len(checkpoints)} but world " 143 | f"size is {world_size}" 144 | ) 145 | ckpt_path = checkpoints[local_rank] 146 | print("Loading") 147 | checkpoint = torch.load(ckpt_path, map_location="cpu") 148 | with open(Path(ckpt_dir) / "params.json", "r") as f: 149 | params = json.loads(f.read()) 150 | return checkpoint, params 151 | 152 | 153 | def load_model( 154 | ckpt_dir: str, 155 | tokenizer_path: str, 156 | local_rank: int, 157 | world_size: int, 158 | max_batch_size: int = 32, 159 | ) -> Tuple[Transformer, HFLikeTokenizer]: 160 | checkpoint, params = load_checkpoints(ckpt_dir, local_rank, world_size) 161 | model_args: ModelArgs = ModelArgs( 162 | max_seq_len=1024, max_batch_size=max_batch_size, **params 163 | ) 164 | tokenizer = Tokenizer(model_path=tokenizer_path) 165 | model_args.vocab_size = tokenizer.n_words 166 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 167 | model = Transformer(model_args) 168 | torch.set_default_tensor_type(torch.FloatTensor) 169 | model.load_state_dict(checkpoint, strict=False) 170 | tokenizer = HFLikeTokenizer(tokenizer) 171 | return model, tokenizer 172 | 173 | 174 | def generate( 175 | model: Transformer, 176 | tokenizer: Tokenizer, 177 | prompts: List[str], 178 | max_gen_len: int, 179 | temperature: float = 0.8, 180 | top_p: float = 0.95, 181 | ) -> List[str]: 182 | bsz = len(prompts) 183 | params = model.params 184 | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) 185 | 186 | prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts] 187 | 188 | min_prompt_size = min([len(t) for t in prompt_tokens]) 189 | max_prompt_size = max([len(t) for t in prompt_tokens]) 190 | 191 | total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) 192 | 193 | tokens = torch.full((bsz, total_len), tokenizer.pad_id).cuda().long() 194 | for k, t in enumerate(prompt_tokens): 195 | tokens[k, : len(t)] = torch.tensor(t).long() 196 | input_text_mask = tokens != tokenizer.pad_id 197 | start_pos = min_prompt_size 198 | prev_pos = 0 199 | for cur_pos in range(start_pos, total_len): 200 | logits = model._forward(tokens[:, prev_pos:cur_pos], prev_pos) 201 | if temperature > 0: 202 | probs = torch.softmax(logits / temperature, dim=-1) 203 | next_token = sample_top_p(probs, top_p) 204 | else: 205 | next_token = torch.argmax(logits, dim=-1) 206 | next_token = next_token.reshape(-1) 207 | # only replace token if prompt has already been generated 208 | next_token = torch.where( 209 | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token 210 | ) 211 | tokens[:, cur_pos] = next_token 212 | prev_pos = cur_pos 213 | 214 | decoded = [] 215 | for i, t in enumerate(tokens.tolist()): 216 | # cut to max gen len 217 | t = t[: len(prompt_tokens[i]) + max_gen_len] 218 | # cut to eos tok if any 219 | try: 220 | t = t[: t.index(tokenizer.eos_id)] 221 | except ValueError: 222 | pass 223 | decoded.append(tokenizer.decode(t)) 224 | return decoded 225 | -------------------------------------------------------------------------------- /chatllama/rlhf/__init__.py: -------------------------------------------------------------------------------- 1 | """RLHF implementation inspired to Lucidrains' implementation.""" 2 | -------------------------------------------------------------------------------- /chatllama/rlhf/actor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | from beartype import beartype 6 | from beartype.typing import Optional, Tuple 7 | from einops import rearrange 8 | from torch.utils.data import Dataset, DataLoader 9 | from config import ConfigActor 10 | from utils import TrainingStats 11 | 12 | from chatllama.llama_model import load_model 13 | 14 | 15 | class ActorModel(torch.nn.Module): 16 | """Actor model that generates the augmented prompt from the initial 17 | user_input. The aim is to train this model to generate better prompts. 18 | 19 | Attributes: 20 | model: The model from LLaMA to be used 21 | tokenizer: The LLaMA tokenizer 22 | max_model_tokens (int): Maximum number of tokens that the model can 23 | handle 24 | config (ConfigActor): Configuration for the actor model 25 | 26 | Methods: 27 | load: Load the model from a path 28 | save: Save the model to a path 29 | forward: Compute the action logits for a given sequence. 30 | generate: Generate a sequence from a given prompt 31 | """ 32 | 33 | def __init__(self, config: ConfigActor) -> None: 34 | super().__init__() 35 | # load the model 36 | 37 | self.max_model_tokens = 1024 38 | self.model, self.tokenizer = load_model( 39 | ckpt_dir=config.model_folder, 40 | tokenizer_path=config.tokenizer_folder, 41 | local_rank=int(os.environ.get("LOCAL_RANK", -1)), 42 | world_size=int(os.environ.get("WORLD_SIZE", -1)), 43 | max_batch_size=config.batch_size, 44 | ) 45 | # save config 46 | self.config = config 47 | 48 | def parameters(self, **kwargs): 49 | """Return the parameters of the model 50 | 51 | Args: 52 | **kwargs: 53 | """ 54 | return self.model.parameters() 55 | 56 | @beartype 57 | def forward( 58 | self, sequences: torch.Tensor, sequences_mask: torch.Tensor 59 | ) -> torch.Tensor: 60 | """Generate logits to have probability distribution over the vocabulary 61 | of the actions 62 | 63 | Args: 64 | sequences (torch.Tensor): Sequences of states and actions used to 65 | compute token logits for the whole list of sequences 66 | attention_mask (torch.Tensor): Mask for the sequences attention 67 | 68 | Returns: 69 | logits (torch.Tensor): Logits for the actions taken 70 | """ 71 | model_output = self.model.forward( 72 | sequences, attention_mask=sequences_mask 73 | ) 74 | if self.config.debug: 75 | print("ActorModel.forward") 76 | print("model_output_logits shape", model_output.logits.shape) 77 | print("model_output logits", model_output.logits) 78 | return model_output.logits 79 | 80 | @beartype 81 | @torch.no_grad() 82 | def generate( 83 | self, states: torch.Tensor, state_mask: torch.Tensor 84 | ) -> Tuple: 85 | """Generate actions and sequences=[states, actions] from state 86 | (i.e. input of the prompt generator model) 87 | 88 | Args: 89 | state (torch.Tensor): the input of the user 90 | state_mask (torch.Tensor): Mask for the state input (for padding) 91 | 92 | Returns: 93 | actions (torch.Tensor): Actions generated from the state 94 | sequences (torch.Tensor): Sequences generated from the 95 | state as [states, actions] 96 | """ 97 | max_sequence = states.shape[1] 98 | max_tokens = self.config.max_tokens + max_sequence 99 | temperature = self.config.temperature 100 | # What if the states + completion are longer than the max context of 101 | # the model? 102 | sequences = self.model.generate( 103 | inputs=states, 104 | attention_mask=state_mask, 105 | max_length=max_tokens, 106 | temperature=temperature, 107 | ) 108 | actions = sequences[:, states.shape[1] :] # noqa E203 109 | if self.config.debug: 110 | print("ActorModel.generate") 111 | print("state", states) 112 | print("state shape", states.shape) 113 | print("sequence shape", sequences.shape) 114 | print("sequence", sequences) 115 | print("actions shape", actions.shape) 116 | print("actions", actions) 117 | return actions, sequences 118 | 119 | @beartype 120 | def load(self, path: Optional[str] = None) -> None: 121 | """Load the model from the path 122 | 123 | Args: 124 | path (str): Path to the model 125 | """ 126 | if path is None: 127 | path = self.config.model_folder + "/" + self.config.model + ".pt" 128 | if os.path.exists(self.config.model_folder) is False: 129 | os.mkdir(self.config.model_folder) 130 | print( 131 | f"Impossible to load the model: {path}" 132 | f"The path doesn't exist." 133 | ) 134 | return 135 | # load the model 136 | if os.path.exists(path) is False: 137 | print( 138 | f"Impossible to load the model: {path}" 139 | f"The path doesn't exist." 140 | ) 141 | return 142 | model_dict = torch.load(path) 143 | self.model.load_state_dict(model_dict["model"]) 144 | 145 | @beartype 146 | def save(self, path: Optional[str] = None) -> None: 147 | """Save the model to the path 148 | 149 | Args: 150 | path (Optional[str], optional): Path to store the model. 151 | Defaults to None. 152 | """ 153 | if path is None: 154 | path = self.config.model_folder + "/" + self.config.model + ".pt" 155 | if os.path.exists(self.config.model_folder) is False: 156 | os.mkdir(self.config.model_folder) 157 | torch.save({"model": self.model.state_dict()}, path) 158 | 159 | 160 | class ActorDataset(Dataset): 161 | """Dataset for the pretraining of the actor model 162 | read a json file with the following format: 163 | [ 164 | { 165 | "user_input": "..." 166 | "completion": "..." 167 | } , 168 | ... 169 | ] 170 | Where: 171 | user_input: the input of the user 172 | completion: the output of the user 173 | """ 174 | 175 | def __init__(self, path: str, device: torch.device) -> None: 176 | self.device = device 177 | self.path = path 178 | with open(path, "r") as f: 179 | data = json.load(f) 180 | self.data = [ 181 | d["user_input"] + "\n\n###\n\n" + d["completion"] for d in data 182 | ] 183 | self.len = len(self.data) 184 | 185 | def __getitem__(self, idx): 186 | return self.data[idx] 187 | 188 | def __len__( 189 | self, 190 | ): 191 | return self.len 192 | 193 | 194 | class ActorTrainer: 195 | """Used to pre-train the actor model to generate better prompts. 196 | 197 | Args: 198 | config (ConfigActor): Configuration for the actor model 199 | 200 | Attributes: 201 | config (ConfigActor): Configuration for the actor model 202 | model (ActorModel): Actor model 203 | loss_function (torch.nn.CrossEntropyLoss): Loss function 204 | optimizer (torch.optim.Adam): Optimizer 205 | validation_flag (bool): Flag to indicate if the validation dataset 206 | is provided 207 | training_stats (TrainingStats): Training statistics 208 | 209 | Methods: 210 | train: Train the actor model 211 | """ 212 | 213 | def __init__(self, config: ConfigActor) -> None: 214 | # load the model 215 | self.config = config 216 | self.model = ActorModel(config) 217 | self.loss_function = torch.nn.CrossEntropyLoss() 218 | self.optimizer = torch.optim.Adam( 219 | self.model.parameters(), lr=config.lr 220 | ) 221 | self.validation_flag = False 222 | self.training_stats = TrainingStats() 223 | if not os.path.exists(config.model_folder): 224 | os.mkdir(config.model_folder) 225 | if config.validation_dataset_path is not None: 226 | self.validation_flag = True 227 | 228 | def train( 229 | self, 230 | ) -> None: 231 | print("Start Actor Model Pretraining") 232 | # get config parameters 233 | train_dataset_path = self.config.train_dataset_path 234 | validation_dataset_path = self.config.validation_dataset_path 235 | batch_size = self.config.batch_size 236 | epochs = self.config.epochs 237 | device = self.config.device 238 | 239 | # create dataloaders 240 | train_dataset = ActorDataset(train_dataset_path, device=device) 241 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size) 242 | if self.validation_flag: 243 | eval_dataset = ActorDataset(validation_dataset_path, device=device) 244 | validation_dataloader = DataLoader( 245 | eval_dataset, batch_size=batch_size 246 | ) 247 | 248 | # compute the number of iterations 249 | n_iter = int(len(train_dataset) / batch_size) 250 | 251 | # traing loop 252 | for epoch in range(epochs): 253 | self.model.train() 254 | for i, input_output in enumerate(train_dataloader): 255 | input_output_tokenized = self.model.tokenizer( 256 | input_output, 257 | return_tensors="pt", 258 | padding=True, 259 | truncation=True, 260 | ) 261 | training_output = input_output_tokenized["input_ids"][:, 1:] 262 | training_input = input_output_tokenized["input_ids"][:, :-1] 263 | attention_mask = input_output_tokenized["attention_mask"][ 264 | :, :-1 265 | ] 266 | training_output = training_output.to(device) 267 | training_input = training_input.to(device) 268 | attention_mask = attention_mask.to(device) 269 | 270 | # forward pass 271 | est_output = self.model.forward(training_input, attention_mask) 272 | est_output = rearrange(est_output, "b s v -> (b s) v") 273 | training_output = rearrange(training_output, "b s -> (b s)") 274 | loss = self.loss_function(est_output, training_output) 275 | self.training_stats.training_loss.append(loss.item()) 276 | 277 | # backward pass 278 | self.optimizer.zero_grad() 279 | loss.backward() 280 | self.optimizer.step() 281 | 282 | # print progress 283 | if i % self.config.iteration_per_print == 0: 284 | print( 285 | f"Epoch: {epoch+1}/{epochs}, " 286 | f"Iteration: {i+1}/{n_iter}, " 287 | f"Training Loss: {loss}" 288 | ) 289 | if self.validation_flag: 290 | self.model.eval() 291 | for i, input_output in enumerate(validation_dataloader): 292 | input_output_tokenized = self.model.tokenizer( 293 | input_output, return_tensors="pt", padding=True 294 | ) 295 | validation_output = input_output_tokenized["input_ids"][ 296 | :, 1: 297 | ] 298 | validation_input = input_output_tokenized["input_ids"][ 299 | :, :-1 300 | ] 301 | attention_mask = input_output_tokenized["attention_mask"][ 302 | :, :-1 303 | ] 304 | 305 | # forward pass 306 | est_output = self.model.forward( 307 | validation_input, attention_mask 308 | ) 309 | validation_output = rearrange( 310 | validation_output, "b s -> (b s)" 311 | ) 312 | est_output = rearrange(est_output, "b s v -> (b s) v") 313 | loss = self.loss_function(est_output, validation_output) 314 | self.training_stats.validation_loss.append(loss.item()) 315 | 316 | # print progress 317 | if i % self.config.iteration_per_print == 0: 318 | print( 319 | f"Epoch: {epoch+1}/{epochs}, " 320 | f"Iteration: {i+1}/{n_iter}, " 321 | f"Validation Loss: {loss}" 322 | ) 323 | self.model.save() 324 | print("Training Finished ") 325 | -------------------------------------------------------------------------------- /chatllama/rlhf/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | from beartype import beartype 7 | from beartype.typing import Optional 8 | 9 | 10 | @dataclass 11 | class ConfigReward: 12 | """Config parameters for the reward model 13 | 14 | Attributes: 15 | model (str): Model to be used for the reward model 16 | model_folder (str): Path to the folder where model are stored (used 17 | to load / store finetuned model) 18 | device (torch.device): Device to be used for the reward model 19 | model_head_hidden_size (int): Hidden size of the reward model head 20 | debug (bool): enable prints for Debugging 21 | train_dataset_path (Optional[str]): Path to the training dataset. 22 | Default to None. To be specified only for the reward model trainig. 23 | validation_dataset_path (Optional[str]): Path to the validation 24 | dataset. Default to None. To be specified only for the reward 25 | model trainig. 26 | batch_size (Optional[int]): Batch size to train the reward model. 27 | Default to None. To be specified only for the reward model 28 | trainig. 29 | epochs (Optional[int]): Number of epochs to train the reward model. 30 | Default to None. To be specified only for the reward model 31 | trainig. 32 | iteration_per_print (Optional[int]): Number of iterations to print 33 | the training loss. Default to None. To be specified only for the 34 | reward model trainig. 35 | lr (Optional[float]): Learning rate for the reward model. Default to 36 | None. To be specified only for the reward model distillation. 37 | llm_model (Optional[str]): Model to be used for the language model 38 | (LLM). Default to None. 39 | llm_max_tokens (Optional[int]): Max tokens for the LLM. Default to 40 | None. 41 | llm_temperature (Optional[float]): Temperature for the LLM. Default 42 | to None. 43 | """ 44 | 45 | model: str 46 | model_folder: str 47 | device: torch.device 48 | model_head_hidden_size: int 49 | debug: bool 50 | train_dataset_path: Optional[str] = None 51 | validation_dataset_path: Optional[str] = None 52 | batch_size: Optional[int] = None 53 | epochs: Optional[int] = None 54 | iteration_per_print: Optional[int] = None 55 | lr: Optional[float] = None 56 | llm_model: Optional[str] = None 57 | llm_max_tokens: Optional[int] = None 58 | llm_temperature: Optional[float] = None 59 | 60 | 61 | @dataclass 62 | class ConfigActor: 63 | """Config parameters for models 64 | 65 | Attributes: 66 | model (str): Model to be used for the actor 67 | model_folder (str): Path to the folder where model are stored (used 68 | to load / store finetuned model) 69 | max_tokens (int): Max tokens for the actor 70 | temperature (float): Temperature for the actor 71 | device (torch.device): Device to be used for the actor 72 | lr (float): Learning rate for the actor 73 | iteration_per_print (int): Number of iterations to print the 74 | training loss 75 | batch_size (int): Batch size to train the actor 76 | epochs (int): Number of epochs to train the actor 77 | debug (bool): Enable prints for debugging 78 | train_dataset_path (str): Path to the training dataset 79 | validation_dataset_path (Optional[str]): Path to the validation dataset 80 | """ 81 | 82 | model: str 83 | model_folder: str 84 | tokenizer_folder: str 85 | max_tokens: int 86 | temperature: float 87 | device: torch.device 88 | lr: float 89 | iteration_per_print: int 90 | batch_size: int 91 | epochs: int 92 | debug: bool 93 | train_dataset_path: str 94 | validation_dataset_path: Optional[str] = None 95 | 96 | 97 | @dataclass 98 | class ConfigTrainer: 99 | """Config parameters for the trainer, used to configure the reinforcement 100 | learning training loop 101 | 102 | Attributes: 103 | update_timesteps (int): Number of timesteps to update the actor 104 | and critic. Every time update_timesteps timesteps are collected, 105 | the training loop for the actor and critic is executed using the 106 | memory buffer to learn the policy. 107 | temperature (float): Temperature for the actor and critic 108 | max_seq_len (int): Max sequence length for the actor and critic 109 | num_examples (int): Number of examples to generate for the actor 110 | and critic. For each iteration of timestep, num_examples are 111 | sampled from the prompt dataset, processed and stored in the 112 | memory buffer. 113 | actor_lr (float): Learning rate for the actor when training with 114 | reinforcement learning 115 | critic_lr (float): Learning rate for the critic when training with 116 | reinforcement learning 117 | num_episodes (int): Number of episodes, each episodes consist of 118 | a number of timesteps that are used to generate examples 119 | stored in the memory buffer. 120 | max_timesteps (int): Max timesteps for the actor and critic. 121 | for each timestep a set of examples are sampled and used to 122 | generate a completion and a reward. 123 | batch_size (int): Batch size to train the actor and critic. 124 | This batch is used to aggregate the memory from the memory buffer 125 | for the actual training of the actor and critic models. 126 | epochs (int): Number of epochs to train the actor and critic. 127 | actor_eps_clip (float): Epsilon clip for the actor 128 | critic_eps_clip (float): Epsilon clip for the critic 129 | beta_s (float): Beta for the actor and critic 130 | update_checkpoint (int): Number of timesteps to update the checkpoint 131 | llm_model_id (str): Model id for the llm 132 | llm_max_tokens (int): Max tokens for the llm 133 | llm_temperature (float): Temperature for the llm 134 | device (torch.device): Device to be used for the actor and critici 135 | checkpoint_folder (str): Folder to store the checkpoints while training 136 | debug (bool): Enable prints for debugging 137 | """ 138 | 139 | update_timesteps: int 140 | num_examples: int 141 | actor_lr: float 142 | critic_lr: float 143 | num_episodes: int 144 | max_timesteps: int 145 | examples_path: str 146 | batch_size: int 147 | epochs: int 148 | actor_eps_clip: float 149 | critic_eps_clip: float 150 | beta_s: float 151 | update_checkpoint: int 152 | llm_model_id: str 153 | llm_max_tokens: int 154 | llm_temperature: float 155 | device: torch.device 156 | checkpoint_folder: str 157 | debug: bool 158 | 159 | 160 | class Config: 161 | """Store the config parameters for the whole pipeline 162 | 163 | Args: 164 | trainer_dict (Optional[Dict]): Dictionary with the config parameters 165 | for the trainer. Default to None. If None, the config.yaml is 166 | used. 167 | actor_dict (Optional[Dict]): Dictionary with the config parameters 168 | for the actor. Default to None. If None, the config.yaml is 169 | used. 170 | critic_dict (Optional[Dict]): Dictionary with the config parameters 171 | for the critic. Default to None. If None, the config.yaml is 172 | used. 173 | reward_dict (Optional[Dict]): Dictionary with the config parameters 174 | for the reward. Default to None. If None, the config.yaml is 175 | used. 176 | device (Optional[torch.device]): Device to be used for the actor 177 | and critic. Default to None. If None, the device available is 178 | used. 179 | debug (Optional[bool]): Enable prints for debugging. Default to False. 180 | 181 | Attributes: 182 | trainer (ConfigTrainer): Config parameters for the trainer 183 | actor (ConfigActor): Config parameters for the actor 184 | critic (ConfigCritic): Config parameters for the critic 185 | reward (ConfigReward): Config parameters for the reward 186 | """ 187 | 188 | @beartype 189 | def __init__( 190 | self, 191 | path: str, 192 | device: Optional[torch.device] = None, 193 | debug: Optional[bool] = False, 194 | ) -> None: 195 | 196 | # if not specified use the device available 197 | if device is None: 198 | device = torch.device( 199 | "cuda" if torch.cuda.is_available() else "cpu" 200 | ) 201 | print(f"Current device used:{str(device)}") 202 | 203 | if path is None or os.path.exists(path) is False: 204 | raise ValueError("Path to the config.yaml is not valid") 205 | 206 | # Read the config from yaml 207 | with open(path, "r") as c: 208 | config = yaml.safe_load(c) 209 | 210 | trainer_dict = config["trainer_config"] 211 | actor_dict = config["actor_config"] 212 | critic_dict = config["critic_config"] 213 | reward_dict = config["reward_config"] 214 | 215 | # Trainer Config 216 | trainer_dict["device"] = device 217 | trainer_dict["debug"] = debug 218 | self.trainer = ConfigTrainer(**trainer_dict) 219 | # Actor Config 220 | actor_dict["device"] = device 221 | actor_dict["debug"] = debug 222 | self.actor = ConfigActor(**actor_dict) 223 | # Critic Config 224 | critic_dict["device"] = device 225 | critic_dict["debug"] = debug 226 | self.critic = ConfigReward(**critic_dict) 227 | # Reward Config 228 | reward_dict["device"] = device 229 | reward_dict["debug"] = debug 230 | self.reward = ConfigReward(**reward_dict) 231 | -------------------------------------------------------------------------------- /chatllama/rlhf/config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer_config: 3 | update_timesteps: 1 4 | num_examples: 2 5 | actor_lr: 0.00001 6 | critic_lr: 0.00001 7 | num_episodes: 10 8 | max_timesteps: 10 9 | examples_path: "dataset/sections_dataset.json" 10 | batch_size: 1 11 | epochs: 5 12 | actor_eps_clip: 0.2 13 | critic_eps_clip: 0.2 14 | beta_s: 0.1 15 | update_checkpoint: 10 16 | llm_model_id: "text-davinci-003" 17 | llm_max_tokens: 1024 18 | llm_temperature: 0.5 19 | checkpoint_folder: "./models/checkpoints" 20 | 21 | actor_config: 22 | model: "llama-7B" 23 | max_tokens: 1024 24 | temperature: 0.9 25 | train_dataset_path: "dataset/sections_dataset.json" 26 | validation_dataset_path: null 27 | batch_size: 16 28 | iteration_per_print: 10 29 | lr: 0.00001 30 | epochs: 1 31 | model_folder: "path-to-checkpoints" 32 | 33 | reward_config: 34 | # model to be chosen are gp2-large, bart-base, longformer-base-4096 35 | model: "longformer-base-4096" 36 | model_head_hidden_size: 2048 37 | model_folder: "./models" 38 | train_dataset_path: "/home/pierpaolo/Documents/optimapi/dataset/sections_dataset.json" 39 | validation_dataset_path: null 40 | batch_size: 64 41 | epochs: 20 42 | iteration_per_print: 10 43 | lr: 0.0001 44 | 45 | critic_config: 46 | # model to be chosen are gp2-large, bart-base, longformer-base-4096 47 | model: "longformer-base-4096" 48 | model_head_hidden_size: 2048 49 | model_folder: "./models" -------------------------------------------------------------------------------- /chatllama/rlhf/reward.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | from beartype import beartype 6 | from beartype.typing import Optional, Iterable 7 | from einops.layers.torch import Rearrange 8 | from langchain import OpenAI, LLMChain, PromptTemplate 9 | from torch.utils.data import Dataset, DataLoader 10 | from transformers import GPT2Tokenizer, GPT2Model, BartModel 11 | from transformers import BartTokenizer, BartConfig, AutoModel, AutoTokenizer 12 | 13 | from chatllama.langchain_modules.prompt_templates import REWARD_TEMPLATE 14 | from config import ConfigReward 15 | from utils import TrainingStats 16 | 17 | 18 | class RewardModel(torch.nn.Module): 19 | """Model to be trained to predict the reward for RL. 20 | or to be used as Critic in RL. 21 | 22 | Attributes: 23 | model (torch.nn.Module): Model to be used for the reward model 24 | tokenizer (torch.nn.Module): Tokenizer to be used for the reward model 25 | head (torch.nn.Module): Head to be used for the reward model 26 | config (ConfigReward): Config parameters for the reward model 27 | max_model_tokens (int): Maximum sequence length for the reward model 28 | 29 | Methods: 30 | forward: Forward pass of the model (used by the critic) 31 | save: Save the model 32 | load: Load the model 33 | get_reward: Get the reward for a given input (used by the reward model) 34 | """ 35 | 36 | def __init__(self, config: ConfigReward) -> None: 37 | super().__init__() 38 | # load the model -- add here other models 39 | head_hidden_size = config.model_head_hidden_size 40 | if config.model == "gpt2-large": 41 | self.max_model_tokens = 1024 42 | self.model = GPT2Model.from_pretrained("gpt2-large") 43 | self.tokenizer = GPT2Tokenizer.from_pretrained( 44 | "gpt2-large", 45 | padding_side="left", 46 | truncation_side="left", 47 | model_max_length=self.max_model_tokens, 48 | ) 49 | self.tokenizer.pad_token = self.tokenizer.eos_token 50 | self.head = torch.nn.Sequential( 51 | torch.nn.Linear(self.model.config.n_embd, head_hidden_size), 52 | torch.nn.ReLU(), 53 | torch.nn.Linear(head_hidden_size, 1), 54 | Rearrange("... 1 -> ..."), 55 | ) 56 | elif config.model == "bart-base": 57 | self.max_model_tokens = 1024 58 | bart_config = BartConfig.from_pretrained("facebook/bart-base") 59 | bart_config.max_position_embeddings = 2048 + 1024 60 | self.model = BartModel(bart_config) 61 | self.tokenizer = BartTokenizer.from_pretrained( 62 | "facebook/bart-large", 63 | padding_side="left", 64 | truncation_side="left", 65 | model_max_length=self.max_model_tokens, 66 | ) 67 | self.tokenizer.pad_token = self.tokenizer.eos_token 68 | self.head = torch.nn.Sequential( 69 | torch.nn.Linear(768, head_hidden_size), 70 | torch.nn.ReLU(), 71 | torch.nn.Linear(head_hidden_size, 1), 72 | Rearrange("... 1 -> ..."), 73 | ) 74 | elif config.model == "longformer-base-4096": 75 | self.max_model_tokens = 4096 76 | self.model = AutoModel.from_pretrained( 77 | "allenai/longformer-base-4096" 78 | ) 79 | self.tokenizer = AutoTokenizer.from_pretrained( 80 | "allenai/longformer-base-4096", 81 | padding_side="left", 82 | truncation_side="left", 83 | model_max_length=self.max_model_tokens, 84 | ) 85 | self.tokenizer.eos_token = self.tokenizer.pad_token 86 | self.head = torch.nn.Sequential( 87 | torch.nn.Linear(768, head_hidden_size), 88 | torch.nn.ReLU(), 89 | torch.nn.Linear(head_hidden_size, 1), 90 | Rearrange("... 1 -> ..."), 91 | ) 92 | else: 93 | raise ValueError(f"model {config.model} not supported") 94 | # store config 95 | self.config = config 96 | if os.path.exists(config.model_folder) is False: 97 | os.mkdir(config.model_folder) 98 | else: 99 | self.load() 100 | # freeze model parameters (only train the head) 101 | for param in self.model.parameters(): 102 | param.requires_grad = False 103 | # move model to device 104 | self.model.to(config.device) 105 | self.head.to(config.device) 106 | 107 | @beartype 108 | def parameters( 109 | self, 110 | ) -> Iterable[torch.nn.Parameter]: 111 | """Return the parameters of the reward model""" 112 | for p in self.model.parameters(): 113 | yield p 114 | for p in self.head.parameters(): 115 | yield p 116 | 117 | @beartype 118 | def forward( 119 | self, output_sequence: torch.Tensor, output_sequence_mask: torch.Tensor 120 | ) -> torch.Tensor: 121 | """Generate the sequence of rewards for the given output sequence 122 | what is the quality of the output sequence tokens? 123 | 124 | Args: 125 | output_sequence (torch.Tensor): The sequence of tokens to be 126 | evaluated 127 | output_sequence_mask (torch.Tensor): Mask for the attention 128 | 129 | Returns: 130 | torch.Tensor: Rewards for the given output sequence 131 | """ 132 | output = self.model( 133 | output_sequence, attention_mask=output_sequence_mask 134 | ) 135 | # What if the output_sequence is longer than the max context of 136 | # the model? 137 | rewards = self.head(output.last_hidden_state) 138 | if self.config.debug: 139 | print("RewardModel.forward") 140 | print("output_sequence.shape", output_sequence.shape) 141 | print("output_sequence", output_sequence) 142 | print("reward.shape", rewards.shape) 143 | print("reward", rewards) 144 | return rewards 145 | 146 | @beartype 147 | def get_reward( 148 | self, output_sequence: torch.Tensor, output_sequence_mask: torch.Tensor 149 | ) -> torch.Tensor: 150 | """Get the reward for the given output sequence 151 | 152 | Args: 153 | output_sequence (torch.Tensor): The concatenation of initial input 154 | and actor output as tokens 155 | output_sequence_mask (torch.Tensor): Mask for the attention 156 | """ 157 | rewards = self.forward(output_sequence, output_sequence_mask) 158 | return rewards[:, -1] 159 | 160 | @beartype 161 | def load(self, path: Optional[str] = None) -> None: 162 | """Load the model from the path 163 | 164 | Args: 165 | path (str): path to the model 166 | """ 167 | if path is None: 168 | path = self.config.model_folder + "/" + self.config.model + ".pt" 169 | if os.path.exists(self.config.model_folder) is False: 170 | os.makedirs(self.config.model_folder) 171 | print( 172 | f"Model folder does not exist. Creating it," 173 | f"and returning without loading the model:\n{path}" 174 | ) 175 | return 176 | # load the model and the tokenizer 177 | if os.path.exists(path) is False: 178 | print( 179 | f"Impossible to load the model:\n{path}\n" 180 | f"The path doesn't exist." 181 | ) 182 | return 183 | model_dict = torch.load(path) 184 | self.model.load_state_dict(model_dict["model"]) 185 | self.head.load_state_dict(model_dict["head"]) 186 | 187 | @beartype 188 | def save(self, path: Optional[str] = None) -> None: 189 | """Save the model to the path 190 | 191 | Args: 192 | path (Optional[str], optional): Path to store the model. 193 | Defaults to None. 194 | """ 195 | if path is None: 196 | path = self.config.model_folder + "/" + self.config.model + ".pt" 197 | if os.path.exists(self.config.model_folder) is False: 198 | os.makedirs(self.config.model_folder) 199 | torch.save( 200 | {"model": self.model.state_dict(), "head": self.head.state_dict()}, 201 | path, 202 | ) 203 | 204 | 205 | # just to keep namings consistent 206 | CriticModel = RewardModel 207 | 208 | 209 | class RewardDataset(Dataset): 210 | """Dataset class for the reward model 211 | read a json file with the following format: 212 | [ 213 | { 214 | "user_input": "...", 215 | "completion": "...", 216 | "score": ... 217 | }, 218 | ... 219 | ] 220 | Where: 221 | user_input: the initial input of the user 222 | completion: the completion generated by the model 223 | score: the score given by the user to the completion (or by the LLM) 224 | """ 225 | 226 | def __init__(self, path: str) -> None: 227 | print(f"Loading dataset from {path}") 228 | with open(path, "r") as f: 229 | self.data = list(json.load(f)) 230 | print(f"Loaded {len(self.data)} samples") 231 | 232 | def __getitem__(self, idx: int): 233 | user_input = self.data[idx]["user_input"] 234 | completion = self.data[idx]["completion"] 235 | item = tuple([user_input, completion]) 236 | return item 237 | 238 | def __len__( 239 | self, 240 | ): 241 | return len(self.data) 242 | 243 | 244 | class RewardTrainer: 245 | """Reward class to train the reward model 246 | 247 | Args: 248 | config (ConfigModel): Config parameters for the model 249 | 250 | Attributes: 251 | model (RewardModel): Reward model 252 | config (ConfigModel): Config parameters for the model 253 | optimizer (torch.optim): Optimizer for the model 254 | loss (torch.nn): Loss function for the model 255 | 256 | Methods: 257 | train: Train the reward model 258 | generate_user_input: Generate the user input for the LLM to evaluate a 259 | couple, (user_input, completion) and assing a score 260 | distillate: Parse the dataset and assign scores using LLMs 261 | """ 262 | 263 | def __init__(self, config: ConfigReward) -> None: 264 | self.model = RewardModel(config) 265 | self.config = config 266 | self.optimizer = torch.optim.Adam( 267 | self.model.parameters(), lr=config.lr 268 | ) 269 | self.loss_function = torch.nn.MSELoss() 270 | if not os.path.exists("./models"): 271 | os.mkdir("./models") 272 | self.training_stats = TrainingStats() 273 | self.validation_flag = False 274 | if config.validation_dataset_path is not None: 275 | self.validation_flag = True 276 | 277 | openai_llm = OpenAI( 278 | model_name=self.config.llm_model, 279 | temperature=self.config.llm_temperature, 280 | max_tokens=self.config.llm_max_tokens, 281 | ) 282 | prompt_template = PromptTemplate(**REWARD_TEMPLATE) 283 | self.llm = LLMChain(llm=openai_llm, prompt=prompt_template) 284 | 285 | def distillate( 286 | self, 287 | ): 288 | """Parse the dataset and assign scores using LLMs 289 | then save back the dataset with the uploaded scores 290 | """ 291 | # load the dataset 292 | with open(self.config.train_dataset_path, "r") as f: 293 | train_data = json.load(f) 294 | # for each element of the dataset, assing a score. 295 | for i, data in enumerate(train_data): 296 | if data.get("score", None) is None: 297 | prompt_tokens = ( 298 | data["user_input"] 299 | + data["completion"] 300 | + self.llm.prompt.template 301 | ) 302 | prompt_len = int(len(prompt_tokens.split(" ")) / 0.75) 303 | # 80% of the max length as safety margin 304 | if prompt_len > self.config.llm_max_tokens * 0.8: 305 | print( 306 | f"The prompt of the data {i} is too long\n" 307 | f"tokens: {prompt_len}\n" 308 | f"max_tokens: {self.config.llm_max_tokens * 0.8}" 309 | ) 310 | continue 311 | score = self.llm.run( 312 | user_input=data["user_input"], 313 | completion=data["completion"], 314 | ).strip() 315 | # TODO extract from score the float value with a regex 316 | score = score.split(" ")[0] 317 | try: 318 | score = float(score) 319 | except Exception: 320 | print( 321 | f"The score returned by the LLM for the" 322 | f"data, {i}, is not a float float:\n{score}" 323 | ) 324 | continue 325 | data["score"] = score 326 | # save the dataset back 327 | with open(self.config.train_dataset_path, "w") as f: 328 | json.dump(train_data, f) 329 | 330 | def train( 331 | self, 332 | ) -> None: 333 | """Train the reward model""" 334 | print("Start Training the Reward Model") 335 | # get config parameters 336 | train_dataset_path = self.config.train_dataset_path 337 | validation_dataset_path = self.config.validation_dataset_path 338 | batch_size = self.config.batch_size 339 | epochs = self.config.epochs 340 | device = self.config.device 341 | 342 | # create dataloaders 343 | train_dataset = RewardDataset(train_dataset_path) 344 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size) 345 | if self.validation_flag: 346 | eval_dataset = RewardDataset(validation_dataset_path) 347 | validation_dataloader = DataLoader( 348 | eval_dataset, batch_size=batch_size 349 | ) 350 | iteration_per_print = self.config.iteration_per_print 351 | 352 | # compute the number of iterations 353 | n_iter = int(len(train_dataset) / batch_size) 354 | 355 | # traing loop 356 | for epoch in range(epochs): 357 | self.model.train() 358 | for i, inputs in enumerate(train_dataloader): 359 | 360 | input_text = inputs["user_input"] + inputs["completion"] 361 | # tokenizer (placed here instead of dataset class) 362 | input_tokens = self.model.tokenizer( 363 | input_text, padding=True, truncation=True 364 | ) 365 | 366 | score = None # TODO: load the score 367 | 368 | # TODO: check on the length of the input tokens if they are 369 | # too many it can create problems 370 | output = torch.tensor(score, dtype=torch.float32).to(device) 371 | 372 | # forward pass 373 | est_output = self.model.get_reward( 374 | input_tokens["input_ids"].to(device), 375 | input_tokens["attention_mask"].to(device), 376 | ) 377 | 378 | loss = self.loss_function(est_output, output) 379 | self.training_stats.training_loss.append(loss.item()) 380 | 381 | # backward pass 382 | self.optimizer.zero_grad() 383 | loss.backward() 384 | self.optimizer.step() 385 | 386 | # print progress 387 | if i % iteration_per_print == 0: 388 | print( 389 | f"Epoch: {epoch+1}/{epochs}, " 390 | f"Iteration: {i+1}/{n_iter}, " 391 | f"Training Loss: {loss.item()}" 392 | ) 393 | print( 394 | "prediction", 395 | est_output.cpu().detach().numpy(), 396 | "target", 397 | score.cpu().numpy(), 398 | ) 399 | if self.validation_flag: 400 | self.model.eval() 401 | for i, (text, score) in enumerate(validation_dataloader): 402 | # forward pass 403 | input_tokens = self.model.tokenizer( 404 | text, return_tensors="pt", padding=True 405 | ) 406 | input_tokens = input_tokens.to(device) 407 | # TODO: check on the length of the input tokens if they are 408 | # too many it can create problems 409 | output = torch.tensor(score, dtype=torch.float32).to( 410 | device 411 | ) 412 | est_output = self.model.get_reward( 413 | input_tokens["input_ids"], 414 | input_tokens["attention_mask"], 415 | ) 416 | loss = self.loss_function(est_output, output) 417 | self.training_stats.validation_loss.append(loss.item()) 418 | 419 | # print progress 420 | if i % iteration_per_print == 0: 421 | print( 422 | f"Epoch: {epoch+1}/{epochs}, " 423 | f"Iteration: {i+1}/{n_iter}, " 424 | f"Validation Loss: {loss.item()}" 425 | ) 426 | self.model.save() 427 | -------------------------------------------------------------------------------- /chatllama/rlhf/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from actor import ActorTrainer 4 | from config import Config 5 | from trainer import RLTrainer 6 | from reward import RewardTrainer 7 | 8 | 9 | def test_actor_training(path=None, device=None, debug=False): 10 | config = Config(path=path, device=device, debug=debug) 11 | trainer = ActorTrainer(config.actor) 12 | trainer.train() 13 | trainer.training_stats.plot() 14 | 15 | 16 | def test_reward_training(path=None, device=None, debug=False): 17 | device = torch.device("cuda:0") 18 | config = Config(path=path, device=device, debug=debug) 19 | trainer = RewardTrainer(config.reward) 20 | trainer.train() 21 | trainer.training_stats.plot() 22 | 23 | 24 | def test_rl_trainig(path=None, device=None, debug=False): 25 | device = torch.device("cuda:0") 26 | config = Config(path=path, device=device, debug=debug) 27 | trainer = RLTrainer(config.trainer) 28 | trainer.distillate() 29 | trainer.train() 30 | trainer.training_stats.plot() 31 | 32 | 33 | if __name__ == "__main__": 34 | reward_training = True 35 | rl_training = False 36 | actor_training = False 37 | 38 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 39 | # place here the path to the config.yaml file 40 | config_path = "/home/pierpaolo/Documents/optimapi/ptuning/config.yaml" 41 | 42 | if reward_training: 43 | test_reward_training(path=config_path, device=device) 44 | if rl_training: 45 | test_rl_trainig(path=config_path, device=device) 46 | if actor_training: 47 | test_actor_training(path=config_path, device=device) 48 | -------------------------------------------------------------------------------- /chatllama/rlhf/trainer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from collections import deque, namedtuple 5 | 6 | import torch 7 | from beartype import beartype 8 | from beartype.typing import Deque, Tuple, List 9 | from einops import rearrange 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | from actor import ActorModel 13 | from reward import RewardModel, CriticModel 14 | from config import ConfigReward, ConfigActor, Config 15 | from utils import TrainingStats, ConversationLog 16 | 17 | 18 | class ActorCritic(torch.nn.Module): 19 | """Actor Critic class stores both the actor and the critic models 20 | and it generates values and action for given sequences during the training 21 | of the actor. 22 | 23 | Attributes: 24 | actor (ActorModel): Actor model 25 | critic (CriticModel): Critic model 26 | debug (bool): enable prints for Debugging 27 | 28 | Methods: 29 | forward: given a sequence returns action logits and values (used 30 | to evaluate the actor during training) 31 | generate: given a sequence returns action, action logits, values 32 | sequences and sequences masks (used to generate new sequences 33 | during acting phase) 34 | """ 35 | 36 | def __init__( 37 | self, actor_config: ConfigActor, critic_config: ConfigReward 38 | ) -> None: 39 | super().__init__() 40 | self.actor = ActorModel(actor_config) 41 | self.critic = CriticModel(critic_config) 42 | self.debug = actor_config.debug 43 | 44 | @beartype 45 | def forward( 46 | self, 47 | sequences: torch.Tensor, 48 | sequences_mask: torch.Tensor, 49 | action_len: int, 50 | ) -> Tuple: 51 | """Given the whole sequences, use the actor forward to get the logits 52 | for each token in the sequence and the critic forward to get the 53 | values for each generation step. 54 | 55 | Args: 56 | sequences (torch.Tensor): Sequences composed of [states, actions] 57 | sequence_mask (torch.Tensor): Mask for the sequences 58 | action_length (int): Length of the actions in the sequences 59 | 60 | Returns: 61 | action_logits (torch.Tensor): Logits for the actions in the 62 | sequences 63 | values (torch.Tensor): Values for the actions in the sequences 64 | """ 65 | # use a single forward on the whole sequence 66 | # to get pi(y | x) and ignore predicted output 67 | actions_logits = self.actor.forward(sequences, sequences_mask) 68 | values = self.critic.forward(sequences, sequences_mask) 69 | 70 | # return only logits and values for the actions taken 71 | real_actions_logits = actions_logits[:, -action_len:, :] 72 | real_values = values[:, -action_len:] 73 | 74 | if self.debug: 75 | print("ActorCritic.forward") 76 | print("action_len", action_len) 77 | print("sequences.shape", sequences.shape) 78 | print("sequences", sequences) 79 | print("real_action_logits.shape", actions_logits.shape) 80 | print("real_action_logits", actions_logits) 81 | print("real_values.shape", values.shape) 82 | print("real_values", values) 83 | 84 | return ( 85 | real_actions_logits, 86 | real_values, 87 | ) 88 | 89 | @torch.no_grad() 90 | @beartype 91 | def generate( 92 | self, states: torch.Tensor, state_mask: torch.Tensor 93 | ) -> Tuple: 94 | """Generate actions, actions_logits, values and sequences from states 95 | 96 | Args: 97 | states (torch.Tensor): user inputs 98 | state_mask (torch.Tensor): Mask for the states of the environment 99 | 100 | Returns: 101 | actions (torch.Tensor): Actions generated from the states 102 | actions_logits (torch.Tensor): Logits for the actions generated 103 | from the states (i.e. pi(y | x)) 104 | values (torch.Tensor): Values generated by the critic model 105 | for the actions generated by the actor (i.e. V(x)) 106 | sequences (torch.Tensor): Sequences generated from the states 107 | as [states, actions] 108 | """ 109 | # generate action sequence 110 | actions, sequence = self.actor.generate(states, state_mask) 111 | sequences_mask = sequence != self.actor.tokenizer.pad_token_id 112 | action_len = actions.shape[1] 113 | 114 | # generate actions_logits and values 115 | actions_logits, values = self.forward( 116 | sequence, sequences_mask, action_len 117 | ) 118 | if self.debug: 119 | print("ActorCritic.generate") 120 | print("actions shape", actions.shape) 121 | print("actions", actions) 122 | print("sequence shape", sequence.shape) 123 | print("sequence", sequence) 124 | print("actions_logits shape", actions_logits.shape) 125 | print("actions_logits", actions_logits) 126 | print("values shape", values.shape) 127 | print("values", values) 128 | 129 | return actions, actions_logits, values, sequence, sequences_mask 130 | 131 | 132 | # structure to store the data for each experience 133 | Memory = namedtuple( 134 | "Memory", 135 | [ 136 | "states", 137 | "actions", 138 | "sequences", 139 | "values", 140 | "rewards", 141 | "actions_log_probs", 142 | "sequences_mask", 143 | ], 144 | ) 145 | 146 | 147 | class ExperienceDataset(Dataset): 148 | """Dataset to train the actor-critic models""" 149 | 150 | def __init__( 151 | self, 152 | memories: Deque[Memory], 153 | device: torch.device, 154 | ) -> None: 155 | super().__init__() 156 | self.data = list(memories) 157 | self.device = device 158 | 159 | def __len__( 160 | self, 161 | ) -> int: 162 | return len(self.data) 163 | 164 | def __getitem__(self, idx) -> Tuple: 165 | # return the idx-th memory element as a tuple of tensors on the device 166 | item = ( 167 | self.data[idx].states.to(self.device), 168 | self.data[idx].actions.to(self.device), 169 | self.data[idx].sequences.to(self.device), 170 | self.data[idx].values.to(self.device), 171 | self.data[idx].rewards.to(self.device), 172 | self.data[idx].actions_log_probs.to(self.device), 173 | self.data[idx].sequences_mask.to(self.device), 174 | ) 175 | return item 176 | 177 | 178 | class ExamplesSampler: 179 | """Store the prompt to be sampled to generate the examples 180 | read a json file with the following format: 181 | [ 182 | { 183 | "user_input" : "", 184 | } , 185 | ... 186 | ] 187 | Where: 188 | user_input: is the input of the user or directly the input of the user 189 | with the memory preappended (i.e. user_input + memory) 190 | """ 191 | 192 | def __init__( 193 | self, 194 | path: str, 195 | ) -> None: 196 | self.path = path 197 | with open(path, "r") as f: 198 | self.data = json.load(f) 199 | 200 | def sample(self, n: int) -> List: 201 | """Sample n examples from the data 202 | 203 | Args: 204 | n (int): Number of examples to sample 205 | """ 206 | return random.sample(self.data, n) 207 | 208 | 209 | class RLTrainer: 210 | """Train the actor-critic model using RL 211 | 212 | Attributes: 213 | config (Config): Configuration of the trainer 214 | debug (bool): Debug mode 215 | actorcritic (ActorCritic): Actor-critic model 216 | actor_optim (torch.optim): Optimizer for the actor 217 | critic_optim (torch.optim): Optimizer for the critic 218 | reward (RewardModel): Reward model 219 | training_stats (TrainingStats): Class to store training stats 220 | Methods: 221 | train: the training loop that calls the learn function after generating 222 | the experiences. 223 | learn: Learn from a batch of experiences and update the actor and the 224 | critic model. 225 | load_checkpoint: Load the checkpoint of the actor-critic model 226 | save_checkpoint: Save the checkpoint of the actor-critic model 227 | generate_user_input: Generate the user input from the inputs 228 | """ 229 | 230 | def __init__( 231 | self, 232 | config: Config, 233 | ) -> None: 234 | self.config = config 235 | self.debug = config.trainer.debug 236 | 237 | # initialize agent-critic 238 | self.actorcritic = ActorCritic(config.actor, config.critic) 239 | self.actor_optim = torch.optim.Adam( 240 | self.actorcritic.actor.parameters(), lr=config.trainer.actor_lr 241 | ) 242 | self.critic_optim = torch.optim.Adam( 243 | self.actorcritic.critic.parameters(), lr=config.trainer.critic_lr 244 | ) 245 | 246 | # initialize reward model 247 | self.reward = RewardModel(config.reward) 248 | 249 | # initialize class to store training stats 250 | self.training_stats = TrainingStats() 251 | self.conversation_log = ConversationLog() 252 | 253 | # initialize examples sampler 254 | self.example_sampler = ExamplesSampler(config.trainer.examples_path) 255 | 256 | # eps 257 | self.eps = 1e-8 258 | 259 | # make models directory 260 | if not os.path.exists("./models"): 261 | os.mkdir("./models") 262 | 263 | if not os.path.exists(self.config.trainer.checkpoint_folder): 264 | os.mkdir(self.config.trainer.checkpoint_folder) 265 | 266 | def save_checkpoint( 267 | self, 268 | current_episode: int, 269 | ) -> None: 270 | print(f"Saving checkpoint for episode {current_episode+1}..") 271 | file_name = "rltraining_" + str(current_episode) + ".pt" 272 | checkpoint_folder = self.config.trainer.checkpoint_folder 273 | if os.path.exists(checkpoint_folder) is False: 274 | os.mkdir(checkpoint_folder) 275 | path = checkpoint_folder + "/" + file_name 276 | torch.save( 277 | { 278 | "episode": current_episode, 279 | "actor_state_dict": self.actorcritic.actor.state_dict(), 280 | "critic_state_dict": self.actorcritic.critic.state_dict(), 281 | "actor_optim_state_dict": self.actor_optim.state_dict(), 282 | "critic_optim_state_dict": self.critic_optim.state_dict(), 283 | "training_stats": self.training_stats, 284 | }, 285 | path, 286 | ) 287 | 288 | def load_checkpoint( 289 | self, 290 | ) -> int: 291 | # get all the files name in the checkpoint folder and take the one 292 | # with the highest epoch 293 | checkpoint_folder = self.config.trainer.checkpoint_folder 294 | if os.path.exists(checkpoint_folder) is False: 295 | os.mkdir(checkpoint_folder) 296 | print( 297 | f"Checkpoint folder {checkpoint_folder} does not exist.\n" 298 | f"No checkpoint will be loaded." 299 | ) 300 | return 301 | files = os.listdir(checkpoint_folder) 302 | episodes = [int(f.split("_")[1].split(".")[0]) for f in files] 303 | if len(episodes) == 0: 304 | return 0 305 | max_episode = max(episodes) 306 | print(f"Loading checkpoint for episode {max_episode+1}..") 307 | file_name = "rltraining_" + str(max_episode) + ".pt" 308 | path = checkpoint_folder + "/" + file_name 309 | checkpoint = torch.load(path) 310 | self.actorcritic.actor.load_state_dict(checkpoint["actor_state_dict"]) 311 | self.actorcritic.critic.load_state_dict( 312 | checkpoint["critic_state_dict"] 313 | ) 314 | self.actor_optim.load_state_dict(checkpoint["actor_optim_state_dict"]) 315 | self.critic_optim.load_state_dict( 316 | checkpoint["critic_optim_state_dict"] 317 | ) 318 | self.trainign_stats = checkpoint["training_stats"] 319 | self.actorcritic.actor.to(self.config.trainer.device) 320 | self.actorcritic.critic.to(self.config.trainer.device) 321 | return max_episode + 1 # return the next episode to train 322 | 323 | @beartype 324 | def learn(self, memories: Deque[Memory]) -> None: 325 | """Train the agent-critic model using RL: 326 | - for each batch of episodes, compute action logits and values 327 | - then compare action logits probs with memories one and values with 328 | rewards to compute the PPO loss and update the actor-critic model 329 | """ 330 | print("Start to Learn...") 331 | 332 | # get parameters 333 | epochs = self.config.trainer.epochs 334 | actor_eps_clip = self.config.trainer.actor_eps_clip 335 | critic_eps_clip = self.config.trainer.critic_eps_clip 336 | beta_s = self.config.trainer.beta_s 337 | batch_size = self.config.trainer.batch_size 338 | device = self.config.trainer.device 339 | 340 | # create dataset from memories 341 | dataloader = DataLoader( 342 | ExperienceDataset(memories, device), batch_size=batch_size 343 | ) 344 | 345 | # train agent-critic 346 | self.actorcritic.train() 347 | for epoch in range(epochs): 348 | for i, ( 349 | states, 350 | old_actions, 351 | sequences, 352 | old_values, 353 | rewards, 354 | old_actions_log_probs, 355 | sequences_mask, 356 | ) in enumerate(dataloader): 357 | 358 | # print 359 | print( 360 | "Epoch", 361 | epoch + 1, 362 | "of", 363 | epochs, 364 | "Data", 365 | i + 1, 366 | "of", 367 | int(len(dataloader) / batch_size), 368 | ) 369 | 370 | if self.debug: 371 | print("RLTrainer.learn()") 372 | print("memory states shapes are: ") 373 | print("states shape", states.shape) 374 | print("old_actions shape", old_actions.shape) 375 | print("sequences shape", sequences.shape) 376 | print("old_values shape", old_values.shape) 377 | print("rewards shape", rewards.shape) 378 | print( 379 | "old_actions_log_probs shape", 380 | old_actions_log_probs.shape, 381 | ) 382 | # reshaping rewards to match [b, s] shape 383 | rewards = rearrange(rewards, "b -> b 1") 384 | 385 | # get actions len 386 | actions_len = old_actions.shape[-1] 387 | 388 | # get actor critic forward 389 | actions_logits, values = self.actorcritic.forward( 390 | sequences, sequences_mask, actions_len 391 | ) 392 | 393 | # get action log prob 394 | actions_prob = ( 395 | torch.softmax(actions_logits, dim=-1).max(dim=-1).values 396 | ) 397 | actions_log_prob = torch.log(actions_prob + self.eps) 398 | 399 | # compute entropy 400 | entropies = (actions_prob * actions_log_prob).sum(dim=-1) 401 | 402 | # compute KL divergence 403 | kl_div_loss = ( 404 | (actions_prob * (old_actions_log_probs - actions_log_prob)) 405 | .sum(dim=-1) 406 | .mean() 407 | ) 408 | 409 | # compute PPO Loss -- Whan dimensions are different 410 | # (especially the values and the probs are 411 | # multiplied directly with the reward) 412 | ratios = (actions_log_prob - old_actions_log_probs).exp() 413 | advantages = rewards - old_values 414 | # normalize advantages 415 | advantages = (advantages - advantages.mean(dim=-1)) / ( 416 | advantages.std() + self.eps 417 | ) 418 | surr1 = advantages * ratios 419 | surr2 = ( 420 | torch.clamp(ratios, 1 - actor_eps_clip, 1 + actor_eps_clip) 421 | * advantages 422 | ) 423 | policy_loss = -torch.min(surr1, surr2) - beta_s * entropies 424 | policy_loss = policy_loss.mean() 425 | loss = policy_loss + kl_div_loss 426 | # check if loss item is nan 427 | if torch.isnan(loss): 428 | raise ValueError("Loss is nan") 429 | print("loss", loss.item()) 430 | 431 | if self.debug: 432 | print("values", values) 433 | print("old_values", old_values) 434 | print("rewards", rewards) 435 | print("ratios", ratios) 436 | print("advantages", advantages) 437 | print("entropies", entropies) 438 | 439 | # update actor with loss 440 | self.actor_optim.zero_grad() 441 | loss.backward() 442 | self.actor_optim.step() 443 | 444 | torch.cuda.synchronize(device) 445 | 446 | # compute value loss 447 | value_loss_clipped = old_values + (values - old_values).clamp( 448 | -critic_eps_clip, critic_eps_clip 449 | ) 450 | value_loss1 = (value_loss_clipped - rewards) ** 2 451 | value_loss2 = (values - rewards) ** 2 452 | value_loss = torch.max(value_loss1, value_loss2).mean() 453 | if torch.isnan(value_loss): 454 | raise ValueError("Value loss is nan") 455 | print("value_loss", value_loss.item()) 456 | 457 | # upate critic with loss 458 | self.critic_optim.zero_grad() 459 | value_loss.backward() 460 | self.critic_optim.step() 461 | 462 | self.training_stats.training_loss.append( 463 | loss.detach().cpu().item() 464 | ) 465 | self.training_stats.value_loss.append( 466 | value_loss.detach().cpu().item() 467 | ) 468 | 469 | self.actorcritic.eval() 470 | print("End Learning") 471 | 472 | def train( 473 | self, 474 | ) -> None: 475 | # initialize settings 476 | num_episodes = self.config.trainer.num_episodes 477 | max_timesteps = self.config.trainer.max_timesteps 478 | num_examples = self.config.trainer.num_examples 479 | update_timesteps = self.config.trainer.update_timesteps 480 | batch_size = self.config.trainer.batch_size 481 | update_checkpoint = self.config.trainer.update_checkpoint 482 | device = self.config.trainer.device 483 | 484 | print("Start RL Training") 485 | # check dimensions consistency 486 | # at each time step num_examples memories are generated 487 | number_of_memories_per_learn_iteration = ( 488 | num_examples * update_timesteps 489 | ) 490 | # the number of memories must be a multiple of the batch size 491 | assert ( 492 | number_of_memories_per_learn_iteration % batch_size == 0 493 | ), "The number of memories must be a multiple of the batch size" 494 | # the total number of timesteps is 495 | total_number_of_timesteps = num_episodes * max_timesteps 496 | # the update_timesteps must be a multiple 497 | # of the total number of timesteps 498 | assert total_number_of_timesteps % update_timesteps == 0, ( 499 | "The number of timesteps (num_episodes*max_timesteps)" 500 | "must be a multiple of the update_timesteps" 501 | ) 502 | 503 | # initialize memories 504 | memories = deque([]) 505 | 506 | # loop over episodes and timesteps 507 | current_time = 0 508 | checkpoint_counter = 0 509 | current_episode = self.load_checkpoint() 510 | current_learn_counter = 0 511 | 512 | self.actorcritic.eval() 513 | for eps in range(current_episode, num_episodes): 514 | for timestep in range(max_timesteps): 515 | 516 | print( 517 | f"Episode: {eps + 1} of {num_episodes}, " 518 | f"Timestep: {timestep + 1} of {max_timesteps}", 519 | ) 520 | 521 | # counter used to count timesteps into memory 522 | current_time += 1 523 | 524 | # sample num_examples examples from example dataset 525 | inputs = self.example_sampler.sample(num_examples) 526 | 527 | # tokenize examples 528 | tokenized_inputs = self.actorcritic.actor.tokenizer( 529 | inputs, padding=True, return_tensors="pt" 530 | ) 531 | if self.debug: 532 | print("RLTrainer.train()") 533 | print("tokenized inputs", tokenized_inputs) 534 | # states are [batch_size, seq_len_of_states] 535 | states = tokenized_inputs["input_ids"].to(device) 536 | states_mask = tokenized_inputs["attention_mask"].to(device) 537 | 538 | # generate prompts 539 | # actions --> output produced by the actor head in response 540 | # of the state(input) [batch_size, len_of_actions] 541 | # actions_logits --> logits of the actions 542 | # [batch_size, len_of_actions, vocab_size] 543 | # values --> output produced by the critic for each action 544 | # [batch_size, len_of_actions] 545 | # sequence --> (state, actions) 546 | # [batch_size, len_of_actions + seq_len_of_states] = 547 | # [batch_size, seq_len] 548 | ( 549 | actions, 550 | actions_logits, 551 | values, 552 | sequences, 553 | sequences_mask, 554 | ) = self.actorcritic.generate(states, states_mask) 555 | 556 | # from action logits to action log probs 557 | action_prob = ( 558 | torch.softmax(actions_logits, dim=-1).max(dim=-1).values 559 | ) 560 | actions_log_probs = torch.log(action_prob + self.eps) 561 | 562 | completions = [ 563 | self.actorcritic.actor.tokenizer.decode(action) 564 | for i, action in enumerate(actions) 565 | ] 566 | if self.debug: 567 | print("RLTrainer.train()") 568 | print("completions:") 569 | for i, completion in enumerate(completions): 570 | print(i, completion) 571 | print("") 572 | 573 | # compute reward for the completion 574 | # the reward must take into account the answer quality wrt to 575 | # the initial request given 576 | # and must be tokenized again 577 | task_responses = [] 578 | for input, completion in zip(inputs, completions): 579 | task_response = input + "\n" + completion 580 | task_responses.append(task_response) 581 | if self.debug: 582 | print("RLTrainer.train()") 583 | print("task_responses:") 584 | for i, task_response in enumerate(task_responses): 585 | print(i, task_response) 586 | print("") 587 | tokenized_responses = self.reward.tokenizer( 588 | task_responses, padding=True, return_tensors="pt" 589 | ) 590 | rewards = self.reward.get_reward( 591 | tokenized_responses["input_ids"].to(device), 592 | tokenized_responses["attention_mask"].to(device), 593 | ) 594 | 595 | # store memories of the episode / timestep 596 | for i in range(states.shape[0]): 597 | memories.append( 598 | Memory( 599 | *map( 600 | lambda x: x.detach().cpu(), 601 | ( 602 | states[i, :], 603 | actions[i, :], 604 | sequences[i, :], 605 | values[i, :], 606 | rewards[i], 607 | actions_log_probs[i, :], 608 | sequences_mask[i, :], 609 | ), 610 | ) 611 | ) 612 | ) 613 | 614 | # log the memories in the conversation log 615 | for i in range(states.shape[0]): 616 | self.conversation_log.add_conversation( 617 | inputs[i], 618 | completions[i], 619 | rewards[i].detach().cpu(), 620 | current_learn_counter, 621 | ) 622 | 623 | # learn from memories 624 | print( 625 | f"Learning counter: {current_time} of {update_timesteps}" 626 | ) 627 | if (current_time % update_timesteps == 0) and ( 628 | current_time != 0 629 | ): 630 | checkpoint_counter += 1 631 | self.conversation_log.show(current_learn_counter) 632 | self.learn(memories) 633 | memories.clear() 634 | current_time = 0 635 | current_learn_counter += 1 636 | 637 | if (checkpoint_counter % update_checkpoint == 0) and ( 638 | checkpoint_counter != 0 639 | ): 640 | self.save_checkpoint(eps) 641 | checkpoint_counter = 0 642 | 643 | self.actorcritic.critic.save() 644 | self.actorcritic.actor.save() 645 | # print("Show conversations log") 646 | # self.conversation_log.show() 647 | print("End RL Training") 648 | -------------------------------------------------------------------------------- /chatllama/rlhf/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from beartype import beartype 3 | from beartype.typing import Optional 4 | from plotly import graph_objects as go 5 | 6 | 7 | class TrainingStats: 8 | """Training statistics 9 | 10 | Attributes: 11 | training_loss (List): List of training losses 12 | training_accuracy (List): List of training accuracies 13 | value_loss (List): List of value losses 14 | validation_loss (List): List of validation losses 15 | validation_accuracy (List): List of validation accuracies 16 | """ 17 | 18 | def __init__(self): 19 | self.training_loss = [] 20 | self.training_accuracy = [] 21 | self.value_loss = [] 22 | self.validation_loss = [] 23 | self.validation_accuracy = [] 24 | 25 | def plot(self): 26 | """Plot the training statistics using plotly""" 27 | fig = go.Figure() 28 | if len(self.training_loss) > 0: 29 | fig.add_trace( 30 | go.Scatter(y=self.training_loss, name="Training loss") 31 | ) 32 | if len(self.training_accuracy) > 0: 33 | fig.add_trace( 34 | go.Scatter(y=self.training_accuracy, name="Training accuracy") 35 | ) 36 | if len(self.value_loss) > 0: 37 | fig.add_trace(go.Scatter(y=self.value_loss, name="Value loss")) 38 | if len(self.validation_loss) > 0: 39 | fig.add_trace( 40 | go.Scatter(y=self.validation_loss, name="Validation loss") 41 | ) 42 | if len(self.validation_accuracy) > 0: 43 | fig.add_trace( 44 | go.Scatter( 45 | y=self.validation_accuracy, name="Validation accuracy" 46 | ) 47 | ) 48 | fig.update_layout( 49 | showlegend=True, xaxis_type="log", xaxis_title="steps" 50 | ) 51 | fig.show() 52 | 53 | 54 | class ConversationLog: 55 | """Save the conversation: 56 | (user input, model output, rewards and learn_counter) 57 | during the RL training loop. Additionally, in order to be able to compare 58 | the initial dataset of answers to the prompts, we store also the original 59 | performance of the generation: 60 | (generation_input, generation_output, generation_reward) 61 | """ 62 | 63 | def __init__(self): 64 | self.conversation = [] 65 | 66 | @beartype 67 | def add_conversation( 68 | self, 69 | user_input: str, 70 | model_output: str, 71 | reward: float, 72 | learn_counter: int, 73 | previous_reward: float, 74 | previous_completion: str, 75 | ): 76 | """Add a conversation to the log 77 | 78 | Args: 79 | user_input (str): User input / initial prompt 80 | model_output (str): Completion of the LLM model 81 | reward (float): Reward of the reward model assigned to the output 82 | learn_counter (int): Number of the learning iteration to 83 | distinguish the conversations that happens at different 84 | points of the training loop 85 | previous_reward (float): Reward of the reward model assigned to 86 | the output of original dataset 87 | previous_completion (str): Completion of the LLM model of the 88 | original dataset 89 | """ 90 | self.conversation.append( 91 | { 92 | "user_input": user_input, 93 | "model_output": model_output, 94 | "reward": reward, 95 | "learn_counter": learn_counter, 96 | "previous_reward": previous_reward, 97 | "previous_completion": previous_completion, 98 | } 99 | ) 100 | 101 | def save(self, path: Optional[str] = "./conversation.json"): 102 | with open(path, "r") as f: 103 | conversation = json.load(f) 104 | conversation.extend(self.conversation) 105 | with open(path, "w") as f: 106 | json.dump(conversation, f) 107 | 108 | def load(self, path: Optional[str] = "./conversation.json"): 109 | with open(path, "r") as f: 110 | self.conversation = json.load(f) 111 | 112 | def show(self, current_iteration: int = None): 113 | """Show the conversation log 114 | 115 | Args: 116 | current_iteration (int): Current iteration of the training loop, 117 | if not None, print only the conversations that happened at 118 | 119 | """ 120 | for i, c in enumerate(self.conversation): 121 | if current_iteration is None: 122 | print( 123 | f"##########################################\n" 124 | f"Conversation {i} at learn_counter " 125 | f"{c['learn_counter']}\n" 126 | f"##########################################\n" 127 | f"## User Input:\n\n{c['user_input']}\n\n" 128 | f"## Model Output:\n\n{c['model_output']}\n\n" 129 | f"## Reward: {c['reward']}\n\n" 130 | f"## Previous Reward: {c['previous_reward']}\n\n" 131 | ) 132 | else: 133 | if current_iteration == c["learn_counter"]: 134 | print( 135 | f"##########################################\n" 136 | f"Conversation {i} at learn_counter " 137 | f"{c['learn_counter']}\n" 138 | f"##########################################\n" 139 | f"## User Input:\n\n{c['user_input']}\n\n" 140 | f"## Model Output:\n\n{c['model_output']}\n\n" 141 | f"## Reward: {c['reward']}\n\n" 142 | f"## Previous Reward: {c['previous_reward']}\n\n" 143 | ) 144 | -------------------------------------------------------------------------------- /generate_dataset.py: -------------------------------------------------------------------------------- 1 | from langchain import OpenAI, LLMChain, PromptTemplate 2 | from langchain.chains.conversation.memory import ( 3 | ConversationBufferWindowMemory, 4 | ) 5 | 6 | from chatllama.langchain_modules.prompt_templates import ( 7 | PERSON_CHATBOT_TEMPLATE, 8 | AI_CHATBOT_TEMPLATE, 9 | ) 10 | 11 | 12 | CONVERSATION_LENGTH = 20 13 | 14 | 15 | def create_conversation(human_agent: LLMChain, bot_agent: LLMChain): 16 | conversation = [] 17 | chatbot_output = "" 18 | for i in range(CONVERSATION_LENGTH): 19 | # Human agent goes first 20 | human_output = human_agent.run(chatbot_input=chatbot_output) 21 | conversation.append(f"Human: {human_output}") 22 | chatbot_output = bot_agent.run(human_input=human_output) 23 | conversation.append(f"AI: {chatbot_output}") 24 | return "\n".join(conversation) 25 | 26 | 27 | def build_agents(): 28 | # be aware that too long completions will not fit the sequence length 29 | # of possible critic or reward models ... 30 | llm = OpenAI(max_tokens=2048, temperature=0.7) 31 | human_template = PromptTemplate(**PERSON_CHATBOT_TEMPLATE) 32 | human_agent = LLMChain( 33 | llm=llm, 34 | prompt=human_template, 35 | memory=ConversationBufferWindowMemory(k=4), 36 | ) 37 | bot_template = PromptTemplate(**AI_CHATBOT_TEMPLATE) 38 | bot_agent = LLMChain( 39 | llm=llm, 40 | prompt=bot_template, 41 | memory=ConversationBufferWindowMemory(k=4), 42 | ) 43 | return human_agent, bot_agent 44 | 45 | 46 | def main(): 47 | from argparse import ArgumentParser 48 | 49 | parser = ArgumentParser() 50 | parser.add_argument("--num_conversations", type=int, default=1000) 51 | parser.add_argument("--output_file", type=str, default="conversations.txt") 52 | args = parser.parse_args() 53 | conversations = [] 54 | for conv in range(args.num_conversations): 55 | human_agent, bot_agent = build_agents() 56 | conversation = create_conversation(human_agent, bot_agent) 57 | conversations.append(conversation) 58 | with open(args.output_file, "w") as f: 59 | f.write("\n\nNEW CONVERSATION\n\n".join(conversations)) 60 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyllama 2 | fairscale 3 | langchain 4 | beartype 5 | einops 6 | plotly 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import os 3 | 4 | here = os.path.dirname(os.path.realpath(__file__)) 5 | HAS_CUDA = os.system("nvidia-smi > /dev/null 2>&1") == 0 6 | 7 | VERSION = "0.0.4" 8 | DESCRIPTION = "ChatLLaMA: Open and Efficient Foundation Language Models Runnable In A Single GPU" 9 | 10 | packages = [ 11 | "chatllama", 12 | ] 13 | 14 | 15 | def read_file(filename: str): 16 | try: 17 | lines = [] 18 | with open(filename) as file: 19 | lines = file.readlines() 20 | lines = [line.rstrip() for line in lines if not line.startswith('#')] 21 | return lines 22 | except: 23 | return [] 24 | 25 | 26 | def package_files(ds): 27 | paths = [] 28 | for d in ds: 29 | for (path, directories, filenames) in os.walk(d): 30 | for filename in filenames: 31 | if '__pycache__' not in str(filename): 32 | paths.append(str(os.path.join(path, filename))[len('chatllama/'):]) 33 | return paths 34 | 35 | extra_files = package_files(['chatllama/']) 36 | 37 | 38 | setup( 39 | name="chatllama", 40 | version=VERSION, 41 | author_email="", 42 | description=DESCRIPTION, 43 | long_description=open("README.md", "r", encoding="utf-8").read(), 44 | long_description_content_type="text/markdown", 45 | install_requires=read_file(f"{here}/requirements.txt"), 46 | keywords=[ 47 | "ChatLLaMA", "LLaMA" 48 | ], 49 | classifiers=[ 50 | "Development Status :: 3 - Alpha", 51 | "Programming Language :: Python :: 3", 52 | "Programming Language :: Python :: 3.8", 53 | "Programming Language :: Python :: 3.9", 54 | "Programming Language :: Python :: 3.10", 55 | ], 56 | packages=packages, 57 | package_data={"chatllama": extra_files}, 58 | url="https://github.com/juncongmoo/chatllama" 59 | ) 60 | --------------------------------------------------------------------------------