├── .gitignore
├── NOTICE
├── README.md
├── chatllama
├── __init__.py
├── langchain_modules
│ ├── __init__.py
│ └── prompt_templates.py
├── llama_model.py
└── rlhf
│ ├── __init__.py
│ ├── actor.py
│ ├── config.py
│ ├── config.yaml
│ ├── reward.py
│ ├── test.py
│ ├── trainer.py
│ └── utils.py
├── generate_dataset.py
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 | .idea
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
132 | # MacOS DS_Store
133 | .DS_Store
134 |
135 | # Pickle folder
136 | .pkl_memoize_py3
137 |
138 | # Folder where optimized models are stored
139 | optimized_model
140 |
141 | # Config file for tests coverage
142 | .coveragerc
143 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | The code is originally from [nebuly-ai](https://github.com/nebuly-ai/nebullvm/tree/main/apps/accelerate/chatllama) with some changes. More changes will follow up soon. And the original license link is here: https://github.com/nebuly-ai/nebullvm/blob/main/LICENSE
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ChatLLaMA
2 |
3 | > 📢 Open source implementation for LLaMA-based ChatGPT runnable in a single GPU. 15x faster training process than `ChatGPT`
4 |
5 | - 🔥 Please check [`pyllama`](https://github.com/juncongmoo/pyllama) for `LLaMA` installation and `single GPU inference` setup.
6 | - 🔥 To train ChatGPT in 5 mins - [minichatgpt](https://github.com/juncongmoo/minichatgpt)
7 |
8 |
9 | Meta has recently released LLaMA, a collection of foundational large language models ranging from 7 to 65 billion parameters.
10 | LLaMA is creating a lot of excitement because it is smaller than GPT-3 but has better performance. For example, LLaMA's 13B architecture outperforms GPT-3 despite being 10 times smaller. This new collection of fundamental models opens the door to faster inference performance and chatGPT-like real-time assistants, while being cost-effective and running on a single GPU.
11 |
12 | However, LLaMA was not fine-tuned for instruction task with a Reinforcement Learning from Human Feedback (RLHF) training process.
13 |
14 | The good news is that we introduce `ChatLLaMA`, the first open source implementation of LLaMA based on RLHF:
15 |
16 | - A complete open source implementation that enables you to build a ChatGPT-style service based on pre-trained LLaMA models.
17 | - Compared to the original ChatGPT, the training process and single-GPU inference are much faster and cheaper by taking advantage of the smaller size of LLaMA architectures.
18 | - ChatLLaMA has built-in support for DeepSpeed ZERO to speedup the fine-tuning process.
19 | - The library also supports all LLaMA model architectures (7B, 13B, 33B, 65B), so that you can fine-tune the model according to your preferences for training time and inference performance.
20 |
21 |
22 |
23 |
24 | Image from [OpenAI’s blog](https://openai.com/blog/chatgpt).
25 |
26 |
27 | # Installation
28 |
29 | ```
30 | pip install chatllama
31 | ```
32 |
33 |
34 | # Get started with ChatLLaMA
35 |
36 | > :warning: Please note this code represents the algorithmic implementation for RLHF training process of LLaMA and does not contain the model weights. To access the model weights, you need to apply to Meta's [form](https://forms.gle/jk851eBVbX1m5TAv5).
37 |
38 | ChatLLaMA allows you to easily train LLaMA-based architectures in a similar way to ChatGPT, using RLHF.
39 | For example, below is the code to start the training in the case of ChatLLaMA 7B.
40 |
41 | ```python
42 | from chatllama.rlhf.trainer import RLTrainer
43 | from chatllama.rlhf.config import Config
44 |
45 | path = "path_to_config_file.yaml"
46 | config = Config(path=path)
47 | trainer = RLTrainer(config.trainer)
48 | trainer.distillate()
49 | trainer.train()
50 | trainer.training_stats.plot()
51 | ```
52 |
53 | Note that you should provide Meta's original weights and your custom dataset before starting the fine-tuning process. Alternatively, you can generate your own dataset using LangChain's agents.
54 |
55 | ```python
56 | python generate_dataset.py
57 | ```
58 |
59 |
--------------------------------------------------------------------------------
/chatllama/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.0.3'
2 |
--------------------------------------------------------------------------------
/chatllama/langchain_modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/henrywoo/chatllama/6e9fabf78af50ae03ac61021a27628ae5e98363c/chatllama/langchain_modules/__init__.py
--------------------------------------------------------------------------------
/chatllama/langchain_modules/prompt_templates.py:
--------------------------------------------------------------------------------
1 | REWARD_TEMPLATE = dict(
2 | template=(
3 | "Lets pretend that you are a lawyer and you have to"
4 | "evalaute the following completion task from a given"
5 | "assigment with a score between 0 and 5 where 0 represents"
6 | "a bad assignment completion and 5 a perfect completion.\n"
7 | "You MUST evaluate: text quality, content quality and"
8 | "coherence.\n"
9 | "You MUST return only the number that represents your"
10 | "judgment.\n"
11 | "The assignement is:\n{user_input}\n"
12 | "The completion is:\n{completion}\n"
13 | ),
14 | input_variables=["user_input", "completion"],
15 | )
16 |
17 |
18 | AI_CHATBOT_TEMPLATE = dict(
19 | template=(
20 | "Assistant is a large language model trained by Meta and Nebuly.ai\n"
21 | "Assistant is designed to be able to assist with a wide range of "
22 | "tasks, from answering simple questions to providing in-depth "
23 | "explanations and discussions on a wide range of topics. As a "
24 | "language model, Assistant is able to generate human-like text "
25 | "based on the input it receives, allowing it to engage in "
26 | "natural-sounding conversations and provide responses that are "
27 | "coherent and relevant to the topic at hand.\n\n"
28 | "Assistant is constantly learning and improving, and its capabilities "
29 | "are constantly evolving. It is able to process and understand large "
30 | "amounts of text, and can use this knowledge to provide accurate and "
31 | "informative responses to a wide range of questions. Additionally, "
32 | "Assistant is able to generate its own text based on the input it "
33 | "receives, allowing it to engage in discussions and provide "
34 | "explanations and descriptions on a wide range of topics.\n\n"
35 | "Overall, Assistant is a powerful tool that can help with a wide "
36 | "range of tasks and provide valuable insights and information on a "
37 | "wide range of topics. Whether you need help with a specific "
38 | "question or just want to have a conversation about a particular "
39 | "topic, Assistant is here to assist.\n\n{history}\n\n"
40 | "Human: {human_input}\n"
41 | "Assistant:"
42 | ),
43 | input_variables=["history", "human_input"],
44 | )
45 |
46 |
47 | PERSON_CHATBOT_TEMPLATE = dict(
48 | template=(
49 | "You are a human chatting with a chatbot. The chatbot is a large "
50 | "language model trained by Meta and Nebuly-ai\n"
51 | "The chatbot is designed to be able to assist you with a wide range "
52 | "of tasks, from answering simple questions to providing in-depth "
53 | "explanations and discussions on a wide range of topics. You are a "
54 | "human and you are testing the chatbot. Ask the chatbot questions and"
55 | "see how it responds. You can also ask the chatbot to tell you a "
56 | "story."
57 | "\n\n{history}\n\n"
58 | "Chatbot: {chatbot_input}\n"
59 | "Human:"
60 | ),
61 | input_variables=["history", "chatbot_input"],
62 | )
63 |
--------------------------------------------------------------------------------
/chatllama/llama_model.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 | from typing import Tuple, List, Union
5 |
6 | import torch.distributed
7 | import torch.nn as nn
8 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel
9 | from fairscale.nn.model_parallel.layers import (
10 | ParallelEmbedding,
11 | ColumnParallelLinear,
12 | )
13 | from llama import ModelArgs, Tokenizer
14 | from llama.generation import sample_top_p
15 | from llama.model import TransformerBlock, RMSNorm, precompute_freqs_cis
16 |
17 |
18 | class HFLikeTokenizer:
19 | def __init__(self, tokenizer: Tokenizer):
20 | self.tokenizer = tokenizer
21 |
22 | def __call__(self, texts: Union[List[str], str], *args, **kwargs):
23 | if isinstance(texts, str):
24 | text = self.tokenizer.encode(texts, bos=True, eos=True)
25 | tokens = torch.tensor(text).cuda().long()
26 | else:
27 | texts = [
28 | self.tokenizer.encode(text, bos=True, eos=True)
29 | for text in texts
30 | ]
31 | max_len = max(len(text) for text in texts)
32 | tokens = (
33 | torch.full((len(texts), max_len), self.tokenizer.pad_id)
34 | .cuda()
35 | .long()
36 | )
37 | for i, text in enumerate(texts):
38 | tokens[i, : len(text)] = torch.tensor(text).cuda().long()
39 | output = {
40 | "input_ids": tokens,
41 | "attention_mask": (tokens != self.tokenizer.pad_id).long(),
42 | }
43 | return output
44 |
45 | def decode(self, tokens):
46 | return self.tokenizer.decode(tokens)
47 |
48 |
49 | class Transformer(nn.Module):
50 | def __init__(self, params: ModelArgs):
51 | super().__init__()
52 | self.params = params
53 | self.vocab_size = params.vocab_size
54 | self.n_layers = params.n_layers
55 |
56 | self.tok_embeddings = ParallelEmbedding(
57 | params.vocab_size, params.dim, init_method=lambda x: x
58 | )
59 |
60 | self.layers = torch.nn.ModuleList()
61 | for layer_id in range(params.n_layers):
62 | self.layers.append(TransformerBlock(layer_id, params))
63 |
64 | self.norm = RMSNorm(params.dim, eps=params.norm_eps)
65 | self.output = ColumnParallelLinear(
66 | params.dim, params.vocab_size, bias=False, init_method=lambda x: x
67 | )
68 |
69 | self.freqs_cis = precompute_freqs_cis(
70 | self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
71 | )
72 |
73 | def forward(self, tokens: torch.Tensor, attention_mask: torch.Tensor):
74 | start_pos = int(torch.argmax(attention_mask.detach(), dim=-1).item())
75 | logits = self._forward(tokens, start_pos)
76 | return logits
77 |
78 | def _forward(self, tokens: torch.Tensor, start_pos: int):
79 | _bsz, seqlen = tokens.shape
80 | h = self.tok_embeddings(tokens)
81 | self.freqs_cis = self.freqs_cis.to(h.device)
82 | freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] # noqa E203
83 |
84 | mask = None
85 | if seqlen > 1:
86 | mask = torch.full(
87 | (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
88 | )
89 | mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
90 |
91 | for layer in self.layers:
92 | h = layer(h, start_pos, freqs_cis, mask)
93 | h = self.norm(h)
94 | output = self.output(h[:, -1, :]) # only compute last logits
95 | return output.float()
96 |
97 | @torch.no_grad()
98 | def generate(
99 | self,
100 | inputs: torch.Tensor,
101 | attention_mask: torch.Tensor,
102 | max_length: int,
103 | temperature: float,
104 | top_p: float = 1.0,
105 | ):
106 | prompt_size = inputs.shape[1]
107 | total_len = min(self.params.max_seq_len, max_length + prompt_size)
108 | start_pos = prompt_size # We assume left padding
109 | prev_pos = 0
110 | generated_tokens = []
111 | for cur_pos in range(start_pos, total_len):
112 | logits = self._forward(inputs[:, prev_pos:cur_pos], prev_pos)
113 | if temperature > 0:
114 | probs = torch.softmax(logits / temperature, dim=-1)
115 | next_token = sample_top_p(probs, top_p)
116 | else:
117 | next_token = torch.argmax(logits, dim=-1)
118 | next_token = next_token.reshape(-1)
119 | generated_tokens.append(next_token)
120 | prev_pos = cur_pos
121 | return torch.stack(generated_tokens, dim=1)
122 |
123 |
124 | def setup_model_parallel() -> Tuple[int, int]:
125 | local_rank = int(os.environ.get("LOCAL_RANK", -1))
126 | world_size = int(os.environ.get("WORLD_SIZE", -1))
127 |
128 | torch.distributed.init_process_group("nccl")
129 | initialize_model_parallel(world_size)
130 | torch.cuda.set_device(local_rank)
131 |
132 | # seed must be the same in all processes
133 | torch.manual_seed(1)
134 | return local_rank, world_size
135 |
136 |
137 | def load_checkpoints(
138 | ckpt_dir: str, local_rank: int, world_size: int
139 | ) -> Tuple[dict, dict]:
140 | checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
141 | assert world_size == len(checkpoints), (
142 | f"Loading a checkpoint for MP={len(checkpoints)} but world "
143 | f"size is {world_size}"
144 | )
145 | ckpt_path = checkpoints[local_rank]
146 | print("Loading")
147 | checkpoint = torch.load(ckpt_path, map_location="cpu")
148 | with open(Path(ckpt_dir) / "params.json", "r") as f:
149 | params = json.loads(f.read())
150 | return checkpoint, params
151 |
152 |
153 | def load_model(
154 | ckpt_dir: str,
155 | tokenizer_path: str,
156 | local_rank: int,
157 | world_size: int,
158 | max_batch_size: int = 32,
159 | ) -> Tuple[Transformer, HFLikeTokenizer]:
160 | checkpoint, params = load_checkpoints(ckpt_dir, local_rank, world_size)
161 | model_args: ModelArgs = ModelArgs(
162 | max_seq_len=1024, max_batch_size=max_batch_size, **params
163 | )
164 | tokenizer = Tokenizer(model_path=tokenizer_path)
165 | model_args.vocab_size = tokenizer.n_words
166 | torch.set_default_tensor_type(torch.cuda.HalfTensor)
167 | model = Transformer(model_args)
168 | torch.set_default_tensor_type(torch.FloatTensor)
169 | model.load_state_dict(checkpoint, strict=False)
170 | tokenizer = HFLikeTokenizer(tokenizer)
171 | return model, tokenizer
172 |
173 |
174 | def generate(
175 | model: Transformer,
176 | tokenizer: Tokenizer,
177 | prompts: List[str],
178 | max_gen_len: int,
179 | temperature: float = 0.8,
180 | top_p: float = 0.95,
181 | ) -> List[str]:
182 | bsz = len(prompts)
183 | params = model.params
184 | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
185 |
186 | prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts]
187 |
188 | min_prompt_size = min([len(t) for t in prompt_tokens])
189 | max_prompt_size = max([len(t) for t in prompt_tokens])
190 |
191 | total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
192 |
193 | tokens = torch.full((bsz, total_len), tokenizer.pad_id).cuda().long()
194 | for k, t in enumerate(prompt_tokens):
195 | tokens[k, : len(t)] = torch.tensor(t).long()
196 | input_text_mask = tokens != tokenizer.pad_id
197 | start_pos = min_prompt_size
198 | prev_pos = 0
199 | for cur_pos in range(start_pos, total_len):
200 | logits = model._forward(tokens[:, prev_pos:cur_pos], prev_pos)
201 | if temperature > 0:
202 | probs = torch.softmax(logits / temperature, dim=-1)
203 | next_token = sample_top_p(probs, top_p)
204 | else:
205 | next_token = torch.argmax(logits, dim=-1)
206 | next_token = next_token.reshape(-1)
207 | # only replace token if prompt has already been generated
208 | next_token = torch.where(
209 | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
210 | )
211 | tokens[:, cur_pos] = next_token
212 | prev_pos = cur_pos
213 |
214 | decoded = []
215 | for i, t in enumerate(tokens.tolist()):
216 | # cut to max gen len
217 | t = t[: len(prompt_tokens[i]) + max_gen_len]
218 | # cut to eos tok if any
219 | try:
220 | t = t[: t.index(tokenizer.eos_id)]
221 | except ValueError:
222 | pass
223 | decoded.append(tokenizer.decode(t))
224 | return decoded
225 |
--------------------------------------------------------------------------------
/chatllama/rlhf/__init__.py:
--------------------------------------------------------------------------------
1 | """RLHF implementation inspired to Lucidrains' implementation."""
2 |
--------------------------------------------------------------------------------
/chatllama/rlhf/actor.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import torch
5 | from beartype import beartype
6 | from beartype.typing import Optional, Tuple
7 | from einops import rearrange
8 | from torch.utils.data import Dataset, DataLoader
9 | from config import ConfigActor
10 | from utils import TrainingStats
11 |
12 | from chatllama.llama_model import load_model
13 |
14 |
15 | class ActorModel(torch.nn.Module):
16 | """Actor model that generates the augmented prompt from the initial
17 | user_input. The aim is to train this model to generate better prompts.
18 |
19 | Attributes:
20 | model: The model from LLaMA to be used
21 | tokenizer: The LLaMA tokenizer
22 | max_model_tokens (int): Maximum number of tokens that the model can
23 | handle
24 | config (ConfigActor): Configuration for the actor model
25 |
26 | Methods:
27 | load: Load the model from a path
28 | save: Save the model to a path
29 | forward: Compute the action logits for a given sequence.
30 | generate: Generate a sequence from a given prompt
31 | """
32 |
33 | def __init__(self, config: ConfigActor) -> None:
34 | super().__init__()
35 | # load the model
36 |
37 | self.max_model_tokens = 1024
38 | self.model, self.tokenizer = load_model(
39 | ckpt_dir=config.model_folder,
40 | tokenizer_path=config.tokenizer_folder,
41 | local_rank=int(os.environ.get("LOCAL_RANK", -1)),
42 | world_size=int(os.environ.get("WORLD_SIZE", -1)),
43 | max_batch_size=config.batch_size,
44 | )
45 | # save config
46 | self.config = config
47 |
48 | def parameters(self, **kwargs):
49 | """Return the parameters of the model
50 |
51 | Args:
52 | **kwargs:
53 | """
54 | return self.model.parameters()
55 |
56 | @beartype
57 | def forward(
58 | self, sequences: torch.Tensor, sequences_mask: torch.Tensor
59 | ) -> torch.Tensor:
60 | """Generate logits to have probability distribution over the vocabulary
61 | of the actions
62 |
63 | Args:
64 | sequences (torch.Tensor): Sequences of states and actions used to
65 | compute token logits for the whole list of sequences
66 | attention_mask (torch.Tensor): Mask for the sequences attention
67 |
68 | Returns:
69 | logits (torch.Tensor): Logits for the actions taken
70 | """
71 | model_output = self.model.forward(
72 | sequences, attention_mask=sequences_mask
73 | )
74 | if self.config.debug:
75 | print("ActorModel.forward")
76 | print("model_output_logits shape", model_output.logits.shape)
77 | print("model_output logits", model_output.logits)
78 | return model_output.logits
79 |
80 | @beartype
81 | @torch.no_grad()
82 | def generate(
83 | self, states: torch.Tensor, state_mask: torch.Tensor
84 | ) -> Tuple:
85 | """Generate actions and sequences=[states, actions] from state
86 | (i.e. input of the prompt generator model)
87 |
88 | Args:
89 | state (torch.Tensor): the input of the user
90 | state_mask (torch.Tensor): Mask for the state input (for padding)
91 |
92 | Returns:
93 | actions (torch.Tensor): Actions generated from the state
94 | sequences (torch.Tensor): Sequences generated from the
95 | state as [states, actions]
96 | """
97 | max_sequence = states.shape[1]
98 | max_tokens = self.config.max_tokens + max_sequence
99 | temperature = self.config.temperature
100 | # What if the states + completion are longer than the max context of
101 | # the model?
102 | sequences = self.model.generate(
103 | inputs=states,
104 | attention_mask=state_mask,
105 | max_length=max_tokens,
106 | temperature=temperature,
107 | )
108 | actions = sequences[:, states.shape[1] :] # noqa E203
109 | if self.config.debug:
110 | print("ActorModel.generate")
111 | print("state", states)
112 | print("state shape", states.shape)
113 | print("sequence shape", sequences.shape)
114 | print("sequence", sequences)
115 | print("actions shape", actions.shape)
116 | print("actions", actions)
117 | return actions, sequences
118 |
119 | @beartype
120 | def load(self, path: Optional[str] = None) -> None:
121 | """Load the model from the path
122 |
123 | Args:
124 | path (str): Path to the model
125 | """
126 | if path is None:
127 | path = self.config.model_folder + "/" + self.config.model + ".pt"
128 | if os.path.exists(self.config.model_folder) is False:
129 | os.mkdir(self.config.model_folder)
130 | print(
131 | f"Impossible to load the model: {path}"
132 | f"The path doesn't exist."
133 | )
134 | return
135 | # load the model
136 | if os.path.exists(path) is False:
137 | print(
138 | f"Impossible to load the model: {path}"
139 | f"The path doesn't exist."
140 | )
141 | return
142 | model_dict = torch.load(path)
143 | self.model.load_state_dict(model_dict["model"])
144 |
145 | @beartype
146 | def save(self, path: Optional[str] = None) -> None:
147 | """Save the model to the path
148 |
149 | Args:
150 | path (Optional[str], optional): Path to store the model.
151 | Defaults to None.
152 | """
153 | if path is None:
154 | path = self.config.model_folder + "/" + self.config.model + ".pt"
155 | if os.path.exists(self.config.model_folder) is False:
156 | os.mkdir(self.config.model_folder)
157 | torch.save({"model": self.model.state_dict()}, path)
158 |
159 |
160 | class ActorDataset(Dataset):
161 | """Dataset for the pretraining of the actor model
162 | read a json file with the following format:
163 | [
164 | {
165 | "user_input": "..."
166 | "completion": "..."
167 | } ,
168 | ...
169 | ]
170 | Where:
171 | user_input: the input of the user
172 | completion: the output of the user
173 | """
174 |
175 | def __init__(self, path: str, device: torch.device) -> None:
176 | self.device = device
177 | self.path = path
178 | with open(path, "r") as f:
179 | data = json.load(f)
180 | self.data = [
181 | d["user_input"] + "\n\n###\n\n" + d["completion"] for d in data
182 | ]
183 | self.len = len(self.data)
184 |
185 | def __getitem__(self, idx):
186 | return self.data[idx]
187 |
188 | def __len__(
189 | self,
190 | ):
191 | return self.len
192 |
193 |
194 | class ActorTrainer:
195 | """Used to pre-train the actor model to generate better prompts.
196 |
197 | Args:
198 | config (ConfigActor): Configuration for the actor model
199 |
200 | Attributes:
201 | config (ConfigActor): Configuration for the actor model
202 | model (ActorModel): Actor model
203 | loss_function (torch.nn.CrossEntropyLoss): Loss function
204 | optimizer (torch.optim.Adam): Optimizer
205 | validation_flag (bool): Flag to indicate if the validation dataset
206 | is provided
207 | training_stats (TrainingStats): Training statistics
208 |
209 | Methods:
210 | train: Train the actor model
211 | """
212 |
213 | def __init__(self, config: ConfigActor) -> None:
214 | # load the model
215 | self.config = config
216 | self.model = ActorModel(config)
217 | self.loss_function = torch.nn.CrossEntropyLoss()
218 | self.optimizer = torch.optim.Adam(
219 | self.model.parameters(), lr=config.lr
220 | )
221 | self.validation_flag = False
222 | self.training_stats = TrainingStats()
223 | if not os.path.exists(config.model_folder):
224 | os.mkdir(config.model_folder)
225 | if config.validation_dataset_path is not None:
226 | self.validation_flag = True
227 |
228 | def train(
229 | self,
230 | ) -> None:
231 | print("Start Actor Model Pretraining")
232 | # get config parameters
233 | train_dataset_path = self.config.train_dataset_path
234 | validation_dataset_path = self.config.validation_dataset_path
235 | batch_size = self.config.batch_size
236 | epochs = self.config.epochs
237 | device = self.config.device
238 |
239 | # create dataloaders
240 | train_dataset = ActorDataset(train_dataset_path, device=device)
241 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
242 | if self.validation_flag:
243 | eval_dataset = ActorDataset(validation_dataset_path, device=device)
244 | validation_dataloader = DataLoader(
245 | eval_dataset, batch_size=batch_size
246 | )
247 |
248 | # compute the number of iterations
249 | n_iter = int(len(train_dataset) / batch_size)
250 |
251 | # traing loop
252 | for epoch in range(epochs):
253 | self.model.train()
254 | for i, input_output in enumerate(train_dataloader):
255 | input_output_tokenized = self.model.tokenizer(
256 | input_output,
257 | return_tensors="pt",
258 | padding=True,
259 | truncation=True,
260 | )
261 | training_output = input_output_tokenized["input_ids"][:, 1:]
262 | training_input = input_output_tokenized["input_ids"][:, :-1]
263 | attention_mask = input_output_tokenized["attention_mask"][
264 | :, :-1
265 | ]
266 | training_output = training_output.to(device)
267 | training_input = training_input.to(device)
268 | attention_mask = attention_mask.to(device)
269 |
270 | # forward pass
271 | est_output = self.model.forward(training_input, attention_mask)
272 | est_output = rearrange(est_output, "b s v -> (b s) v")
273 | training_output = rearrange(training_output, "b s -> (b s)")
274 | loss = self.loss_function(est_output, training_output)
275 | self.training_stats.training_loss.append(loss.item())
276 |
277 | # backward pass
278 | self.optimizer.zero_grad()
279 | loss.backward()
280 | self.optimizer.step()
281 |
282 | # print progress
283 | if i % self.config.iteration_per_print == 0:
284 | print(
285 | f"Epoch: {epoch+1}/{epochs}, "
286 | f"Iteration: {i+1}/{n_iter}, "
287 | f"Training Loss: {loss}"
288 | )
289 | if self.validation_flag:
290 | self.model.eval()
291 | for i, input_output in enumerate(validation_dataloader):
292 | input_output_tokenized = self.model.tokenizer(
293 | input_output, return_tensors="pt", padding=True
294 | )
295 | validation_output = input_output_tokenized["input_ids"][
296 | :, 1:
297 | ]
298 | validation_input = input_output_tokenized["input_ids"][
299 | :, :-1
300 | ]
301 | attention_mask = input_output_tokenized["attention_mask"][
302 | :, :-1
303 | ]
304 |
305 | # forward pass
306 | est_output = self.model.forward(
307 | validation_input, attention_mask
308 | )
309 | validation_output = rearrange(
310 | validation_output, "b s -> (b s)"
311 | )
312 | est_output = rearrange(est_output, "b s v -> (b s) v")
313 | loss = self.loss_function(est_output, validation_output)
314 | self.training_stats.validation_loss.append(loss.item())
315 |
316 | # print progress
317 | if i % self.config.iteration_per_print == 0:
318 | print(
319 | f"Epoch: {epoch+1}/{epochs}, "
320 | f"Iteration: {i+1}/{n_iter}, "
321 | f"Validation Loss: {loss}"
322 | )
323 | self.model.save()
324 | print("Training Finished ")
325 |
--------------------------------------------------------------------------------
/chatllama/rlhf/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import os
3 | from dataclasses import dataclass
4 |
5 | import torch
6 | from beartype import beartype
7 | from beartype.typing import Optional
8 |
9 |
10 | @dataclass
11 | class ConfigReward:
12 | """Config parameters for the reward model
13 |
14 | Attributes:
15 | model (str): Model to be used for the reward model
16 | model_folder (str): Path to the folder where model are stored (used
17 | to load / store finetuned model)
18 | device (torch.device): Device to be used for the reward model
19 | model_head_hidden_size (int): Hidden size of the reward model head
20 | debug (bool): enable prints for Debugging
21 | train_dataset_path (Optional[str]): Path to the training dataset.
22 | Default to None. To be specified only for the reward model trainig.
23 | validation_dataset_path (Optional[str]): Path to the validation
24 | dataset. Default to None. To be specified only for the reward
25 | model trainig.
26 | batch_size (Optional[int]): Batch size to train the reward model.
27 | Default to None. To be specified only for the reward model
28 | trainig.
29 | epochs (Optional[int]): Number of epochs to train the reward model.
30 | Default to None. To be specified only for the reward model
31 | trainig.
32 | iteration_per_print (Optional[int]): Number of iterations to print
33 | the training loss. Default to None. To be specified only for the
34 | reward model trainig.
35 | lr (Optional[float]): Learning rate for the reward model. Default to
36 | None. To be specified only for the reward model distillation.
37 | llm_model (Optional[str]): Model to be used for the language model
38 | (LLM). Default to None.
39 | llm_max_tokens (Optional[int]): Max tokens for the LLM. Default to
40 | None.
41 | llm_temperature (Optional[float]): Temperature for the LLM. Default
42 | to None.
43 | """
44 |
45 | model: str
46 | model_folder: str
47 | device: torch.device
48 | model_head_hidden_size: int
49 | debug: bool
50 | train_dataset_path: Optional[str] = None
51 | validation_dataset_path: Optional[str] = None
52 | batch_size: Optional[int] = None
53 | epochs: Optional[int] = None
54 | iteration_per_print: Optional[int] = None
55 | lr: Optional[float] = None
56 | llm_model: Optional[str] = None
57 | llm_max_tokens: Optional[int] = None
58 | llm_temperature: Optional[float] = None
59 |
60 |
61 | @dataclass
62 | class ConfigActor:
63 | """Config parameters for models
64 |
65 | Attributes:
66 | model (str): Model to be used for the actor
67 | model_folder (str): Path to the folder where model are stored (used
68 | to load / store finetuned model)
69 | max_tokens (int): Max tokens for the actor
70 | temperature (float): Temperature for the actor
71 | device (torch.device): Device to be used for the actor
72 | lr (float): Learning rate for the actor
73 | iteration_per_print (int): Number of iterations to print the
74 | training loss
75 | batch_size (int): Batch size to train the actor
76 | epochs (int): Number of epochs to train the actor
77 | debug (bool): Enable prints for debugging
78 | train_dataset_path (str): Path to the training dataset
79 | validation_dataset_path (Optional[str]): Path to the validation dataset
80 | """
81 |
82 | model: str
83 | model_folder: str
84 | tokenizer_folder: str
85 | max_tokens: int
86 | temperature: float
87 | device: torch.device
88 | lr: float
89 | iteration_per_print: int
90 | batch_size: int
91 | epochs: int
92 | debug: bool
93 | train_dataset_path: str
94 | validation_dataset_path: Optional[str] = None
95 |
96 |
97 | @dataclass
98 | class ConfigTrainer:
99 | """Config parameters for the trainer, used to configure the reinforcement
100 | learning training loop
101 |
102 | Attributes:
103 | update_timesteps (int): Number of timesteps to update the actor
104 | and critic. Every time update_timesteps timesteps are collected,
105 | the training loop for the actor and critic is executed using the
106 | memory buffer to learn the policy.
107 | temperature (float): Temperature for the actor and critic
108 | max_seq_len (int): Max sequence length for the actor and critic
109 | num_examples (int): Number of examples to generate for the actor
110 | and critic. For each iteration of timestep, num_examples are
111 | sampled from the prompt dataset, processed and stored in the
112 | memory buffer.
113 | actor_lr (float): Learning rate for the actor when training with
114 | reinforcement learning
115 | critic_lr (float): Learning rate for the critic when training with
116 | reinforcement learning
117 | num_episodes (int): Number of episodes, each episodes consist of
118 | a number of timesteps that are used to generate examples
119 | stored in the memory buffer.
120 | max_timesteps (int): Max timesteps for the actor and critic.
121 | for each timestep a set of examples are sampled and used to
122 | generate a completion and a reward.
123 | batch_size (int): Batch size to train the actor and critic.
124 | This batch is used to aggregate the memory from the memory buffer
125 | for the actual training of the actor and critic models.
126 | epochs (int): Number of epochs to train the actor and critic.
127 | actor_eps_clip (float): Epsilon clip for the actor
128 | critic_eps_clip (float): Epsilon clip for the critic
129 | beta_s (float): Beta for the actor and critic
130 | update_checkpoint (int): Number of timesteps to update the checkpoint
131 | llm_model_id (str): Model id for the llm
132 | llm_max_tokens (int): Max tokens for the llm
133 | llm_temperature (float): Temperature for the llm
134 | device (torch.device): Device to be used for the actor and critici
135 | checkpoint_folder (str): Folder to store the checkpoints while training
136 | debug (bool): Enable prints for debugging
137 | """
138 |
139 | update_timesteps: int
140 | num_examples: int
141 | actor_lr: float
142 | critic_lr: float
143 | num_episodes: int
144 | max_timesteps: int
145 | examples_path: str
146 | batch_size: int
147 | epochs: int
148 | actor_eps_clip: float
149 | critic_eps_clip: float
150 | beta_s: float
151 | update_checkpoint: int
152 | llm_model_id: str
153 | llm_max_tokens: int
154 | llm_temperature: float
155 | device: torch.device
156 | checkpoint_folder: str
157 | debug: bool
158 |
159 |
160 | class Config:
161 | """Store the config parameters for the whole pipeline
162 |
163 | Args:
164 | trainer_dict (Optional[Dict]): Dictionary with the config parameters
165 | for the trainer. Default to None. If None, the config.yaml is
166 | used.
167 | actor_dict (Optional[Dict]): Dictionary with the config parameters
168 | for the actor. Default to None. If None, the config.yaml is
169 | used.
170 | critic_dict (Optional[Dict]): Dictionary with the config parameters
171 | for the critic. Default to None. If None, the config.yaml is
172 | used.
173 | reward_dict (Optional[Dict]): Dictionary with the config parameters
174 | for the reward. Default to None. If None, the config.yaml is
175 | used.
176 | device (Optional[torch.device]): Device to be used for the actor
177 | and critic. Default to None. If None, the device available is
178 | used.
179 | debug (Optional[bool]): Enable prints for debugging. Default to False.
180 |
181 | Attributes:
182 | trainer (ConfigTrainer): Config parameters for the trainer
183 | actor (ConfigActor): Config parameters for the actor
184 | critic (ConfigCritic): Config parameters for the critic
185 | reward (ConfigReward): Config parameters for the reward
186 | """
187 |
188 | @beartype
189 | def __init__(
190 | self,
191 | path: str,
192 | device: Optional[torch.device] = None,
193 | debug: Optional[bool] = False,
194 | ) -> None:
195 |
196 | # if not specified use the device available
197 | if device is None:
198 | device = torch.device(
199 | "cuda" if torch.cuda.is_available() else "cpu"
200 | )
201 | print(f"Current device used:{str(device)}")
202 |
203 | if path is None or os.path.exists(path) is False:
204 | raise ValueError("Path to the config.yaml is not valid")
205 |
206 | # Read the config from yaml
207 | with open(path, "r") as c:
208 | config = yaml.safe_load(c)
209 |
210 | trainer_dict = config["trainer_config"]
211 | actor_dict = config["actor_config"]
212 | critic_dict = config["critic_config"]
213 | reward_dict = config["reward_config"]
214 |
215 | # Trainer Config
216 | trainer_dict["device"] = device
217 | trainer_dict["debug"] = debug
218 | self.trainer = ConfigTrainer(**trainer_dict)
219 | # Actor Config
220 | actor_dict["device"] = device
221 | actor_dict["debug"] = debug
222 | self.actor = ConfigActor(**actor_dict)
223 | # Critic Config
224 | critic_dict["device"] = device
225 | critic_dict["debug"] = debug
226 | self.critic = ConfigReward(**critic_dict)
227 | # Reward Config
228 | reward_dict["device"] = device
229 | reward_dict["debug"] = debug
230 | self.reward = ConfigReward(**reward_dict)
231 |
--------------------------------------------------------------------------------
/chatllama/rlhf/config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | trainer_config:
3 | update_timesteps: 1
4 | num_examples: 2
5 | actor_lr: 0.00001
6 | critic_lr: 0.00001
7 | num_episodes: 10
8 | max_timesteps: 10
9 | examples_path: "dataset/sections_dataset.json"
10 | batch_size: 1
11 | epochs: 5
12 | actor_eps_clip: 0.2
13 | critic_eps_clip: 0.2
14 | beta_s: 0.1
15 | update_checkpoint: 10
16 | llm_model_id: "text-davinci-003"
17 | llm_max_tokens: 1024
18 | llm_temperature: 0.5
19 | checkpoint_folder: "./models/checkpoints"
20 |
21 | actor_config:
22 | model: "llama-7B"
23 | max_tokens: 1024
24 | temperature: 0.9
25 | train_dataset_path: "dataset/sections_dataset.json"
26 | validation_dataset_path: null
27 | batch_size: 16
28 | iteration_per_print: 10
29 | lr: 0.00001
30 | epochs: 1
31 | model_folder: "path-to-checkpoints"
32 |
33 | reward_config:
34 | # model to be chosen are gp2-large, bart-base, longformer-base-4096
35 | model: "longformer-base-4096"
36 | model_head_hidden_size: 2048
37 | model_folder: "./models"
38 | train_dataset_path: "/home/pierpaolo/Documents/optimapi/dataset/sections_dataset.json"
39 | validation_dataset_path: null
40 | batch_size: 64
41 | epochs: 20
42 | iteration_per_print: 10
43 | lr: 0.0001
44 |
45 | critic_config:
46 | # model to be chosen are gp2-large, bart-base, longformer-base-4096
47 | model: "longformer-base-4096"
48 | model_head_hidden_size: 2048
49 | model_folder: "./models"
--------------------------------------------------------------------------------
/chatllama/rlhf/reward.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import torch
5 | from beartype import beartype
6 | from beartype.typing import Optional, Iterable
7 | from einops.layers.torch import Rearrange
8 | from langchain import OpenAI, LLMChain, PromptTemplate
9 | from torch.utils.data import Dataset, DataLoader
10 | from transformers import GPT2Tokenizer, GPT2Model, BartModel
11 | from transformers import BartTokenizer, BartConfig, AutoModel, AutoTokenizer
12 |
13 | from chatllama.langchain_modules.prompt_templates import REWARD_TEMPLATE
14 | from config import ConfigReward
15 | from utils import TrainingStats
16 |
17 |
18 | class RewardModel(torch.nn.Module):
19 | """Model to be trained to predict the reward for RL.
20 | or to be used as Critic in RL.
21 |
22 | Attributes:
23 | model (torch.nn.Module): Model to be used for the reward model
24 | tokenizer (torch.nn.Module): Tokenizer to be used for the reward model
25 | head (torch.nn.Module): Head to be used for the reward model
26 | config (ConfigReward): Config parameters for the reward model
27 | max_model_tokens (int): Maximum sequence length for the reward model
28 |
29 | Methods:
30 | forward: Forward pass of the model (used by the critic)
31 | save: Save the model
32 | load: Load the model
33 | get_reward: Get the reward for a given input (used by the reward model)
34 | """
35 |
36 | def __init__(self, config: ConfigReward) -> None:
37 | super().__init__()
38 | # load the model -- add here other models
39 | head_hidden_size = config.model_head_hidden_size
40 | if config.model == "gpt2-large":
41 | self.max_model_tokens = 1024
42 | self.model = GPT2Model.from_pretrained("gpt2-large")
43 | self.tokenizer = GPT2Tokenizer.from_pretrained(
44 | "gpt2-large",
45 | padding_side="left",
46 | truncation_side="left",
47 | model_max_length=self.max_model_tokens,
48 | )
49 | self.tokenizer.pad_token = self.tokenizer.eos_token
50 | self.head = torch.nn.Sequential(
51 | torch.nn.Linear(self.model.config.n_embd, head_hidden_size),
52 | torch.nn.ReLU(),
53 | torch.nn.Linear(head_hidden_size, 1),
54 | Rearrange("... 1 -> ..."),
55 | )
56 | elif config.model == "bart-base":
57 | self.max_model_tokens = 1024
58 | bart_config = BartConfig.from_pretrained("facebook/bart-base")
59 | bart_config.max_position_embeddings = 2048 + 1024
60 | self.model = BartModel(bart_config)
61 | self.tokenizer = BartTokenizer.from_pretrained(
62 | "facebook/bart-large",
63 | padding_side="left",
64 | truncation_side="left",
65 | model_max_length=self.max_model_tokens,
66 | )
67 | self.tokenizer.pad_token = self.tokenizer.eos_token
68 | self.head = torch.nn.Sequential(
69 | torch.nn.Linear(768, head_hidden_size),
70 | torch.nn.ReLU(),
71 | torch.nn.Linear(head_hidden_size, 1),
72 | Rearrange("... 1 -> ..."),
73 | )
74 | elif config.model == "longformer-base-4096":
75 | self.max_model_tokens = 4096
76 | self.model = AutoModel.from_pretrained(
77 | "allenai/longformer-base-4096"
78 | )
79 | self.tokenizer = AutoTokenizer.from_pretrained(
80 | "allenai/longformer-base-4096",
81 | padding_side="left",
82 | truncation_side="left",
83 | model_max_length=self.max_model_tokens,
84 | )
85 | self.tokenizer.eos_token = self.tokenizer.pad_token
86 | self.head = torch.nn.Sequential(
87 | torch.nn.Linear(768, head_hidden_size),
88 | torch.nn.ReLU(),
89 | torch.nn.Linear(head_hidden_size, 1),
90 | Rearrange("... 1 -> ..."),
91 | )
92 | else:
93 | raise ValueError(f"model {config.model} not supported")
94 | # store config
95 | self.config = config
96 | if os.path.exists(config.model_folder) is False:
97 | os.mkdir(config.model_folder)
98 | else:
99 | self.load()
100 | # freeze model parameters (only train the head)
101 | for param in self.model.parameters():
102 | param.requires_grad = False
103 | # move model to device
104 | self.model.to(config.device)
105 | self.head.to(config.device)
106 |
107 | @beartype
108 | def parameters(
109 | self,
110 | ) -> Iterable[torch.nn.Parameter]:
111 | """Return the parameters of the reward model"""
112 | for p in self.model.parameters():
113 | yield p
114 | for p in self.head.parameters():
115 | yield p
116 |
117 | @beartype
118 | def forward(
119 | self, output_sequence: torch.Tensor, output_sequence_mask: torch.Tensor
120 | ) -> torch.Tensor:
121 | """Generate the sequence of rewards for the given output sequence
122 | what is the quality of the output sequence tokens?
123 |
124 | Args:
125 | output_sequence (torch.Tensor): The sequence of tokens to be
126 | evaluated
127 | output_sequence_mask (torch.Tensor): Mask for the attention
128 |
129 | Returns:
130 | torch.Tensor: Rewards for the given output sequence
131 | """
132 | output = self.model(
133 | output_sequence, attention_mask=output_sequence_mask
134 | )
135 | # What if the output_sequence is longer than the max context of
136 | # the model?
137 | rewards = self.head(output.last_hidden_state)
138 | if self.config.debug:
139 | print("RewardModel.forward")
140 | print("output_sequence.shape", output_sequence.shape)
141 | print("output_sequence", output_sequence)
142 | print("reward.shape", rewards.shape)
143 | print("reward", rewards)
144 | return rewards
145 |
146 | @beartype
147 | def get_reward(
148 | self, output_sequence: torch.Tensor, output_sequence_mask: torch.Tensor
149 | ) -> torch.Tensor:
150 | """Get the reward for the given output sequence
151 |
152 | Args:
153 | output_sequence (torch.Tensor): The concatenation of initial input
154 | and actor output as tokens
155 | output_sequence_mask (torch.Tensor): Mask for the attention
156 | """
157 | rewards = self.forward(output_sequence, output_sequence_mask)
158 | return rewards[:, -1]
159 |
160 | @beartype
161 | def load(self, path: Optional[str] = None) -> None:
162 | """Load the model from the path
163 |
164 | Args:
165 | path (str): path to the model
166 | """
167 | if path is None:
168 | path = self.config.model_folder + "/" + self.config.model + ".pt"
169 | if os.path.exists(self.config.model_folder) is False:
170 | os.makedirs(self.config.model_folder)
171 | print(
172 | f"Model folder does not exist. Creating it,"
173 | f"and returning without loading the model:\n{path}"
174 | )
175 | return
176 | # load the model and the tokenizer
177 | if os.path.exists(path) is False:
178 | print(
179 | f"Impossible to load the model:\n{path}\n"
180 | f"The path doesn't exist."
181 | )
182 | return
183 | model_dict = torch.load(path)
184 | self.model.load_state_dict(model_dict["model"])
185 | self.head.load_state_dict(model_dict["head"])
186 |
187 | @beartype
188 | def save(self, path: Optional[str] = None) -> None:
189 | """Save the model to the path
190 |
191 | Args:
192 | path (Optional[str], optional): Path to store the model.
193 | Defaults to None.
194 | """
195 | if path is None:
196 | path = self.config.model_folder + "/" + self.config.model + ".pt"
197 | if os.path.exists(self.config.model_folder) is False:
198 | os.makedirs(self.config.model_folder)
199 | torch.save(
200 | {"model": self.model.state_dict(), "head": self.head.state_dict()},
201 | path,
202 | )
203 |
204 |
205 | # just to keep namings consistent
206 | CriticModel = RewardModel
207 |
208 |
209 | class RewardDataset(Dataset):
210 | """Dataset class for the reward model
211 | read a json file with the following format:
212 | [
213 | {
214 | "user_input": "...",
215 | "completion": "...",
216 | "score": ...
217 | },
218 | ...
219 | ]
220 | Where:
221 | user_input: the initial input of the user
222 | completion: the completion generated by the model
223 | score: the score given by the user to the completion (or by the LLM)
224 | """
225 |
226 | def __init__(self, path: str) -> None:
227 | print(f"Loading dataset from {path}")
228 | with open(path, "r") as f:
229 | self.data = list(json.load(f))
230 | print(f"Loaded {len(self.data)} samples")
231 |
232 | def __getitem__(self, idx: int):
233 | user_input = self.data[idx]["user_input"]
234 | completion = self.data[idx]["completion"]
235 | item = tuple([user_input, completion])
236 | return item
237 |
238 | def __len__(
239 | self,
240 | ):
241 | return len(self.data)
242 |
243 |
244 | class RewardTrainer:
245 | """Reward class to train the reward model
246 |
247 | Args:
248 | config (ConfigModel): Config parameters for the model
249 |
250 | Attributes:
251 | model (RewardModel): Reward model
252 | config (ConfigModel): Config parameters for the model
253 | optimizer (torch.optim): Optimizer for the model
254 | loss (torch.nn): Loss function for the model
255 |
256 | Methods:
257 | train: Train the reward model
258 | generate_user_input: Generate the user input for the LLM to evaluate a
259 | couple, (user_input, completion) and assing a score
260 | distillate: Parse the dataset and assign scores using LLMs
261 | """
262 |
263 | def __init__(self, config: ConfigReward) -> None:
264 | self.model = RewardModel(config)
265 | self.config = config
266 | self.optimizer = torch.optim.Adam(
267 | self.model.parameters(), lr=config.lr
268 | )
269 | self.loss_function = torch.nn.MSELoss()
270 | if not os.path.exists("./models"):
271 | os.mkdir("./models")
272 | self.training_stats = TrainingStats()
273 | self.validation_flag = False
274 | if config.validation_dataset_path is not None:
275 | self.validation_flag = True
276 |
277 | openai_llm = OpenAI(
278 | model_name=self.config.llm_model,
279 | temperature=self.config.llm_temperature,
280 | max_tokens=self.config.llm_max_tokens,
281 | )
282 | prompt_template = PromptTemplate(**REWARD_TEMPLATE)
283 | self.llm = LLMChain(llm=openai_llm, prompt=prompt_template)
284 |
285 | def distillate(
286 | self,
287 | ):
288 | """Parse the dataset and assign scores using LLMs
289 | then save back the dataset with the uploaded scores
290 | """
291 | # load the dataset
292 | with open(self.config.train_dataset_path, "r") as f:
293 | train_data = json.load(f)
294 | # for each element of the dataset, assing a score.
295 | for i, data in enumerate(train_data):
296 | if data.get("score", None) is None:
297 | prompt_tokens = (
298 | data["user_input"]
299 | + data["completion"]
300 | + self.llm.prompt.template
301 | )
302 | prompt_len = int(len(prompt_tokens.split(" ")) / 0.75)
303 | # 80% of the max length as safety margin
304 | if prompt_len > self.config.llm_max_tokens * 0.8:
305 | print(
306 | f"The prompt of the data {i} is too long\n"
307 | f"tokens: {prompt_len}\n"
308 | f"max_tokens: {self.config.llm_max_tokens * 0.8}"
309 | )
310 | continue
311 | score = self.llm.run(
312 | user_input=data["user_input"],
313 | completion=data["completion"],
314 | ).strip()
315 | # TODO extract from score the float value with a regex
316 | score = score.split(" ")[0]
317 | try:
318 | score = float(score)
319 | except Exception:
320 | print(
321 | f"The score returned by the LLM for the"
322 | f"data, {i}, is not a float float:\n{score}"
323 | )
324 | continue
325 | data["score"] = score
326 | # save the dataset back
327 | with open(self.config.train_dataset_path, "w") as f:
328 | json.dump(train_data, f)
329 |
330 | def train(
331 | self,
332 | ) -> None:
333 | """Train the reward model"""
334 | print("Start Training the Reward Model")
335 | # get config parameters
336 | train_dataset_path = self.config.train_dataset_path
337 | validation_dataset_path = self.config.validation_dataset_path
338 | batch_size = self.config.batch_size
339 | epochs = self.config.epochs
340 | device = self.config.device
341 |
342 | # create dataloaders
343 | train_dataset = RewardDataset(train_dataset_path)
344 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
345 | if self.validation_flag:
346 | eval_dataset = RewardDataset(validation_dataset_path)
347 | validation_dataloader = DataLoader(
348 | eval_dataset, batch_size=batch_size
349 | )
350 | iteration_per_print = self.config.iteration_per_print
351 |
352 | # compute the number of iterations
353 | n_iter = int(len(train_dataset) / batch_size)
354 |
355 | # traing loop
356 | for epoch in range(epochs):
357 | self.model.train()
358 | for i, inputs in enumerate(train_dataloader):
359 |
360 | input_text = inputs["user_input"] + inputs["completion"]
361 | # tokenizer (placed here instead of dataset class)
362 | input_tokens = self.model.tokenizer(
363 | input_text, padding=True, truncation=True
364 | )
365 |
366 | score = None # TODO: load the score
367 |
368 | # TODO: check on the length of the input tokens if they are
369 | # too many it can create problems
370 | output = torch.tensor(score, dtype=torch.float32).to(device)
371 |
372 | # forward pass
373 | est_output = self.model.get_reward(
374 | input_tokens["input_ids"].to(device),
375 | input_tokens["attention_mask"].to(device),
376 | )
377 |
378 | loss = self.loss_function(est_output, output)
379 | self.training_stats.training_loss.append(loss.item())
380 |
381 | # backward pass
382 | self.optimizer.zero_grad()
383 | loss.backward()
384 | self.optimizer.step()
385 |
386 | # print progress
387 | if i % iteration_per_print == 0:
388 | print(
389 | f"Epoch: {epoch+1}/{epochs}, "
390 | f"Iteration: {i+1}/{n_iter}, "
391 | f"Training Loss: {loss.item()}"
392 | )
393 | print(
394 | "prediction",
395 | est_output.cpu().detach().numpy(),
396 | "target",
397 | score.cpu().numpy(),
398 | )
399 | if self.validation_flag:
400 | self.model.eval()
401 | for i, (text, score) in enumerate(validation_dataloader):
402 | # forward pass
403 | input_tokens = self.model.tokenizer(
404 | text, return_tensors="pt", padding=True
405 | )
406 | input_tokens = input_tokens.to(device)
407 | # TODO: check on the length of the input tokens if they are
408 | # too many it can create problems
409 | output = torch.tensor(score, dtype=torch.float32).to(
410 | device
411 | )
412 | est_output = self.model.get_reward(
413 | input_tokens["input_ids"],
414 | input_tokens["attention_mask"],
415 | )
416 | loss = self.loss_function(est_output, output)
417 | self.training_stats.validation_loss.append(loss.item())
418 |
419 | # print progress
420 | if i % iteration_per_print == 0:
421 | print(
422 | f"Epoch: {epoch+1}/{epochs}, "
423 | f"Iteration: {i+1}/{n_iter}, "
424 | f"Validation Loss: {loss.item()}"
425 | )
426 | self.model.save()
427 |
--------------------------------------------------------------------------------
/chatllama/rlhf/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from actor import ActorTrainer
4 | from config import Config
5 | from trainer import RLTrainer
6 | from reward import RewardTrainer
7 |
8 |
9 | def test_actor_training(path=None, device=None, debug=False):
10 | config = Config(path=path, device=device, debug=debug)
11 | trainer = ActorTrainer(config.actor)
12 | trainer.train()
13 | trainer.training_stats.plot()
14 |
15 |
16 | def test_reward_training(path=None, device=None, debug=False):
17 | device = torch.device("cuda:0")
18 | config = Config(path=path, device=device, debug=debug)
19 | trainer = RewardTrainer(config.reward)
20 | trainer.train()
21 | trainer.training_stats.plot()
22 |
23 |
24 | def test_rl_trainig(path=None, device=None, debug=False):
25 | device = torch.device("cuda:0")
26 | config = Config(path=path, device=device, debug=debug)
27 | trainer = RLTrainer(config.trainer)
28 | trainer.distillate()
29 | trainer.train()
30 | trainer.training_stats.plot()
31 |
32 |
33 | if __name__ == "__main__":
34 | reward_training = True
35 | rl_training = False
36 | actor_training = False
37 |
38 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
39 | # place here the path to the config.yaml file
40 | config_path = "/home/pierpaolo/Documents/optimapi/ptuning/config.yaml"
41 |
42 | if reward_training:
43 | test_reward_training(path=config_path, device=device)
44 | if rl_training:
45 | test_rl_trainig(path=config_path, device=device)
46 | if actor_training:
47 | test_actor_training(path=config_path, device=device)
48 |
--------------------------------------------------------------------------------
/chatllama/rlhf/trainer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | from collections import deque, namedtuple
5 |
6 | import torch
7 | from beartype import beartype
8 | from beartype.typing import Deque, Tuple, List
9 | from einops import rearrange
10 | from torch.utils.data import Dataset, DataLoader
11 |
12 | from actor import ActorModel
13 | from reward import RewardModel, CriticModel
14 | from config import ConfigReward, ConfigActor, Config
15 | from utils import TrainingStats, ConversationLog
16 |
17 |
18 | class ActorCritic(torch.nn.Module):
19 | """Actor Critic class stores both the actor and the critic models
20 | and it generates values and action for given sequences during the training
21 | of the actor.
22 |
23 | Attributes:
24 | actor (ActorModel): Actor model
25 | critic (CriticModel): Critic model
26 | debug (bool): enable prints for Debugging
27 |
28 | Methods:
29 | forward: given a sequence returns action logits and values (used
30 | to evaluate the actor during training)
31 | generate: given a sequence returns action, action logits, values
32 | sequences and sequences masks (used to generate new sequences
33 | during acting phase)
34 | """
35 |
36 | def __init__(
37 | self, actor_config: ConfigActor, critic_config: ConfigReward
38 | ) -> None:
39 | super().__init__()
40 | self.actor = ActorModel(actor_config)
41 | self.critic = CriticModel(critic_config)
42 | self.debug = actor_config.debug
43 |
44 | @beartype
45 | def forward(
46 | self,
47 | sequences: torch.Tensor,
48 | sequences_mask: torch.Tensor,
49 | action_len: int,
50 | ) -> Tuple:
51 | """Given the whole sequences, use the actor forward to get the logits
52 | for each token in the sequence and the critic forward to get the
53 | values for each generation step.
54 |
55 | Args:
56 | sequences (torch.Tensor): Sequences composed of [states, actions]
57 | sequence_mask (torch.Tensor): Mask for the sequences
58 | action_length (int): Length of the actions in the sequences
59 |
60 | Returns:
61 | action_logits (torch.Tensor): Logits for the actions in the
62 | sequences
63 | values (torch.Tensor): Values for the actions in the sequences
64 | """
65 | # use a single forward on the whole sequence
66 | # to get pi(y | x) and ignore predicted output
67 | actions_logits = self.actor.forward(sequences, sequences_mask)
68 | values = self.critic.forward(sequences, sequences_mask)
69 |
70 | # return only logits and values for the actions taken
71 | real_actions_logits = actions_logits[:, -action_len:, :]
72 | real_values = values[:, -action_len:]
73 |
74 | if self.debug:
75 | print("ActorCritic.forward")
76 | print("action_len", action_len)
77 | print("sequences.shape", sequences.shape)
78 | print("sequences", sequences)
79 | print("real_action_logits.shape", actions_logits.shape)
80 | print("real_action_logits", actions_logits)
81 | print("real_values.shape", values.shape)
82 | print("real_values", values)
83 |
84 | return (
85 | real_actions_logits,
86 | real_values,
87 | )
88 |
89 | @torch.no_grad()
90 | @beartype
91 | def generate(
92 | self, states: torch.Tensor, state_mask: torch.Tensor
93 | ) -> Tuple:
94 | """Generate actions, actions_logits, values and sequences from states
95 |
96 | Args:
97 | states (torch.Tensor): user inputs
98 | state_mask (torch.Tensor): Mask for the states of the environment
99 |
100 | Returns:
101 | actions (torch.Tensor): Actions generated from the states
102 | actions_logits (torch.Tensor): Logits for the actions generated
103 | from the states (i.e. pi(y | x))
104 | values (torch.Tensor): Values generated by the critic model
105 | for the actions generated by the actor (i.e. V(x))
106 | sequences (torch.Tensor): Sequences generated from the states
107 | as [states, actions]
108 | """
109 | # generate action sequence
110 | actions, sequence = self.actor.generate(states, state_mask)
111 | sequences_mask = sequence != self.actor.tokenizer.pad_token_id
112 | action_len = actions.shape[1]
113 |
114 | # generate actions_logits and values
115 | actions_logits, values = self.forward(
116 | sequence, sequences_mask, action_len
117 | )
118 | if self.debug:
119 | print("ActorCritic.generate")
120 | print("actions shape", actions.shape)
121 | print("actions", actions)
122 | print("sequence shape", sequence.shape)
123 | print("sequence", sequence)
124 | print("actions_logits shape", actions_logits.shape)
125 | print("actions_logits", actions_logits)
126 | print("values shape", values.shape)
127 | print("values", values)
128 |
129 | return actions, actions_logits, values, sequence, sequences_mask
130 |
131 |
132 | # structure to store the data for each experience
133 | Memory = namedtuple(
134 | "Memory",
135 | [
136 | "states",
137 | "actions",
138 | "sequences",
139 | "values",
140 | "rewards",
141 | "actions_log_probs",
142 | "sequences_mask",
143 | ],
144 | )
145 |
146 |
147 | class ExperienceDataset(Dataset):
148 | """Dataset to train the actor-critic models"""
149 |
150 | def __init__(
151 | self,
152 | memories: Deque[Memory],
153 | device: torch.device,
154 | ) -> None:
155 | super().__init__()
156 | self.data = list(memories)
157 | self.device = device
158 |
159 | def __len__(
160 | self,
161 | ) -> int:
162 | return len(self.data)
163 |
164 | def __getitem__(self, idx) -> Tuple:
165 | # return the idx-th memory element as a tuple of tensors on the device
166 | item = (
167 | self.data[idx].states.to(self.device),
168 | self.data[idx].actions.to(self.device),
169 | self.data[idx].sequences.to(self.device),
170 | self.data[idx].values.to(self.device),
171 | self.data[idx].rewards.to(self.device),
172 | self.data[idx].actions_log_probs.to(self.device),
173 | self.data[idx].sequences_mask.to(self.device),
174 | )
175 | return item
176 |
177 |
178 | class ExamplesSampler:
179 | """Store the prompt to be sampled to generate the examples
180 | read a json file with the following format:
181 | [
182 | {
183 | "user_input" : "",
184 | } ,
185 | ...
186 | ]
187 | Where:
188 | user_input: is the input of the user or directly the input of the user
189 | with the memory preappended (i.e. user_input + memory)
190 | """
191 |
192 | def __init__(
193 | self,
194 | path: str,
195 | ) -> None:
196 | self.path = path
197 | with open(path, "r") as f:
198 | self.data = json.load(f)
199 |
200 | def sample(self, n: int) -> List:
201 | """Sample n examples from the data
202 |
203 | Args:
204 | n (int): Number of examples to sample
205 | """
206 | return random.sample(self.data, n)
207 |
208 |
209 | class RLTrainer:
210 | """Train the actor-critic model using RL
211 |
212 | Attributes:
213 | config (Config): Configuration of the trainer
214 | debug (bool): Debug mode
215 | actorcritic (ActorCritic): Actor-critic model
216 | actor_optim (torch.optim): Optimizer for the actor
217 | critic_optim (torch.optim): Optimizer for the critic
218 | reward (RewardModel): Reward model
219 | training_stats (TrainingStats): Class to store training stats
220 | Methods:
221 | train: the training loop that calls the learn function after generating
222 | the experiences.
223 | learn: Learn from a batch of experiences and update the actor and the
224 | critic model.
225 | load_checkpoint: Load the checkpoint of the actor-critic model
226 | save_checkpoint: Save the checkpoint of the actor-critic model
227 | generate_user_input: Generate the user input from the inputs
228 | """
229 |
230 | def __init__(
231 | self,
232 | config: Config,
233 | ) -> None:
234 | self.config = config
235 | self.debug = config.trainer.debug
236 |
237 | # initialize agent-critic
238 | self.actorcritic = ActorCritic(config.actor, config.critic)
239 | self.actor_optim = torch.optim.Adam(
240 | self.actorcritic.actor.parameters(), lr=config.trainer.actor_lr
241 | )
242 | self.critic_optim = torch.optim.Adam(
243 | self.actorcritic.critic.parameters(), lr=config.trainer.critic_lr
244 | )
245 |
246 | # initialize reward model
247 | self.reward = RewardModel(config.reward)
248 |
249 | # initialize class to store training stats
250 | self.training_stats = TrainingStats()
251 | self.conversation_log = ConversationLog()
252 |
253 | # initialize examples sampler
254 | self.example_sampler = ExamplesSampler(config.trainer.examples_path)
255 |
256 | # eps
257 | self.eps = 1e-8
258 |
259 | # make models directory
260 | if not os.path.exists("./models"):
261 | os.mkdir("./models")
262 |
263 | if not os.path.exists(self.config.trainer.checkpoint_folder):
264 | os.mkdir(self.config.trainer.checkpoint_folder)
265 |
266 | def save_checkpoint(
267 | self,
268 | current_episode: int,
269 | ) -> None:
270 | print(f"Saving checkpoint for episode {current_episode+1}..")
271 | file_name = "rltraining_" + str(current_episode) + ".pt"
272 | checkpoint_folder = self.config.trainer.checkpoint_folder
273 | if os.path.exists(checkpoint_folder) is False:
274 | os.mkdir(checkpoint_folder)
275 | path = checkpoint_folder + "/" + file_name
276 | torch.save(
277 | {
278 | "episode": current_episode,
279 | "actor_state_dict": self.actorcritic.actor.state_dict(),
280 | "critic_state_dict": self.actorcritic.critic.state_dict(),
281 | "actor_optim_state_dict": self.actor_optim.state_dict(),
282 | "critic_optim_state_dict": self.critic_optim.state_dict(),
283 | "training_stats": self.training_stats,
284 | },
285 | path,
286 | )
287 |
288 | def load_checkpoint(
289 | self,
290 | ) -> int:
291 | # get all the files name in the checkpoint folder and take the one
292 | # with the highest epoch
293 | checkpoint_folder = self.config.trainer.checkpoint_folder
294 | if os.path.exists(checkpoint_folder) is False:
295 | os.mkdir(checkpoint_folder)
296 | print(
297 | f"Checkpoint folder {checkpoint_folder} does not exist.\n"
298 | f"No checkpoint will be loaded."
299 | )
300 | return
301 | files = os.listdir(checkpoint_folder)
302 | episodes = [int(f.split("_")[1].split(".")[0]) for f in files]
303 | if len(episodes) == 0:
304 | return 0
305 | max_episode = max(episodes)
306 | print(f"Loading checkpoint for episode {max_episode+1}..")
307 | file_name = "rltraining_" + str(max_episode) + ".pt"
308 | path = checkpoint_folder + "/" + file_name
309 | checkpoint = torch.load(path)
310 | self.actorcritic.actor.load_state_dict(checkpoint["actor_state_dict"])
311 | self.actorcritic.critic.load_state_dict(
312 | checkpoint["critic_state_dict"]
313 | )
314 | self.actor_optim.load_state_dict(checkpoint["actor_optim_state_dict"])
315 | self.critic_optim.load_state_dict(
316 | checkpoint["critic_optim_state_dict"]
317 | )
318 | self.trainign_stats = checkpoint["training_stats"]
319 | self.actorcritic.actor.to(self.config.trainer.device)
320 | self.actorcritic.critic.to(self.config.trainer.device)
321 | return max_episode + 1 # return the next episode to train
322 |
323 | @beartype
324 | def learn(self, memories: Deque[Memory]) -> None:
325 | """Train the agent-critic model using RL:
326 | - for each batch of episodes, compute action logits and values
327 | - then compare action logits probs with memories one and values with
328 | rewards to compute the PPO loss and update the actor-critic model
329 | """
330 | print("Start to Learn...")
331 |
332 | # get parameters
333 | epochs = self.config.trainer.epochs
334 | actor_eps_clip = self.config.trainer.actor_eps_clip
335 | critic_eps_clip = self.config.trainer.critic_eps_clip
336 | beta_s = self.config.trainer.beta_s
337 | batch_size = self.config.trainer.batch_size
338 | device = self.config.trainer.device
339 |
340 | # create dataset from memories
341 | dataloader = DataLoader(
342 | ExperienceDataset(memories, device), batch_size=batch_size
343 | )
344 |
345 | # train agent-critic
346 | self.actorcritic.train()
347 | for epoch in range(epochs):
348 | for i, (
349 | states,
350 | old_actions,
351 | sequences,
352 | old_values,
353 | rewards,
354 | old_actions_log_probs,
355 | sequences_mask,
356 | ) in enumerate(dataloader):
357 |
358 | # print
359 | print(
360 | "Epoch",
361 | epoch + 1,
362 | "of",
363 | epochs,
364 | "Data",
365 | i + 1,
366 | "of",
367 | int(len(dataloader) / batch_size),
368 | )
369 |
370 | if self.debug:
371 | print("RLTrainer.learn()")
372 | print("memory states shapes are: ")
373 | print("states shape", states.shape)
374 | print("old_actions shape", old_actions.shape)
375 | print("sequences shape", sequences.shape)
376 | print("old_values shape", old_values.shape)
377 | print("rewards shape", rewards.shape)
378 | print(
379 | "old_actions_log_probs shape",
380 | old_actions_log_probs.shape,
381 | )
382 | # reshaping rewards to match [b, s] shape
383 | rewards = rearrange(rewards, "b -> b 1")
384 |
385 | # get actions len
386 | actions_len = old_actions.shape[-1]
387 |
388 | # get actor critic forward
389 | actions_logits, values = self.actorcritic.forward(
390 | sequences, sequences_mask, actions_len
391 | )
392 |
393 | # get action log prob
394 | actions_prob = (
395 | torch.softmax(actions_logits, dim=-1).max(dim=-1).values
396 | )
397 | actions_log_prob = torch.log(actions_prob + self.eps)
398 |
399 | # compute entropy
400 | entropies = (actions_prob * actions_log_prob).sum(dim=-1)
401 |
402 | # compute KL divergence
403 | kl_div_loss = (
404 | (actions_prob * (old_actions_log_probs - actions_log_prob))
405 | .sum(dim=-1)
406 | .mean()
407 | )
408 |
409 | # compute PPO Loss -- Whan dimensions are different
410 | # (especially the values and the probs are
411 | # multiplied directly with the reward)
412 | ratios = (actions_log_prob - old_actions_log_probs).exp()
413 | advantages = rewards - old_values
414 | # normalize advantages
415 | advantages = (advantages - advantages.mean(dim=-1)) / (
416 | advantages.std() + self.eps
417 | )
418 | surr1 = advantages * ratios
419 | surr2 = (
420 | torch.clamp(ratios, 1 - actor_eps_clip, 1 + actor_eps_clip)
421 | * advantages
422 | )
423 | policy_loss = -torch.min(surr1, surr2) - beta_s * entropies
424 | policy_loss = policy_loss.mean()
425 | loss = policy_loss + kl_div_loss
426 | # check if loss item is nan
427 | if torch.isnan(loss):
428 | raise ValueError("Loss is nan")
429 | print("loss", loss.item())
430 |
431 | if self.debug:
432 | print("values", values)
433 | print("old_values", old_values)
434 | print("rewards", rewards)
435 | print("ratios", ratios)
436 | print("advantages", advantages)
437 | print("entropies", entropies)
438 |
439 | # update actor with loss
440 | self.actor_optim.zero_grad()
441 | loss.backward()
442 | self.actor_optim.step()
443 |
444 | torch.cuda.synchronize(device)
445 |
446 | # compute value loss
447 | value_loss_clipped = old_values + (values - old_values).clamp(
448 | -critic_eps_clip, critic_eps_clip
449 | )
450 | value_loss1 = (value_loss_clipped - rewards) ** 2
451 | value_loss2 = (values - rewards) ** 2
452 | value_loss = torch.max(value_loss1, value_loss2).mean()
453 | if torch.isnan(value_loss):
454 | raise ValueError("Value loss is nan")
455 | print("value_loss", value_loss.item())
456 |
457 | # upate critic with loss
458 | self.critic_optim.zero_grad()
459 | value_loss.backward()
460 | self.critic_optim.step()
461 |
462 | self.training_stats.training_loss.append(
463 | loss.detach().cpu().item()
464 | )
465 | self.training_stats.value_loss.append(
466 | value_loss.detach().cpu().item()
467 | )
468 |
469 | self.actorcritic.eval()
470 | print("End Learning")
471 |
472 | def train(
473 | self,
474 | ) -> None:
475 | # initialize settings
476 | num_episodes = self.config.trainer.num_episodes
477 | max_timesteps = self.config.trainer.max_timesteps
478 | num_examples = self.config.trainer.num_examples
479 | update_timesteps = self.config.trainer.update_timesteps
480 | batch_size = self.config.trainer.batch_size
481 | update_checkpoint = self.config.trainer.update_checkpoint
482 | device = self.config.trainer.device
483 |
484 | print("Start RL Training")
485 | # check dimensions consistency
486 | # at each time step num_examples memories are generated
487 | number_of_memories_per_learn_iteration = (
488 | num_examples * update_timesteps
489 | )
490 | # the number of memories must be a multiple of the batch size
491 | assert (
492 | number_of_memories_per_learn_iteration % batch_size == 0
493 | ), "The number of memories must be a multiple of the batch size"
494 | # the total number of timesteps is
495 | total_number_of_timesteps = num_episodes * max_timesteps
496 | # the update_timesteps must be a multiple
497 | # of the total number of timesteps
498 | assert total_number_of_timesteps % update_timesteps == 0, (
499 | "The number of timesteps (num_episodes*max_timesteps)"
500 | "must be a multiple of the update_timesteps"
501 | )
502 |
503 | # initialize memories
504 | memories = deque([])
505 |
506 | # loop over episodes and timesteps
507 | current_time = 0
508 | checkpoint_counter = 0
509 | current_episode = self.load_checkpoint()
510 | current_learn_counter = 0
511 |
512 | self.actorcritic.eval()
513 | for eps in range(current_episode, num_episodes):
514 | for timestep in range(max_timesteps):
515 |
516 | print(
517 | f"Episode: {eps + 1} of {num_episodes}, "
518 | f"Timestep: {timestep + 1} of {max_timesteps}",
519 | )
520 |
521 | # counter used to count timesteps into memory
522 | current_time += 1
523 |
524 | # sample num_examples examples from example dataset
525 | inputs = self.example_sampler.sample(num_examples)
526 |
527 | # tokenize examples
528 | tokenized_inputs = self.actorcritic.actor.tokenizer(
529 | inputs, padding=True, return_tensors="pt"
530 | )
531 | if self.debug:
532 | print("RLTrainer.train()")
533 | print("tokenized inputs", tokenized_inputs)
534 | # states are [batch_size, seq_len_of_states]
535 | states = tokenized_inputs["input_ids"].to(device)
536 | states_mask = tokenized_inputs["attention_mask"].to(device)
537 |
538 | # generate prompts
539 | # actions --> output produced by the actor head in response
540 | # of the state(input) [batch_size, len_of_actions]
541 | # actions_logits --> logits of the actions
542 | # [batch_size, len_of_actions, vocab_size]
543 | # values --> output produced by the critic for each action
544 | # [batch_size, len_of_actions]
545 | # sequence --> (state, actions)
546 | # [batch_size, len_of_actions + seq_len_of_states] =
547 | # [batch_size, seq_len]
548 | (
549 | actions,
550 | actions_logits,
551 | values,
552 | sequences,
553 | sequences_mask,
554 | ) = self.actorcritic.generate(states, states_mask)
555 |
556 | # from action logits to action log probs
557 | action_prob = (
558 | torch.softmax(actions_logits, dim=-1).max(dim=-1).values
559 | )
560 | actions_log_probs = torch.log(action_prob + self.eps)
561 |
562 | completions = [
563 | self.actorcritic.actor.tokenizer.decode(action)
564 | for i, action in enumerate(actions)
565 | ]
566 | if self.debug:
567 | print("RLTrainer.train()")
568 | print("completions:")
569 | for i, completion in enumerate(completions):
570 | print(i, completion)
571 | print("")
572 |
573 | # compute reward for the completion
574 | # the reward must take into account the answer quality wrt to
575 | # the initial request given
576 | # and must be tokenized again
577 | task_responses = []
578 | for input, completion in zip(inputs, completions):
579 | task_response = input + "\n" + completion
580 | task_responses.append(task_response)
581 | if self.debug:
582 | print("RLTrainer.train()")
583 | print("task_responses:")
584 | for i, task_response in enumerate(task_responses):
585 | print(i, task_response)
586 | print("")
587 | tokenized_responses = self.reward.tokenizer(
588 | task_responses, padding=True, return_tensors="pt"
589 | )
590 | rewards = self.reward.get_reward(
591 | tokenized_responses["input_ids"].to(device),
592 | tokenized_responses["attention_mask"].to(device),
593 | )
594 |
595 | # store memories of the episode / timestep
596 | for i in range(states.shape[0]):
597 | memories.append(
598 | Memory(
599 | *map(
600 | lambda x: x.detach().cpu(),
601 | (
602 | states[i, :],
603 | actions[i, :],
604 | sequences[i, :],
605 | values[i, :],
606 | rewards[i],
607 | actions_log_probs[i, :],
608 | sequences_mask[i, :],
609 | ),
610 | )
611 | )
612 | )
613 |
614 | # log the memories in the conversation log
615 | for i in range(states.shape[0]):
616 | self.conversation_log.add_conversation(
617 | inputs[i],
618 | completions[i],
619 | rewards[i].detach().cpu(),
620 | current_learn_counter,
621 | )
622 |
623 | # learn from memories
624 | print(
625 | f"Learning counter: {current_time} of {update_timesteps}"
626 | )
627 | if (current_time % update_timesteps == 0) and (
628 | current_time != 0
629 | ):
630 | checkpoint_counter += 1
631 | self.conversation_log.show(current_learn_counter)
632 | self.learn(memories)
633 | memories.clear()
634 | current_time = 0
635 | current_learn_counter += 1
636 |
637 | if (checkpoint_counter % update_checkpoint == 0) and (
638 | checkpoint_counter != 0
639 | ):
640 | self.save_checkpoint(eps)
641 | checkpoint_counter = 0
642 |
643 | self.actorcritic.critic.save()
644 | self.actorcritic.actor.save()
645 | # print("Show conversations log")
646 | # self.conversation_log.show()
647 | print("End RL Training")
648 |
--------------------------------------------------------------------------------
/chatllama/rlhf/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from beartype import beartype
3 | from beartype.typing import Optional
4 | from plotly import graph_objects as go
5 |
6 |
7 | class TrainingStats:
8 | """Training statistics
9 |
10 | Attributes:
11 | training_loss (List): List of training losses
12 | training_accuracy (List): List of training accuracies
13 | value_loss (List): List of value losses
14 | validation_loss (List): List of validation losses
15 | validation_accuracy (List): List of validation accuracies
16 | """
17 |
18 | def __init__(self):
19 | self.training_loss = []
20 | self.training_accuracy = []
21 | self.value_loss = []
22 | self.validation_loss = []
23 | self.validation_accuracy = []
24 |
25 | def plot(self):
26 | """Plot the training statistics using plotly"""
27 | fig = go.Figure()
28 | if len(self.training_loss) > 0:
29 | fig.add_trace(
30 | go.Scatter(y=self.training_loss, name="Training loss")
31 | )
32 | if len(self.training_accuracy) > 0:
33 | fig.add_trace(
34 | go.Scatter(y=self.training_accuracy, name="Training accuracy")
35 | )
36 | if len(self.value_loss) > 0:
37 | fig.add_trace(go.Scatter(y=self.value_loss, name="Value loss"))
38 | if len(self.validation_loss) > 0:
39 | fig.add_trace(
40 | go.Scatter(y=self.validation_loss, name="Validation loss")
41 | )
42 | if len(self.validation_accuracy) > 0:
43 | fig.add_trace(
44 | go.Scatter(
45 | y=self.validation_accuracy, name="Validation accuracy"
46 | )
47 | )
48 | fig.update_layout(
49 | showlegend=True, xaxis_type="log", xaxis_title="steps"
50 | )
51 | fig.show()
52 |
53 |
54 | class ConversationLog:
55 | """Save the conversation:
56 | (user input, model output, rewards and learn_counter)
57 | during the RL training loop. Additionally, in order to be able to compare
58 | the initial dataset of answers to the prompts, we store also the original
59 | performance of the generation:
60 | (generation_input, generation_output, generation_reward)
61 | """
62 |
63 | def __init__(self):
64 | self.conversation = []
65 |
66 | @beartype
67 | def add_conversation(
68 | self,
69 | user_input: str,
70 | model_output: str,
71 | reward: float,
72 | learn_counter: int,
73 | previous_reward: float,
74 | previous_completion: str,
75 | ):
76 | """Add a conversation to the log
77 |
78 | Args:
79 | user_input (str): User input / initial prompt
80 | model_output (str): Completion of the LLM model
81 | reward (float): Reward of the reward model assigned to the output
82 | learn_counter (int): Number of the learning iteration to
83 | distinguish the conversations that happens at different
84 | points of the training loop
85 | previous_reward (float): Reward of the reward model assigned to
86 | the output of original dataset
87 | previous_completion (str): Completion of the LLM model of the
88 | original dataset
89 | """
90 | self.conversation.append(
91 | {
92 | "user_input": user_input,
93 | "model_output": model_output,
94 | "reward": reward,
95 | "learn_counter": learn_counter,
96 | "previous_reward": previous_reward,
97 | "previous_completion": previous_completion,
98 | }
99 | )
100 |
101 | def save(self, path: Optional[str] = "./conversation.json"):
102 | with open(path, "r") as f:
103 | conversation = json.load(f)
104 | conversation.extend(self.conversation)
105 | with open(path, "w") as f:
106 | json.dump(conversation, f)
107 |
108 | def load(self, path: Optional[str] = "./conversation.json"):
109 | with open(path, "r") as f:
110 | self.conversation = json.load(f)
111 |
112 | def show(self, current_iteration: int = None):
113 | """Show the conversation log
114 |
115 | Args:
116 | current_iteration (int): Current iteration of the training loop,
117 | if not None, print only the conversations that happened at
118 |
119 | """
120 | for i, c in enumerate(self.conversation):
121 | if current_iteration is None:
122 | print(
123 | f"##########################################\n"
124 | f"Conversation {i} at learn_counter "
125 | f"{c['learn_counter']}\n"
126 | f"##########################################\n"
127 | f"## User Input:\n\n{c['user_input']}\n\n"
128 | f"## Model Output:\n\n{c['model_output']}\n\n"
129 | f"## Reward: {c['reward']}\n\n"
130 | f"## Previous Reward: {c['previous_reward']}\n\n"
131 | )
132 | else:
133 | if current_iteration == c["learn_counter"]:
134 | print(
135 | f"##########################################\n"
136 | f"Conversation {i} at learn_counter "
137 | f"{c['learn_counter']}\n"
138 | f"##########################################\n"
139 | f"## User Input:\n\n{c['user_input']}\n\n"
140 | f"## Model Output:\n\n{c['model_output']}\n\n"
141 | f"## Reward: {c['reward']}\n\n"
142 | f"## Previous Reward: {c['previous_reward']}\n\n"
143 | )
144 |
--------------------------------------------------------------------------------
/generate_dataset.py:
--------------------------------------------------------------------------------
1 | from langchain import OpenAI, LLMChain, PromptTemplate
2 | from langchain.chains.conversation.memory import (
3 | ConversationBufferWindowMemory,
4 | )
5 |
6 | from chatllama.langchain_modules.prompt_templates import (
7 | PERSON_CHATBOT_TEMPLATE,
8 | AI_CHATBOT_TEMPLATE,
9 | )
10 |
11 |
12 | CONVERSATION_LENGTH = 20
13 |
14 |
15 | def create_conversation(human_agent: LLMChain, bot_agent: LLMChain):
16 | conversation = []
17 | chatbot_output = ""
18 | for i in range(CONVERSATION_LENGTH):
19 | # Human agent goes first
20 | human_output = human_agent.run(chatbot_input=chatbot_output)
21 | conversation.append(f"Human: {human_output}")
22 | chatbot_output = bot_agent.run(human_input=human_output)
23 | conversation.append(f"AI: {chatbot_output}")
24 | return "\n".join(conversation)
25 |
26 |
27 | def build_agents():
28 | # be aware that too long completions will not fit the sequence length
29 | # of possible critic or reward models ...
30 | llm = OpenAI(max_tokens=2048, temperature=0.7)
31 | human_template = PromptTemplate(**PERSON_CHATBOT_TEMPLATE)
32 | human_agent = LLMChain(
33 | llm=llm,
34 | prompt=human_template,
35 | memory=ConversationBufferWindowMemory(k=4),
36 | )
37 | bot_template = PromptTemplate(**AI_CHATBOT_TEMPLATE)
38 | bot_agent = LLMChain(
39 | llm=llm,
40 | prompt=bot_template,
41 | memory=ConversationBufferWindowMemory(k=4),
42 | )
43 | return human_agent, bot_agent
44 |
45 |
46 | def main():
47 | from argparse import ArgumentParser
48 |
49 | parser = ArgumentParser()
50 | parser.add_argument("--num_conversations", type=int, default=1000)
51 | parser.add_argument("--output_file", type=str, default="conversations.txt")
52 | args = parser.parse_args()
53 | conversations = []
54 | for conv in range(args.num_conversations):
55 | human_agent, bot_agent = build_agents()
56 | conversation = create_conversation(human_agent, bot_agent)
57 | conversations.append(conversation)
58 | with open(args.output_file, "w") as f:
59 | f.write("\n\nNEW CONVERSATION\n\n".join(conversations))
60 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyllama
2 | fairscale
3 | langchain
4 | beartype
5 | einops
6 | plotly
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | import os
3 |
4 | here = os.path.dirname(os.path.realpath(__file__))
5 | HAS_CUDA = os.system("nvidia-smi > /dev/null 2>&1") == 0
6 |
7 | VERSION = "0.0.4"
8 | DESCRIPTION = "ChatLLaMA: Open and Efficient Foundation Language Models Runnable In A Single GPU"
9 |
10 | packages = [
11 | "chatllama",
12 | ]
13 |
14 |
15 | def read_file(filename: str):
16 | try:
17 | lines = []
18 | with open(filename) as file:
19 | lines = file.readlines()
20 | lines = [line.rstrip() for line in lines if not line.startswith('#')]
21 | return lines
22 | except:
23 | return []
24 |
25 |
26 | def package_files(ds):
27 | paths = []
28 | for d in ds:
29 | for (path, directories, filenames) in os.walk(d):
30 | for filename in filenames:
31 | if '__pycache__' not in str(filename):
32 | paths.append(str(os.path.join(path, filename))[len('chatllama/'):])
33 | return paths
34 |
35 | extra_files = package_files(['chatllama/'])
36 |
37 |
38 | setup(
39 | name="chatllama",
40 | version=VERSION,
41 | author_email="",
42 | description=DESCRIPTION,
43 | long_description=open("README.md", "r", encoding="utf-8").read(),
44 | long_description_content_type="text/markdown",
45 | install_requires=read_file(f"{here}/requirements.txt"),
46 | keywords=[
47 | "ChatLLaMA", "LLaMA"
48 | ],
49 | classifiers=[
50 | "Development Status :: 3 - Alpha",
51 | "Programming Language :: Python :: 3",
52 | "Programming Language :: Python :: 3.8",
53 | "Programming Language :: Python :: 3.9",
54 | "Programming Language :: Python :: 3.10",
55 | ],
56 | packages=packages,
57 | package_data={"chatllama": extra_files},
58 | url="https://github.com/juncongmoo/chatllama"
59 | )
60 |
--------------------------------------------------------------------------------