├── README.md
├── example.py
└── llama
├── __init__.py
├── configuration_llama.py
├── convert_llama_weights_to_hf.py
├── modeling_llama.py
└── tokenization_llama.py
/README.md:
--------------------------------------------------------------------------------
1 | # LLaMA - Single End-to-End Repository
2 |
3 |
4 |
5 |
6 |
7 |
8 | OS: Windows 11 Home 22H2 22621.1265
9 |
10 | Graphics card: NVIDIA RTX 3090 FE 24GB, RTX 4090 FE 24GB
11 |
12 | This repository is a standalone solution for running the LLaMA model with huggingface interface using the public weights.
13 |
14 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 |
2 | import llama
3 |
4 | MODEL = 'decapoda-research/llama-7b-hf'
5 | REVISION = '84fd0de2f666324fe13da5642b047be4d55b5982'
6 |
7 | tokenizer = llama.LLaMATokenizer.from_pretrained(MODEL, revision=REVISION)
8 | model = llama.LLaMAForCausalLM.from_pretrained(MODEL, low_cpu_mem_usage = True, revision=REVISION).half()
9 | model.to('cuda')
10 |
11 | prompt = """Tweet: "I hate it when my phone battery dies."
12 | Sentiment: Negative
13 | ###
14 | Tweet: "My day has been 👍"
15 | Sentiment: Positive
16 | ###
17 | Tweet: "This is the link to the article"
18 | Sentiment: Neutral
19 | ###
20 | Tweet: "This new music video was incredibile"
21 | Sentiment:"""
22 |
23 | batch = tokenizer(prompt, return_tensors = "pt", add_special_tokens = False)
24 | print(tokenizer.decode(model.generate(batch["input_ids"].cuda(), max_length=100)[0]))
25 |
--------------------------------------------------------------------------------
/llama/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import TYPE_CHECKING
15 |
16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
17 |
18 |
19 | _import_structure = {
20 | "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LLaMAConfig"],
21 | "tokenization_llama": ["LLaMATokenizer"],
22 | }
23 |
24 | try:
25 | if not is_torch_available():
26 | raise OptionalDependencyNotAvailable()
27 | except OptionalDependencyNotAvailable:
28 | pass
29 | else:
30 | _import_structure["modeling_llama"] = [
31 | "LLaMAForCausalLM",
32 | "LLaMAModel",
33 | "LLaMAPreTrainedModel",
34 | ]
35 |
36 |
37 | if TYPE_CHECKING:
38 | from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LLaMAConfig
39 | from .tokenization_llama import LLaMATokenizer
40 |
41 | try:
42 | if not is_torch_available():
43 | raise OptionalDependencyNotAvailable()
44 | except OptionalDependencyNotAvailable:
45 | pass
46 | else:
47 | from .modeling_llama import (
48 | LLaMAForCausalLM,
49 | LLaMAModel,
50 | LLaMAPreTrainedModel,
51 | )
52 |
53 |
54 | else:
55 | import sys
56 |
57 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
--------------------------------------------------------------------------------
/llama/configuration_llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The FAIR team of Meta AI and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ LLaMA model configuration"""
16 |
17 | from transformers.configuration_utils import PretrainedConfig
18 | from transformers.utils import logging
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24 |
25 |
26 | class LLaMAConfig(PretrainedConfig):
27 | r"""
28 | This is the configuration class to store the configuration of a [`~LLaMAModel`]. It is used to instantiate an LLaMA
29 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30 | defaults will yield a similar configuration to that of the LLaMA-7B.
31 |
32 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33 | documentation from [`PretrainedConfig`] for more information.
34 |
35 |
36 | Args:
37 | vocab_size (`int`, *optional*, defaults to 32000):
38 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
39 | `inputs_ids` passed when calling [`~LLaMAModel`] or [`~TFLLaMAModel`].
40 | hidden_size (`int`, *optional*, defaults to 4096):
41 | Dimension of the hidden representations.
42 | intermediate_size (`int`, *optional*, defaults to 11008):
43 | Dimension of the MLP representations.
44 | num_hidden_layers (`int`, *optional*, defaults to 32):
45 | Number of hidden layers in the Transformer encoder.
46 | num_attention_heads (`int`, *optional*, defaults to 32):
47 | Number of attention heads for each attention layer in the Transformer encoder.
48 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
49 | The non-linear activation function (function or string) in the decoder.
50 | max_sequence_length (`int`, *optional*, defaults to 2048):
51 | Max sequence length for model (for RoPE computation)
52 | initializer_range (`float`, *optional*, defaults to 0.02):
53 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
54 | rms_norm_eps (`float`, *optional*, defaults to 1e-12):
55 | The epsilon used by the rms normalization layers.
56 | use_cache (`bool`, *optional*, defaults to `True`):
57 | Whether or not the model should return the last key/values attentions (not used by all models). Only
58 | relevant if `config.is_decoder=True`.
59 | tie_word_embeddings(`bool`, *optional*, defaults to `False`):
60 | Whether to tie weight embeddings
61 | Example:
62 |
63 | ```python
64 | >>> from llama import LLaMAModel, LLaMAConfig
65 |
66 | >>> # Initializing a LLaMA llama-7b style configuration
67 | >>> configuration = LLaMAConfig()
68 |
69 | >>> # Initializing a model from the llama-7b style configuration
70 | >>> model = LLaMAModel(configuration)
71 |
72 | >>> # Accessing the model configuration
73 | >>> configuration = model.config
74 | ```"""
75 | model_type = "llama"
76 |
77 | def __init__(
78 | self,
79 | vocab_size=32000,
80 | hidden_size=4096,
81 | intermediate_size=11008,
82 | num_hidden_layers=32,
83 | num_attention_heads=32,
84 | hidden_act="silu",
85 | max_sequence_length=2048,
86 | initializer_range=0.02,
87 | rms_norm_eps=1e-6,
88 | use_cache=True,
89 | pad_token_id=-1,
90 | bos_token_id=0,
91 | eos_token_id=1,
92 | tie_word_embeddings=False,
93 | **kwargs,
94 | ):
95 | self.vocab_size = vocab_size
96 | self.hidden_size = hidden_size
97 | self.intermediate_size = intermediate_size
98 | self.num_hidden_layers = num_hidden_layers
99 | self.num_attention_heads = num_attention_heads
100 | self.hidden_act = hidden_act
101 | self.max_sequence_length = max_sequence_length
102 | self.initializer_range = initializer_range
103 | self.rms_norm_eps = rms_norm_eps
104 | self.use_cache = use_cache
105 | super().__init__(
106 | pad_token_id=pad_token_id,
107 | bos_token_id=bos_token_id,
108 | eos_token_id=eos_token_id,
109 | tie_word_embeddings=tie_word_embeddings,
110 | **kwargs,
111 | )
--------------------------------------------------------------------------------
/llama/convert_llama_weights_to_hf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import shutil
5 |
6 | import torch
7 |
8 |
9 | """
10 | Sample usage:
11 |
12 | ```
13 | python src/transformers/models/llama/convert_llama_weights_to_hf.py \
14 | --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
15 | ```
16 |
17 | Thereafter, models can be loaded via:
18 |
19 | ```
20 | tokenizer = transformers.LLaMATokenizer.from_pretrained("/output/path/tokenizer/")
21 |
22 | model = transformers.LLaMAForCausalLM.from_pretrained("/output/path/llama-7b/")
23 | ```
24 | """
25 |
26 | INTERMEDIATE_SIZE_MAP = {
27 | "7B": 11008,
28 | "13B": 13824,
29 | "30B": 17920,
30 | "65B": 22016,
31 | }
32 | NUM_SHARDS = {
33 | "7B": 1,
34 | "13B": 2,
35 | "30B": 4,
36 | "65B": 8,
37 | }
38 |
39 |
40 | def read_json(path):
41 | with open(path, "r") as f:
42 | return json.loads(f.read())
43 |
44 |
45 | def write_json(text, path):
46 | with open(path, "w") as f:
47 | f.write(json.dumps(text))
48 |
49 |
50 | def write_model(model_path, input_base_path, model_size):
51 | assert model_size in INTERMEDIATE_SIZE_MAP
52 | os.makedirs(model_path, exist_ok=True)
53 |
54 | params = read_json(os.path.join(input_base_path, "params.json"))
55 | num_shards = NUM_SHARDS[model_size]
56 | n_layers = params["n_layers"]
57 | n_heads = params["n_heads"]
58 | n_heads_per_shard = n_heads // num_shards
59 | dim = params["dim"]
60 | dims_per_head = dim // n_heads
61 |
62 | # Load weights
63 | if model_size == "7B":
64 | # Not shared
65 | # (The sharded implementation would also work, but this is simpler.)
66 | loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
67 | else:
68 | # Sharded
69 | loaded = [
70 | torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
71 | for i in range(num_shards)
72 | ]
73 | param_count = 0
74 | index_dict = {"weight_map": {}}
75 | for layer_i in range(n_layers):
76 | filename = "pytorch_model-{:05d}-of-{:05d}.bin".format(
77 | layer_i,
78 | n_layers + 1,
79 | )
80 | if model_size == "7B":
81 | # Unsharded
82 | state_dict = {
83 | f"model.decoder.layers.{layer_i}.self_attn.q_proj.weight": loaded[
84 | f"layers.{layer_i}.attention.wq.weight"
85 | ],
86 | f"model.decoder.layers.{layer_i}.self_attn.k_proj.weight": loaded[
87 | f"layers.{layer_i}.attention.wk.weight"
88 | ],
89 | f"model.decoder.layers.{layer_i}.self_attn.v_proj.weight": loaded[
90 | f"layers.{layer_i}.attention.wv.weight"
91 | ],
92 | f"model.decoder.layers.{layer_i}.self_attn.o_proj.weight": loaded[
93 | f"layers.{layer_i}.attention.wo.weight"
94 | ],
95 | f"model.decoder.layers.{layer_i}.feed_forward.w1.weight": loaded[
96 | f"layers.{layer_i}.feed_forward.w1.weight"
97 | ],
98 | f"model.decoder.layers.{layer_i}.feed_forward.w2.weight": loaded[
99 | f"layers.{layer_i}.feed_forward.w2.weight"
100 | ],
101 | f"model.decoder.layers.{layer_i}.feed_forward.w3.weight": loaded[
102 | f"layers.{layer_i}.feed_forward.w3.weight"
103 | ],
104 | f"model.decoder.layers.{layer_i}.attention_norm.weight": loaded[
105 | f"layers.{layer_i}.attention_norm.weight"
106 | ],
107 | f"model.decoder.layers.{layer_i}.ffn_norm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
108 | }
109 | else:
110 | # Sharded
111 | state_dict = {
112 | f"model.decoder.layers.{layer_i}.attention_norm.weight": loaded[0][
113 | f"layers.{layer_i}.attention_norm.weight"
114 | ],
115 | f"model.decoder.layers.{layer_i}.ffn_norm.weight": loaded[0][f"layers.{layer_i}.ffn_norm.weight"],
116 | }
117 | state_dict[f"model.decoder.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat(
118 | [
119 | loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
120 | for i in range(num_shards)
121 | ],
122 | dim=0,
123 | ).reshape(dim, dim)
124 | state_dict[f"model.decoder.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat(
125 | [
126 | loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
127 | for i in range(num_shards)
128 | ],
129 | dim=0,
130 | ).reshape(dim, dim)
131 | state_dict[f"model.decoder.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
132 | [
133 | loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
134 | for i in range(num_shards)
135 | ],
136 | dim=0,
137 | ).reshape(dim, dim)
138 |
139 | state_dict[f"model.decoder.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
140 | [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
141 | )
142 | state_dict[f"model.decoder.layers.{layer_i}.feed_forward.w1.weight"] = torch.cat(
143 | [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
144 | )
145 | state_dict[f"model.decoder.layers.{layer_i}.feed_forward.w2.weight"] = torch.cat(
146 | [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
147 | )
148 | state_dict[f"model.decoder.layers.{layer_i}.feed_forward.w3.weight"] = torch.cat(
149 | [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
150 | )
151 |
152 | for k, v in state_dict.items():
153 | index_dict["weight_map"][k] = filename
154 | param_count += v.numel()
155 | torch.save(state_dict, os.path.join(model_path, filename))
156 |
157 | filename = "pytorch_model-{:05d}-of-{:05d}.bin".format(
158 | n_layers,
159 | n_layers + 1,
160 | )
161 | if model_size == "7B":
162 | # Unsharded
163 | state_dict = {
164 | "model.decoder.embed_tokens.weight": loaded["tok_embeddings.weight"],
165 | "model.decoder.norm.weight": loaded["norm.weight"],
166 | "lm_head.weight": loaded["output.weight"],
167 | }
168 | else:
169 | state_dict = {
170 | "model.decoder.norm.weight": loaded[0]["norm.weight"],
171 | "model.decoder.embed_tokens.weight": torch.cat(
172 | [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
173 | ),
174 | "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
175 | }
176 |
177 | for k, v in state_dict.items():
178 | index_dict["weight_map"][k] = filename
179 | param_count += v.numel()
180 | torch.save(state_dict, os.path.join(model_path, filename))
181 |
182 | # Write configs
183 | index_dict["metadata"] = {"total_size": param_count * 2}
184 | write_json(index_dict, os.path.join(model_path, "pytorch_model.bin.index.json"))
185 | config_out = {
186 | "architectures": ["LLaMAForCausalLM"],
187 | "bos_token_id": 0,
188 | "eos_token_id": 1,
189 | "hidden_act": "silu",
190 | "hidden_size": params["dim"],
191 | "intermediate_size": INTERMEDIATE_SIZE_MAP[model_size],
192 | "initializer_range": 0.02,
193 | "max_sequence_length": 2048,
194 | "model_type": "llama",
195 | "num_attention_heads": params["n_heads"],
196 | "num_hidden_layers": params["n_layers"],
197 | "pad_token_id": -1,
198 | "rms_norm_eps": params["norm_eps"],
199 | "torch_dtype": "float16",
200 | "transformers_version": "4.27.0.dev0",
201 | "use_cache": True,
202 | "vocab_size": 32000,
203 | }
204 | write_json(
205 | config_out,
206 | os.path.join(model_path, "config.json"),
207 | )
208 | generation_config = {
209 | "_from_model_config": True,
210 | "bos_token_id": 0,
211 | "eos_token_id": 1,
212 | "pad_token_id": -1,
213 | "transformers_version": "4.27.0.dev0",
214 | }
215 | write_json(
216 | generation_config,
217 | os.path.join(model_path, "generation_config.json"),
218 | )
219 |
220 |
221 | def write_tokenizer(tokenizer_path, input_tokenizer_path):
222 | os.makedirs(tokenizer_path, exist_ok=True)
223 | write_json({}, os.path.join(tokenizer_path, "special_tokens_map.json"))
224 | write_json(
225 | {
226 | "bos_token": "",
227 | "eos_token": "",
228 | "model_max_length": int(1e30),
229 | "tokenizer_class": "LLaMATokenizer",
230 | "unk_token": "",
231 | },
232 | os.path.join(tokenizer_path, "tokenizer_config.json"),
233 | )
234 | shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
235 |
236 |
237 | def main():
238 | parser = argparse.ArgumentParser()
239 | parser.add_argument(
240 | "--input_dir",
241 | help="Location of LLaMA weights, which contains tokenizer.model and model folders",
242 | )
243 | parser.add_argument(
244 | "--model_size",
245 | choices=["7B", "13B", "30B", "65B"],
246 | )
247 | parser.add_argument(
248 | "--output_dir",
249 | help="Location to write HF model and tokenizer",
250 | )
251 | args = parser.parse_args()
252 | write_model(
253 | model_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()),
254 | input_base_path=os.path.join(args.input_dir, args.model_size),
255 | model_size=args.model_size,
256 | )
257 | write_tokenizer(
258 | tokenizer_path=os.path.join(args.output_dir, "tokenizer"),
259 | input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
260 | )
261 |
262 |
263 | if __name__ == "__main__":
264 | main()
--------------------------------------------------------------------------------
/llama/modeling_llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ PyTorch LLaMA model."""
16 | import math
17 | from typing import List, Optional, Tuple, Union
18 |
19 | import torch
20 | import torch.utils.checkpoint
21 | from torch import nn
22 | from torch.nn import CrossEntropyLoss
23 |
24 | from transformers.activations import ACT2FN
25 | from transformers.modeling_outputs import (
26 | BaseModelOutputWithPast,
27 | CausalLMOutputWithPast,
28 | )
29 | from transformers.modeling_utils import PreTrainedModel
30 | from transformers.utils import (
31 | add_code_sample_docstrings,
32 | add_start_docstrings,
33 | add_start_docstrings_to_model_forward,
34 | logging,
35 | replace_return_docstrings,
36 | )
37 | from .configuration_llama import LLaMAConfig
38 |
39 |
40 | logger = logging.get_logger(__name__)
41 |
42 | _CHECKPOINT_FOR_DOC = "llama-7b"
43 | _CONFIG_FOR_DOC = "LLaMAConfig"
44 |
45 |
46 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
47 | """
48 | Make causal mask used for bi-directional self-attention.
49 | """
50 | bsz, tgt_len = input_ids_shape
51 | mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
52 | mask_cond = torch.arange(mask.size(-1))
53 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
54 | mask = mask.to(dtype)
55 |
56 | if past_key_values_length > 0:
57 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
58 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
59 |
60 |
61 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
62 | """
63 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
64 | """
65 | bsz, src_len = mask.size()
66 | tgt_len = tgt_len if tgt_len is not None else src_len
67 |
68 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
69 |
70 | inverted_mask = 1.0 - expanded_mask
71 |
72 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
73 |
74 |
75 | class RMSNorm(torch.nn.Module):
76 | def __init__(self, dim: int, eps: float = 1e-6):
77 | super().__init__()
78 | self.eps = eps
79 | self.weight = nn.Parameter(torch.ones(dim))
80 |
81 | def _norm(self, x):
82 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
83 |
84 | def forward(self, x):
85 | output = self._norm(x.float()).type_as(x)
86 | return output * self.weight
87 |
88 |
89 | class LLaMAFeedForward(nn.Module):
90 | def __init__(
91 | self,
92 | hidden_size: int,
93 | intermediate_size: int,
94 | hidden_act: str,
95 | ):
96 | super().__init__()
97 | self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
98 | self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
99 | self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
100 | self.act_fn = ACT2FN[hidden_act]
101 |
102 | def forward(self, x):
103 | return self.w2(self.act_fn(self.w1(x)) * self.w3(x))
104 |
105 |
106 | class LLaMAAttention(nn.Module):
107 | """Multi-headed attention from 'Attention Is All You Need' paper"""
108 |
109 | def __init__(
110 | self,
111 | hidden_size: int,
112 | num_heads: int,
113 | complex_frequencies: torch.Tensor,
114 | ):
115 | super().__init__()
116 | self.hidden_size = hidden_size
117 | self.num_heads = num_heads
118 | self.head_dim = hidden_size // num_heads
119 |
120 | if (self.head_dim * num_heads) != self.hidden_size:
121 | raise ValueError(
122 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
123 | f" and `num_heads`: {num_heads})."
124 | )
125 | self.q_proj = nn.Linear(
126 | hidden_size,
127 | num_heads * self.head_dim,
128 | bias=False,
129 | )
130 | self.k_proj = nn.Linear(
131 | hidden_size,
132 | num_heads * self.head_dim,
133 | bias=False,
134 | )
135 | self.v_proj = nn.Linear(
136 | hidden_size,
137 | num_heads * self.head_dim,
138 | bias=False,
139 | )
140 | self.o_proj = nn.Linear(
141 | num_heads * self.head_dim,
142 | hidden_size,
143 | bias=False,
144 | )
145 | self.complex_frequencies = complex_frequencies
146 |
147 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
148 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
149 |
150 | def forward(
151 | self,
152 | hidden_states: torch.Tensor,
153 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
154 | attention_mask: Optional[torch.Tensor] = None,
155 | layer_head_mask: Optional[torch.Tensor] = None,
156 | output_attentions: bool = False,
157 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
158 | """Input shape: Batch x Time x Channel"""
159 |
160 | self.complex_frequencies = self.complex_frequencies.to(hidden_states.device)
161 |
162 | bsz, tgt_len, _ = hidden_states.size()
163 |
164 | # get query proj
165 | query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim)
166 | key_states = self.k_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim)
167 | value_states = self.v_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim)
168 |
169 | if past_key_value is not None:
170 | start = past_key_value[0].shape[2]
171 | else:
172 | start = 0
173 |
174 | sliced_complex_frequencies = self.complex_frequencies[start : start + tgt_len]
175 | query_states, key_states = apply_rotary_emb(
176 | query_states=query_states, key_states=key_states, complex_frequencies=sliced_complex_frequencies
177 | )
178 |
179 | # get key, value proj
180 | key_states = self._shape(key_states, -1, bsz)
181 | value_states = self._shape(value_states, -1, bsz)
182 | if past_key_value is not None:
183 | # reuse k, v, self_attention
184 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
185 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
186 |
187 | past_key_value = (key_states, value_states)
188 | proj_shape = (bsz * self.num_heads, -1, self.head_dim)
189 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
190 | key_states = key_states.view(*proj_shape)
191 | value_states = value_states.view(*proj_shape)
192 |
193 | src_len = key_states.size(1)
194 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
195 |
196 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
197 | raise ValueError(
198 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
199 | f" {attn_weights.size()}"
200 | )
201 |
202 | if attention_mask is not None:
203 | if attention_mask.size() != (bsz, 1, tgt_len, src_len):
204 | raise ValueError(
205 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
206 | )
207 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
208 | attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
209 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
210 |
211 | # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
212 | if attn_weights.dtype == torch.float16:
213 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
214 | else:
215 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
216 |
217 | if layer_head_mask is not None:
218 | if layer_head_mask.size() != (self.num_heads,):
219 | raise ValueError(
220 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
221 | f" {layer_head_mask.size()}"
222 | )
223 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
224 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
225 |
226 | if output_attentions:
227 | # this operation is a bit awkward, but it's required to
228 | # make sure that attn_weights keeps its gradient.
229 | # In order to do so, attn_weights have to be reshaped
230 | # twice and have to be reused in the following
231 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
232 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
233 | else:
234 | attn_weights_reshaped = None
235 |
236 | attn_output = torch.bmm(attn_weights, value_states)
237 |
238 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
239 | raise ValueError(
240 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
241 | f" {attn_output.size()}"
242 | )
243 |
244 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
245 | attn_output = attn_output.transpose(1, 2)
246 |
247 | attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size)
248 |
249 | attn_output = self.o_proj(attn_output)
250 |
251 | return attn_output, attn_weights_reshaped, past_key_value
252 |
253 |
254 | class LLaMADecoderLayer(nn.Module):
255 | def __init__(self, config: LLaMAConfig):
256 | super().__init__()
257 | self.hidden_size = config.hidden_size
258 | complex_frequencies = precompute_complex_frequencies(
259 | head_dim=self.hidden_size // config.num_attention_heads,
260 | length=config.max_sequence_length * 2,
261 | )
262 | self.self_attn = LLaMAAttention(
263 | hidden_size=self.hidden_size,
264 | num_heads=config.num_attention_heads,
265 | complex_frequencies=complex_frequencies,
266 | )
267 | self.feed_forward = LLaMAFeedForward(
268 | hidden_size=self.hidden_size,
269 | intermediate_size=config.intermediate_size,
270 | hidden_act=config.hidden_act,
271 | )
272 | self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273 | self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
274 |
275 | def forward(
276 | self,
277 | hidden_states: torch.Tensor,
278 | attention_mask: Optional[torch.Tensor] = None,
279 | layer_head_mask: Optional[torch.Tensor] = None,
280 | output_attentions: Optional[bool] = False,
281 | use_cache: Optional[bool] = False,
282 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
283 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
284 | """
285 | Args:
286 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
287 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
288 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
289 | layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
290 | `(encoder_attention_heads,)`.
291 | output_attentions (`bool`, *optional*):
292 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
293 | returned tensors for more detail.
294 | use_cache (`bool`, *optional*):
295 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
296 | (see `past_key_values`).
297 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
298 | """
299 |
300 | residual = hidden_states
301 |
302 | hidden_states = self.attention_norm(hidden_states)
303 |
304 | # Self Attention
305 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
306 | hidden_states=hidden_states,
307 | past_key_value=past_key_value,
308 | attention_mask=attention_mask,
309 | layer_head_mask=layer_head_mask,
310 | output_attentions=output_attentions,
311 | )
312 | hidden_states = residual + hidden_states
313 |
314 | # Fully Connected
315 | residual = hidden_states
316 | hidden_states = self.ffn_norm(hidden_states)
317 | hidden_states = self.feed_forward(hidden_states)
318 | hidden_states = residual + hidden_states
319 |
320 | outputs = (hidden_states,)
321 |
322 | if output_attentions:
323 | outputs += (self_attn_weights,)
324 |
325 | if use_cache:
326 | outputs += (present_key_value,)
327 |
328 | return outputs
329 |
330 |
331 | LLAMA_START_DOCSTRING = r"""
332 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
333 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
334 | etc.)
335 |
336 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
337 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
338 | and behavior.
339 |
340 | Parameters:
341 | config ([`LLaMAConfig`]):
342 | Model configuration class with all the parameters of the model. Initializing with a config file does not
343 | load the weights associated with the model, only the configuration. Check out the
344 | [`~PreTrainedModel.from_pretrained`] method to load the model weights.
345 | """
346 |
347 |
348 | @add_start_docstrings(
349 | "The bare OPT Model outputting raw hidden-states without any specific head on top.",
350 | LLAMA_START_DOCSTRING,
351 | )
352 | class LLaMAPreTrainedModel(PreTrainedModel):
353 | config_class = LLaMAConfig
354 | base_model_prefix = "model"
355 | supports_gradient_checkpointing = True
356 | _no_split_modules = ["LLaMADecoderLayer"]
357 | _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
358 |
359 | def _init_weights(self, module):
360 | std = self.config.initializer_range
361 | if isinstance(module, nn.Linear):
362 | module.weight.data.normal_(mean=0.0, std=std)
363 | if module.bias is not None:
364 | module.bias.data.zero_()
365 | elif isinstance(module, nn.Embedding):
366 | module.weight.data.normal_(mean=0.0, std=std)
367 | if module.padding_idx is not None:
368 | module.weight.data[module.padding_idx].zero_()
369 |
370 | def _set_gradient_checkpointing(self, module, value=False):
371 | if isinstance(module, (LLaMADecoder)):
372 | module.gradient_checkpointing = value
373 |
374 |
375 | LLAMA_INPUTS_DOCSTRING = r"""
376 | Args:
377 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
378 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
379 | it.
380 |
381 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
382 | [`PreTrainedTokenizer.__call__`] for details.
383 |
384 | [What are input IDs?](../glossary#input-ids)
385 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
386 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
387 |
388 | - 1 for tokens that are **not masked**,
389 | - 0 for tokens that are **masked**.
390 |
391 | [What are attention masks?](../glossary#attention-mask)
392 |
393 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
394 | [`PreTrainedTokenizer.__call__`] for details.
395 |
396 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
397 | `past_key_values`).
398 |
399 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
400 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
401 | information on the default strategy.
402 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
403 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
404 |
405 | - 1 indicates the head is **not masked**,
406 | - 0 indicates the head is **masked**.
407 |
408 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
409 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
410 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
411 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
412 |
413 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
414 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
415 |
416 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
417 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
418 | `decoder_input_ids` of shape `(batch_size, sequence_length)`.
419 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
420 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
421 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
422 | model's internal embedding lookup matrix.
423 | use_cache (`bool`, *optional*):
424 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
425 | `past_key_values`).
426 | output_attentions (`bool`, *optional*):
427 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
428 | tensors for more detail.
429 | output_hidden_states (`bool`, *optional*):
430 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
431 | more detail.
432 | return_dict (`bool`, *optional*):
433 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
434 | """
435 |
436 |
437 | class LLaMADecoder(LLaMAPreTrainedModel):
438 | """
439 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LLaMADecoderLayer`]
440 |
441 | Args:
442 | config: LLaMAConfig
443 | """
444 |
445 | def __init__(self, config: LLaMAConfig):
446 | super().__init__(config)
447 | self.padding_idx = config.pad_token_id
448 |
449 | self.vocab_size = config.vocab_size
450 |
451 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
452 |
453 | self.layers = nn.ModuleList([LLaMADecoderLayer(config) for _ in range(config.num_hidden_layers)])
454 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
455 |
456 | self.gradient_checkpointing = False
457 | # Initialize weights and apply final processing
458 | self.post_init()
459 |
460 | def get_input_embeddings(self):
461 | return self.embed_tokens
462 |
463 | def set_input_embeddings(self, value):
464 | self.embed_tokens = value
465 |
466 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
467 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
468 | # create causal mask
469 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
470 | combined_attention_mask = None
471 | if input_shape[-1] > 1:
472 | combined_attention_mask = _make_causal_mask(
473 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
474 | ).to(inputs_embeds.device)
475 |
476 | if attention_mask is not None:
477 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
478 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
479 | inputs_embeds.device
480 | )
481 | combined_attention_mask = (
482 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
483 | )
484 |
485 | return combined_attention_mask
486 |
487 | def forward(
488 | self,
489 | input_ids: torch.LongTensor = None,
490 | attention_mask: Optional[torch.Tensor] = None,
491 | head_mask: Optional[torch.Tensor] = None,
492 | past_key_values: Optional[List[torch.FloatTensor]] = None,
493 | inputs_embeds: Optional[torch.FloatTensor] = None,
494 | use_cache: Optional[bool] = None,
495 | output_attentions: Optional[bool] = None,
496 | output_hidden_states: Optional[bool] = None,
497 | return_dict: Optional[bool] = None,
498 | ) -> Union[Tuple, BaseModelOutputWithPast]:
499 | r"""
500 | Args:
501 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
502 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
503 | provide it.
504 |
505 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
506 | [`PreTrainedTokenizer.__call__`] for details.
507 |
508 | [What are input IDs?](../glossary#input-ids)
509 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
510 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
511 |
512 | - 1 for tokens that are **not masked**,
513 | - 0 for tokens that are **masked**.
514 |
515 | [What are attention masks?](../glossary#attention-mask)
516 | head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
517 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
518 |
519 | - 1 indicates the head is **not masked**,
520 | - 0 indicates the head is **masked**.
521 |
522 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
523 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
524 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
525 |
526 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
527 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
528 |
529 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
530 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
531 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
532 |
533 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
534 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
535 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors
536 | than the model's internal embedding lookup matrix.
537 | output_attentions (`bool`, *optional*):
538 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
539 | returned tensors for more detail.
540 | output_hidden_states (`bool`, *optional*):
541 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
542 | for more detail.
543 | return_dict (`bool`, *optional*):
544 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
545 | """
546 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
547 | output_hidden_states = (
548 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
549 | )
550 | use_cache = use_cache if use_cache is not None else self.config.use_cache
551 |
552 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
553 |
554 | # retrieve input_ids and inputs_embeds
555 | if input_ids is not None and inputs_embeds is not None:
556 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
557 | elif input_ids is not None:
558 | input_shape = input_ids.size()
559 | input_ids = input_ids.view(-1, input_shape[-1])
560 | elif inputs_embeds is not None:
561 | input_shape = inputs_embeds.size()[:-1]
562 | else:
563 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
564 |
565 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
566 |
567 | if inputs_embeds is None:
568 | inputs_embeds = self.embed_tokens(input_ids)
569 |
570 | # embed positions
571 | if attention_mask is None:
572 | attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
573 |
574 | attention_mask = self._prepare_decoder_attention_mask(
575 | attention_mask, input_shape, inputs_embeds, past_key_values_length
576 | )
577 |
578 | hidden_states = inputs_embeds
579 |
580 | if self.gradient_checkpointing and self.training:
581 | if use_cache:
582 | logger.warning_once(
583 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
584 | )
585 | use_cache = False
586 |
587 | # decoder layers
588 | all_hidden_states = () if output_hidden_states else None
589 | all_self_attns = () if output_attentions else None
590 | next_decoder_cache = () if use_cache else None
591 |
592 | # check if head_mask has a correct number of layers specified if desired
593 | for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
594 | if attn_mask is not None:
595 | if attn_mask.size()[0] != (len(self.layers)):
596 | raise ValueError(
597 | f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
598 | f" {head_mask.size()[0]}."
599 | )
600 |
601 | for idx, decoder_layer in enumerate(self.layers):
602 | if output_hidden_states:
603 | all_hidden_states += (hidden_states,)
604 |
605 | past_key_value = past_key_values[idx] if past_key_values is not None else None
606 |
607 | if self.gradient_checkpointing and self.training:
608 |
609 | def create_custom_forward(module):
610 | def custom_forward(*inputs):
611 | # None for past_key_value
612 | return module(*inputs, output_attentions, None)
613 |
614 | return custom_forward
615 |
616 | layer_outputs = torch.utils.checkpoint.checkpoint(
617 | create_custom_forward(decoder_layer),
618 | hidden_states,
619 | attention_mask,
620 | head_mask[idx] if head_mask is not None else None,
621 | None,
622 | )
623 | else:
624 | layer_outputs = decoder_layer(
625 | hidden_states,
626 | attention_mask=attention_mask,
627 | layer_head_mask=(head_mask[idx] if head_mask is not None else None),
628 | past_key_value=past_key_value,
629 | output_attentions=output_attentions,
630 | use_cache=use_cache,
631 | )
632 |
633 | hidden_states = layer_outputs[0]
634 |
635 | if use_cache:
636 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
637 |
638 | if output_attentions:
639 | all_self_attns += (layer_outputs[1],)
640 |
641 | hidden_states = self.norm(hidden_states)
642 |
643 | # add hidden states from the last decoder layer
644 | if output_hidden_states:
645 | all_hidden_states += (hidden_states,)
646 |
647 | next_cache = next_decoder_cache if use_cache else None
648 | if not return_dict:
649 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
650 | return BaseModelOutputWithPast(
651 | last_hidden_state=hidden_states,
652 | past_key_values=next_cache,
653 | hidden_states=all_hidden_states,
654 | attentions=all_self_attns,
655 | )
656 |
657 |
658 | @add_start_docstrings(
659 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
660 | LLAMA_START_DOCSTRING,
661 | )
662 | class LLaMAModel(LLaMAPreTrainedModel):
663 | def __init__(self, config: LLaMAConfig):
664 | super().__init__(config)
665 | self.decoder = LLaMADecoder(config)
666 | # Initialize weights and apply final processing
667 | self.post_init()
668 |
669 | def get_input_embeddings(self):
670 | return self.decoder.embed_tokens
671 |
672 | def set_input_embeddings(self, value):
673 | self.decoder.embed_tokens = value
674 |
675 | def get_decoder(self):
676 | return self.decoder
677 |
678 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
679 | @add_code_sample_docstrings(
680 | checkpoint=_CHECKPOINT_FOR_DOC,
681 | output_type=BaseModelOutputWithPast,
682 | config_class=_CONFIG_FOR_DOC,
683 | )
684 | def forward(
685 | self,
686 | input_ids: torch.LongTensor = None,
687 | attention_mask: Optional[torch.Tensor] = None,
688 | head_mask: Optional[torch.Tensor] = None,
689 | past_key_values: Optional[List[torch.FloatTensor]] = None,
690 | inputs_embeds: Optional[torch.FloatTensor] = None,
691 | use_cache: Optional[bool] = None,
692 | output_attentions: Optional[bool] = None,
693 | output_hidden_states: Optional[bool] = None,
694 | return_dict: Optional[bool] = None,
695 | ) -> Union[Tuple, BaseModelOutputWithPast]:
696 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
697 | output_hidden_states = (
698 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
699 | )
700 | use_cache = use_cache if use_cache is not None else self.config.use_cache
701 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
702 |
703 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
704 | decoder_outputs = self.decoder(
705 | input_ids=input_ids,
706 | attention_mask=attention_mask,
707 | head_mask=head_mask,
708 | past_key_values=past_key_values,
709 | inputs_embeds=inputs_embeds,
710 | use_cache=use_cache,
711 | output_attentions=output_attentions,
712 | output_hidden_states=output_hidden_states,
713 | return_dict=return_dict,
714 | )
715 |
716 | if not return_dict:
717 | return decoder_outputs
718 |
719 | return BaseModelOutputWithPast(
720 | last_hidden_state=decoder_outputs.last_hidden_state,
721 | past_key_values=decoder_outputs.past_key_values,
722 | hidden_states=decoder_outputs.hidden_states,
723 | attentions=decoder_outputs.attentions,
724 | )
725 |
726 |
727 | class LLaMAForCausalLM(LLaMAPreTrainedModel):
728 | _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
729 |
730 | def __init__(self, config):
731 | super().__init__(config)
732 | self.model = LLaMAModel(config)
733 |
734 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
735 |
736 | # Initialize weights and apply final processing
737 | self.post_init()
738 |
739 | def get_input_embeddings(self):
740 | return self.model.decoder.embed_tokens
741 |
742 | def set_input_embeddings(self, value):
743 | self.model.decoder.embed_tokens = value
744 |
745 | def get_output_embeddings(self):
746 | return self.lm_head
747 |
748 | def set_output_embeddings(self, new_embeddings):
749 | self.lm_head = new_embeddings
750 |
751 | def set_decoder(self, decoder):
752 | self.model.decoder = decoder
753 |
754 | def get_decoder(self):
755 | return self.model.decoder
756 |
757 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
758 | def forward(
759 | self,
760 | input_ids: torch.LongTensor = None,
761 | attention_mask: Optional[torch.Tensor] = None,
762 | head_mask: Optional[torch.Tensor] = None,
763 | past_key_values: Optional[List[torch.FloatTensor]] = None,
764 | inputs_embeds: Optional[torch.FloatTensor] = None,
765 | labels: Optional[torch.LongTensor] = None,
766 | use_cache: Optional[bool] = None,
767 | output_attentions: Optional[bool] = None,
768 | output_hidden_states: Optional[bool] = None,
769 | return_dict: Optional[bool] = None,
770 | ) -> Union[Tuple, CausalLMOutputWithPast]:
771 | r"""
772 | Args:
773 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
774 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
775 | provide it.
776 |
777 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
778 | [`PreTrainedTokenizer.__call__`] for details.
779 |
780 | [What are input IDs?](../glossary#input-ids)
781 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
782 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
783 |
784 | - 1 for tokens that are **not masked**,
785 | - 0 for tokens that are **masked**.
786 |
787 | [What are attention masks?](../glossary#attention-mask)
788 | head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
789 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
790 |
791 | - 1 indicates the head is **not masked**,
792 | - 0 indicates the head is **masked**.
793 |
794 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
795 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
796 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
797 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
798 | tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
799 |
800 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
801 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
802 |
803 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
804 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
805 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
806 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
807 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
808 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors
809 | than the model's internal embedding lookup matrix.
810 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
811 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
812 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
813 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
814 | use_cache (`bool`, *optional*):
815 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
816 | (see `past_key_values`).
817 | output_attentions (`bool`, *optional*):
818 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
819 | returned tensors for more detail.
820 | output_hidden_states (`bool`, *optional*):
821 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
822 | for more detail.
823 | return_dict (`bool`, *optional*):
824 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
825 |
826 | Returns:
827 |
828 | Example:
829 |
830 | ```python
831 | >>> from transformers import AutoTokenizer, LLaMAForCausalLM
832 |
833 | >>> model = LLaMAForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
834 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
835 |
836 | >>> prompt = "Hey, are you consciours? Can you talk to me?"
837 | >>> inputs = tokenizer(prompt, return_tensors="pt")
838 |
839 | >>> # Generate
840 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
841 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
842 | "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
843 | ```"""
844 |
845 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
846 | output_hidden_states = (
847 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
848 | )
849 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
850 |
851 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
852 | outputs = self.model.decoder(
853 | input_ids=input_ids,
854 | attention_mask=attention_mask,
855 | head_mask=head_mask,
856 | past_key_values=past_key_values,
857 | inputs_embeds=inputs_embeds,
858 | use_cache=use_cache,
859 | output_attentions=output_attentions,
860 | output_hidden_states=output_hidden_states,
861 | return_dict=return_dict,
862 | )
863 |
864 | logits = self.lm_head(outputs[0]).contiguous()
865 |
866 | loss = None
867 | if labels is not None:
868 | # Shift so that tokens < n predict n
869 | shift_logits = logits[..., :-1, :].contiguous()
870 | shift_labels = labels[..., 1:].contiguous()
871 | # Flatten the tokens
872 | loss_fct = CrossEntropyLoss()
873 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
874 |
875 | if not return_dict:
876 | output = (logits,) + outputs[1:]
877 | return (loss,) + output if loss is not None else output
878 |
879 | return CausalLMOutputWithPast(
880 | loss=loss,
881 | logits=logits,
882 | past_key_values=outputs.past_key_values,
883 | hidden_states=outputs.hidden_states,
884 | attentions=outputs.attentions,
885 | )
886 |
887 | def prepare_inputs_for_generation(
888 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
889 | ):
890 | if past_key_values:
891 | input_ids = input_ids[:, -1:]
892 |
893 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
894 | if inputs_embeds is not None and past_key_values is None:
895 | model_inputs = {"inputs_embeds": inputs_embeds}
896 | else:
897 | model_inputs = {"input_ids": input_ids}
898 |
899 | model_inputs.update(
900 | {
901 | "past_key_values": past_key_values,
902 | "use_cache": kwargs.get("use_cache"),
903 | "attention_mask": attention_mask,
904 | }
905 | )
906 | return model_inputs
907 |
908 | @staticmethod
909 | def _reorder_cache(past_key_values, beam_idx):
910 | reordered_past = ()
911 | for layer_past in past_key_values:
912 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
913 | return reordered_past
914 |
915 |
916 | def precompute_complex_frequencies(head_dim: int, length: int, theta: float = 10000.0):
917 | frequencies = 1.0 / (theta ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
918 | t = torch.arange(length, device=frequencies.device)
919 | frequencies = torch.outer(t, frequencies).float()
920 | return torch.polar(torch.ones_like(frequencies), frequencies) # complex64
921 |
922 |
923 | def apply_rotary_emb(
924 | query_states: torch.Tensor,
925 | key_states: torch.Tensor,
926 | complex_frequencies: torch.Tensor,
927 | ) -> Tuple[torch.Tensor, torch.Tensor]:
928 | query_states_complex = torch.view_as_complex(query_states.float().reshape(*key_states.shape[:-1], -1, 2))
929 | key_states_complex = torch.view_as_complex(key_states.float().reshape(*key_states.shape[:-1], -1, 2))
930 | complex_frequencies = reshape_for_broadcast(complex_frequencies, query_states_complex)
931 | output_query_states = torch.view_as_real(query_states_complex * complex_frequencies).flatten(3)
932 | output_key_states = torch.view_as_real(key_states_complex * complex_frequencies).flatten(3)
933 | return output_query_states.type_as(query_states), output_key_states.type_as(key_states)
934 |
935 |
936 | def reshape_for_broadcast(complex_frequencies: torch.Tensor, x: torch.Tensor):
937 | ndim = x.ndim
938 | assert 0 <= 1 < ndim
939 | assert complex_frequencies.shape == (x.shape[1], x.shape[-1])
940 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
941 | return complex_frequencies.view(*shape)
--------------------------------------------------------------------------------
/llama/tokenization_llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The FAIR team of Meta AI and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for LLaMA."""
16 | import os
17 | import re
18 | from shutil import copyfile
19 | from typing import Any, Dict, List, Optional, Tuple
20 |
21 | import sentencepiece as spm
22 |
23 | from transformers.tokenization_utils import PreTrainedTokenizer
24 | from transformers.utils import logging
25 |
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 | VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
30 |
31 | PRETRAINED_VOCAB_FILES_MAP = {}
32 |
33 |
34 | class LLaMATokenizer(PreTrainedTokenizer):
35 | """
36 | Construct a LLaMA tokenizer. Based on byte-level Byte-Pair-Encoding.
37 |
38 | Args:
39 | vocab_file (`str`):
40 | Path to the vocabulary file.
41 | """
42 |
43 | vocab_files_names = VOCAB_FILES_NAMES
44 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
45 | model_input_names = ["input_ids", "attention_mask"]
46 |
47 | def __init__(
48 | self,
49 | vocab_file,
50 | unk_token="",
51 | bos_token="",
52 | eos_token="",
53 | sp_model_kwargs: Optional[Dict[str, Any]] = None,
54 | add_bos_token=False,
55 | add_eos_token=False,
56 | **kwargs,
57 | ):
58 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
59 | super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
60 | self.vocab_file = vocab_file
61 | self.add_bos_token = add_bos_token
62 | self.add_eos_token = add_eos_token
63 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
64 | self.sp_model.Load(vocab_file)
65 |
66 | """ Initialisation"""
67 |
68 | @property
69 | def vocab_size(self):
70 | """Returns vocab size"""
71 | return self.sp_model.get_piece_size()
72 |
73 | @property
74 | def bos_token_id(self) -> Optional[int]:
75 | return self.sp_model.bos_id()
76 |
77 | @property
78 | def eos_token_id(self) -> Optional[int]:
79 | return self.sp_model.eos_id()
80 |
81 | def get_vocab(self):
82 | """Returns vocab as a dict"""
83 | vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
84 | vocab.update(self.added_tokens_encoder)
85 | return vocab
86 |
87 | def _tokenize(self, text):
88 | """Returns a tokenized string."""
89 | return self.sp_model.encode(text, out_type=str)
90 |
91 | def _convert_token_to_id(self, token):
92 | """Converts a token (str) in an id using the vocab."""
93 | return self.sp_model.piece_to_id(token)
94 |
95 | def _convert_id_to_token(self, index):
96 | """Converts an index (integer) in a token (str) using the vocab."""
97 | token = self.sp_model.IdToPiece(index)
98 | return token
99 |
100 | def convert_tokens_to_string(self, tokens):
101 | """Converts a sequence of tokens (string) in a single string."""
102 | current_sub_tokens = []
103 | out_string = ""
104 | prev_is_special = False
105 | for token in tokens:
106 | # make sure that special tokens are not decoded using sentencepiece model
107 | if token in self.all_special_tokens:
108 | if not prev_is_special:
109 | out_string += " "
110 | out_string += self.sp_model.decode(current_sub_tokens) + token
111 | prev_is_special = True
112 | current_sub_tokens = []
113 | else:
114 | current_sub_tokens.append(token)
115 | prev_is_special = False
116 | out_string += self.sp_model.decode(current_sub_tokens)
117 | return out_string.strip()
118 |
119 | def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
120 | """
121 | Save the vocabulary and special tokens file to a directory.
122 |
123 | Args:
124 | save_directory (`str`):
125 | The directory in which to save the vocabulary.
126 |
127 | Returns:
128 | `Tuple(str)`: Paths to the files saved.
129 | """
130 | if not os.path.isdir(save_directory):
131 | logger.error(f"Vocabulary path ({save_directory}) should be a directory")
132 | return
133 | out_vocab_file = os.path.join(
134 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
135 | )
136 |
137 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
138 | copyfile(self.vocab_file, out_vocab_file)
139 | elif not os.path.isfile(self.vocab_file):
140 | with open(out_vocab_file, "wb") as fi:
141 | content_spiece_model = self.sp_model.serialized_model_proto()
142 | fi.write(content_spiece_model)
143 |
144 | return (out_vocab_file,)
145 |
146 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
147 | if self.add_bos_token:
148 | bos_token_ids = [self.bos_token_id]
149 | else:
150 | bos_token_ids = []
151 |
152 | output = bos_token_ids + token_ids_0
153 |
154 | if token_ids_1 is not None:
155 | output = output + token_ids_1
156 |
157 | if self.add_eos_token:
158 | output = output + [self.eos_token_id]
159 |
160 | return output
161 |
162 | def get_special_tokens_mask(
163 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
164 | ) -> List[int]:
165 | """
166 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
167 | special tokens using the tokenizer `prepare_for_model` method.
168 |
169 | Args:
170 | token_ids_0 (`List[int]`):
171 | List of IDs.
172 | token_ids_1 (`List[int]`, *optional*):
173 | Optional second list of IDs for sequence pairs.
174 | already_has_special_tokens (`bool`, *optional*, defaults to `False`):
175 | Whether or not the token list is already formatted with special tokens for the model.
176 |
177 | Returns:
178 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
179 | """
180 | if already_has_special_tokens:
181 | return super().get_special_tokens_mask(
182 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
183 | )
184 |
185 | if token_ids_1 is None:
186 | return [1] + ([0] * len(token_ids_0)) + [1]
187 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
188 |
189 | def create_token_type_ids_from_sequences(
190 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
191 | ) -> List[int]:
192 | """
193 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
194 | use of token type ids, therefore a list of zeros is returned.
195 |
196 | Args:
197 | token_ids_0 (`List[int]`):
198 | List of IDs.
199 | token_ids_1 (`List[int]`, *optional*):
200 | Optional second list of IDs for sequence pairs.
201 |
202 | Returns:
203 | `List[int]`: List of zeros.
204 | """
205 | eos = [self.eos_token_id]
206 |
207 | if token_ids_1 is None:
208 | return len(token_ids_0 + eos) * [0]
209 | return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
--------------------------------------------------------------------------------