├── .gitignore
├── LICENSE
├── MCSD
├── README.md
├── evaluation.py
├── inference
│ ├── __init__.py
│ ├── generate.py
│ └── strategies.py
└── model
│ ├── __init__.py
│ └── llama_tree_attn
│ ├── __init__.py
│ ├── configuration_llama.py
│ ├── convert_llama_weights_to_hf.py
│ ├── modeling_llama.py
│ ├── tokenization_llama.py
│ └── tokenization_llama_fast.py
├── README.md
└── dataset
└── wmt_ende.json
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 NJUNLP
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MCSD/README.md:
--------------------------------------------------------------------------------
1 | # Source Code for Multi-Candidate Speculative Decoding
2 |
3 | We provide Python application interfaces for inference, as well as command-line interfaces for evaluation.
4 |
5 | ## Dependencies
6 |
7 | PyTorch version >= 1.11.0
8 |
9 | Python version >= 3.8
10 |
11 | transformers >= 4.34.0
12 |
13 | ## Evaluation CLI
14 | Run the following script for evaluation:
15 | ```
16 | python evaluation.py \
17 | --draft-model PATH_TO_DRAFT_MODEL \
18 | --target-model PATH_TO_TARGET_MODEL \
19 | --fp16 \
20 | --k-config 4,2,2 \
21 | --datapath PATH_TO_DATA \
22 | --sampling-type sampling
23 | ```
24 |
25 | ### Options
26 | ```
27 | -h, --help show this help message and exit
28 | --draft-model Draft model path.
29 | --target-model Target model path.
30 | --tokenizer Tokenizer path. If not provided, use the Target model path.
31 | --fp16 Use float16 dtype.
32 | --k-config Use comma separations, e.g. `--k-config 4,2,2`.
33 | --datapath The json data file.
34 | --max-new-tokens
35 | --replacement Sampling with replacement.
36 | --naive-sampling Use multi-candidate naive sampling.
37 | --disable-tree-attn
38 | --sampling-type {argmax,sampling}
39 | --disable-tqdm
40 | --auto-model Use AutoModelForCausalLM and AutoTokenizer to load the model and tokenizer, this will disable the tree attn.
41 | ```
42 |
43 | Note:
44 | * Tree Attn is currently not supported for models other than LLaMA. Therefore, when using '--auto-model', Tree Attn will be disabled.
45 | * Since flash-attn does not support custom attention masks, it is currently incompatible with Tree Attn.
46 |
47 | ## Python application interfaces
48 | Here is an example of inference using our generator, see here for the function of each argument.
49 | ```python
50 | import torch
51 | from model.llama_tree_attn import LlamaForCausalLM, LlamaTokenizer
52 | from inference.generate import SpeculativeGenerator
53 |
54 | draft_model = LlamaForCausalLM.from_pretrained(
55 | "PATH_TO_DRAFT_MODEL",
56 | torch_dtype=torch.float16,
57 | device_map=0,
58 | )
59 | target_model = LlamaForCausalLM.from_pretrained(
60 | "PATH_TO_TARGET_MODEL",
61 | torch_dtype=torch.float16,
62 | device_map="auto",
63 | )
64 | tokenizer = LlamaTokenizer.from_pretrained("PATH_TO_TARGET_MODEL")
65 |
66 | generator = SpeculativeGenerator(
67 | draft_model,
68 | target_model,
69 | eos_token_id=tokenizer.eos_token_id,
70 | k_config=(4, 2, 2),
71 | max_new_tokens=128,
72 | draft_model_temp=1,
73 | target_model_temp=1,
74 | replacement=False,
75 | speculative_sampling=True,
76 | tree_attn=True,
77 | )
78 |
79 | prompt_text = "Hey, are you conscious? Can you talk to me?"
80 | inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")
81 | input_ids = inputs.input_ids
82 | with torch.no_grad():
83 | output = generator.generate(input_ids)
84 | output_text = tokenizer.batch_decode(
85 | output.sequences, skip_special_tokens=True, clean_up_tokenization_spaces=False
86 | )[0]
87 |
88 | print("Output:\n{}".format(output_text))
89 |
90 | ```
91 |
--------------------------------------------------------------------------------
/MCSD/evaluation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import time
5 | from typing import Literal, Tuple
6 |
7 | import torch
8 | from inference.generate import Generator, BaseGenerator, SpeculativeGenerator
9 | from model.llama_tree_attn import LlamaForCausalLM, LlamaTokenizer
10 | from transformers import AutoModelForCausalLM, AutoTokenizer
11 | from tqdm import tqdm
12 |
13 | # Setup logging
14 | logging.basicConfig(
15 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
16 | datefmt="%m/%d/%Y %H:%M:%S",
17 | level=logging.INFO,
18 | )
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class JsonData:
24 | def __init__(self, path) -> None:
25 | with open(path) as fin:
26 | self.data = json.load(fin)
27 |
28 | def __getitem__(self, index) -> Tuple[str, str]:
29 | return self.data[index]
30 |
31 | def __len__(self):
32 | return len(self.data)
33 |
34 |
35 | def run_eval(
36 | draft_model,
37 | target_model,
38 | tokenizer,
39 | k_config: Tuple[int],
40 | datapath: str,
41 | max_new_tokens: int = 128,
42 | replacement=False,
43 | speculative_sampling=True,
44 | tree_attn=True,
45 | sampling_type: Literal["argmax", "sampling"] = "sampling",
46 | disable_tqdm: bool = False,
47 | ):
48 | if sampling_type not in ["argmax", "sampling"]:
49 | raise ValueError(
50 | f'`sampling_type` can be either `"argmax"` or `"sampling"`, but received "{sampling_type}"'
51 | )
52 | if sampling_type == "argmax":
53 | target_model_temp = 0
54 | draft_model_temp = 0
55 | else:
56 | target_model_temp = 1
57 | draft_model_temp = 1
58 |
59 | dataloader = JsonData(datapath)
60 | generator = SpeculativeGenerator(
61 | draft_model,
62 | target_model,
63 | eos_token_id=tokenizer.eos_token_id,
64 | k_config=k_config,
65 | max_new_tokens=max_new_tokens,
66 | draft_model_temp=draft_model_temp,
67 | target_model_temp=target_model_temp,
68 | replacement=replacement,
69 | speculative_sampling=speculative_sampling,
70 | tree_attn=tree_attn,
71 | )
72 |
73 | draft_model.eval()
74 | target_model.eval()
75 |
76 | logger.info("evaluation start.")
77 | start_time = time.time()
78 |
79 | acceptance_count = 0
80 | draft_token_count = 0
81 | invocation_count = 0
82 |
83 | iterator = range(len(dataloader))
84 | with torch.no_grad():
85 | for sample_idx in iterator if disable_tqdm else tqdm(iterator):
86 | prompt_text = dataloader[sample_idx]
87 | inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")
88 | input_ids = inputs.input_ids
89 | output = generator.generate(input_ids)
90 |
91 | acceptance_count += output.acceptance_count
92 | draft_token_count += output.draft_token_count
93 | invocation_count += output.invocation_count
94 | end_time = time.time()
95 |
96 | logger.info("evaluation complete.")
97 |
98 | run_time = end_time - start_time
99 |
100 | latency = run_time / (acceptance_count + invocation_count)
101 | acceptance_rate = acceptance_count / draft_token_count
102 | block_efficiency = 1 + acceptance_count / invocation_count
103 |
104 | logger.info("Running time: {:.2f} s".format(run_time))
105 | logger.info("Token latency: {:.2f} ms".format(latency * 1000))
106 | logger.info("Acceptance rate: {:.2f}".format(acceptance_rate))
107 | logger.info("Block efficiency: {:.2f}".format(block_efficiency))
108 |
109 |
110 | def run_baseline_eval(
111 | target_model,
112 | tokenizer,
113 | datapath: str,
114 | max_new_tokens: int = 128,
115 | sampling_type: Literal["argmax", "sampling"] = "sampling",
116 | disable_tqdm: bool = False,
117 | ):
118 | if sampling_type not in ["argmax", "sampling"]:
119 | raise ValueError(
120 | f'`sampling_type` can be either `"argmax"` or `"sampling"`, but received "{sampling_type}"'
121 | )
122 | if sampling_type == "argmax":
123 | target_model_temp = 0
124 | else:
125 | target_model_temp = 1
126 |
127 | dataloader = JsonData(datapath)
128 | generator = BaseGenerator(
129 | target_model,
130 | eos_token_id=tokenizer.eos_token_id,
131 | max_new_tokens=max_new_tokens,
132 | temp=target_model_temp,
133 | )
134 |
135 | target_model.eval()
136 |
137 | logger.info("evaluation start.")
138 | start_time = time.time()
139 |
140 | invocation_count = 0
141 |
142 | iterator = range(len(dataloader))
143 | with torch.no_grad():
144 | for sample_idx in iterator if disable_tqdm else tqdm(iterator):
145 | prompt_text = dataloader[sample_idx]
146 | inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")
147 | input_ids = inputs.input_ids
148 | output = generator.generate(input_ids)
149 |
150 | invocation_count += output.invocation_count
151 | end_time = time.time()
152 |
153 | logger.info("evaluation complete.")
154 |
155 | run_time = end_time - start_time
156 |
157 | latency = run_time / invocation_count
158 |
159 | logger.info("Running time: {:.2f} s".format(run_time))
160 | logger.info("Token latency: {:.2f} ms".format(latency * 1000))
161 |
162 |
163 | def main(args):
164 | torch_dtype = torch.float16 if args.fp16 else torch.float32
165 |
166 | logger.info("The full evaluation configuration:\n" + repr(args))
167 |
168 | if args.auto_model and not args.disable_tree_attn:
169 | logger.warning(
170 | "Tree Attn is currently not supported for models other than LLaMA. Therefore, "
171 | "when using '--auto-model', Tree Attn will be disabled."
172 | )
173 | args.disable_tree_attn = True
174 |
175 | ModelLoader = AutoModelForCausalLM if args.auto_model else LlamaForCausalLM
176 | TokenizerLoader = AutoTokenizer if args.auto_model else LlamaTokenizer
177 |
178 | logger.info("Loading draft model: {}".format(args.draft_model))
179 | draft_model = ModelLoader.from_pretrained(
180 | args.draft_model,
181 | torch_dtype=torch_dtype,
182 | device_map=0,
183 | use_flash_attention_2=True if args.flash_attn else False,
184 | )
185 |
186 | logger.info("Loading target model: {}".format(args.target_model))
187 | target_model = ModelLoader.from_pretrained(
188 | args.target_model,
189 | torch_dtype=torch_dtype,
190 | device_map="auto",
191 | use_flash_attention_2=True if args.flash_attn else False,
192 | )
193 |
194 | tokenizer = TokenizerLoader.from_pretrained(args.tokenizer)
195 |
196 | if args.run_baseline:
197 | run_baseline_eval(
198 | target_model,
199 | tokenizer=tokenizer,
200 | datapath=args.datapath,
201 | max_new_tokens=args.max_new_tokens,
202 | sampling_type=args.sampling_type,
203 | disable_tqdm=args.disable_tqdm,
204 | )
205 | else:
206 | run_eval(
207 | draft_model,
208 | target_model,
209 | tokenizer=tokenizer,
210 | k_config=args.k_config,
211 | datapath=args.datapath,
212 | max_new_tokens=args.max_new_tokens,
213 | replacement=args.replacement,
214 | speculative_sampling=not args.naive_sampling,
215 | tree_attn=not args.disable_tree_attn,
216 | sampling_type=args.sampling_type,
217 | disable_tqdm=args.disable_tqdm,
218 | )
219 |
220 |
221 | if __name__ == "__main__":
222 | parser = argparse.ArgumentParser()
223 | parser.add_argument(
224 | "--draft-model", type=str, required=True, help="Draft model path."
225 | )
226 | parser.add_argument(
227 | "--target-model", type=str, required=True, help="Target model path."
228 | )
229 | parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path.")
230 | parser.add_argument("--fp16", action="store_true", help="use float16 dtype.")
231 |
232 | parser.add_argument(
233 | "--k-config",
234 | type=lambda x: tuple(map(int, x.split(","))),
235 | required=True,
236 | help="Use comma separations, e.g. `--k-config 4,2,2`.",
237 | )
238 |
239 | parser.add_argument(
240 | "--datapath", type=str, required=True, help="The json data file."
241 | )
242 | parser.add_argument("--max-new-tokens", type=int, default=128)
243 | parser.add_argument(
244 | "--replacement",
245 | action="store_true",
246 | help="Sampling with replacement.",
247 | )
248 | parser.add_argument(
249 | "--naive-sampling",
250 | action="store_true",
251 | help="Use multi-candidate naive sampling.",
252 | )
253 |
254 | parser.add_argument("--disable-tree-attn", action="store_true")
255 |
256 | parser.add_argument(
257 | "--sampling-type", type=str, default="sampling", choices=["argmax", "sampling"]
258 | )
259 |
260 | parser.add_argument("--disable-tqdm", action="store_true")
261 |
262 | parser.add_argument("--auto-model", action="store_true")
263 | parser.add_argument("--run-baseline", action="store_true")
264 |
265 | parser.add_argument("--flash-attn", action="store_true")
266 |
267 | args = parser.parse_args()
268 |
269 | if args.tokenizer is None:
270 | args.tokenizer = args.target_model
271 | main(args)
272 |
--------------------------------------------------------------------------------
/MCSD/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NJUNLP/MCSD/8aadd6501a9e987ba5fca6cc8f9ad5949e480ec7/MCSD/inference/__init__.py
--------------------------------------------------------------------------------
/MCSD/inference/generate.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional, Tuple, Union
3 |
4 | import torch
5 | from transformers.modeling_outputs import ModelOutput, CausalLMOutputWithPast
6 |
7 | from . import strategies
8 |
9 |
10 | @dataclass
11 | class DecoderOnlyOutput(ModelOutput):
12 | """
13 | Base class for outputs of decoder-only generation models using MCSD.
14 | """
15 |
16 | sequences: torch.LongTensor
17 | acceptance_count: int = None
18 | draft_token_count: int = None
19 | invocation_count: int = None
20 |
21 |
22 | class Generator:
23 | def __init__(self) -> None:
24 | pass
25 |
26 | def generate(
27 | self,
28 | input_ids: Optional[torch.Tensor] = None,
29 | ) -> DecoderOnlyOutput:
30 | raise NotImplementedError
31 |
32 |
33 | class BaseGenerator:
34 | def __init__(
35 | self,
36 | model,
37 | eos_token_id: int,
38 | max_new_tokens: int = 128,
39 | temp: float = 1,
40 | ) -> None:
41 | self.model = model
42 | self.eos_token_id = eos_token_id
43 | self.max_new_tokens = max_new_tokens
44 | self.temp = temp
45 |
46 | def generate(
47 | self,
48 | input_ids: Optional[torch.Tensor] = None,
49 | ) -> DecoderOnlyOutput:
50 | past_key_values = None
51 | invocation_count = 0
52 |
53 | init_input_len = input_ids.size(-1)
54 |
55 | while True:
56 | if past_key_values is not None:
57 | pruned_input_ids = input_ids[:, past_key_values[0][0].size(2) :]
58 | else:
59 | pruned_input_ids = input_ids
60 |
61 | outputs: CausalLMOutputWithPast = self.model(
62 | input_ids=pruned_input_ids,
63 | use_cache=True,
64 | past_key_values=past_key_values,
65 | return_dict=True,
66 | output_attentions=False,
67 | output_hidden_states=False,
68 | )
69 |
70 | logits = outputs.logits
71 | past_key_values = outputs.past_key_values
72 |
73 | batch_num, seq_len, _ = logits.size()
74 |
75 | if self.temp == 0:
76 | _, ground_tokens = logits.topk(k=1, dim=-1) # batch x seq_len x 1
77 | else:
78 | ground_probs = torch.softmax(
79 | logits / self.temp, dim=-1
80 | ) # batch x seq_len x hidden_dim
81 |
82 | ground_tokens = torch.multinomial(
83 | ground_probs.view(batch_num * seq_len, -1), num_samples=1
84 | ) # batch*seq_len x 1
85 | ground_tokens = ground_tokens.view(batch_num, seq_len)
86 |
87 | input_ids = torch.cat(
88 | (input_ids, ground_tokens[:, -1:].to(input_ids)), dim=1
89 | )
90 |
91 | invocation_count += 1
92 |
93 | if (
94 | self.eos_token_id in input_ids[0, -1:]
95 | or input_ids.size(-1) - init_input_len >= self.max_new_tokens
96 | ):
97 | break
98 | return DecoderOnlyOutput(sequences=input_ids, invocation_count=invocation_count)
99 |
100 |
101 | class SpeculativeGenerator:
102 | def __init__(
103 | self,
104 | draft_model,
105 | target_model,
106 | eos_token_id: int,
107 | k_config: Tuple[int],
108 | max_new_tokens: int = 128,
109 | draft_model_temp: float = 1,
110 | target_model_temp: float = 1,
111 | replacement: bool = False,
112 | speculative_sampling: bool = True,
113 | tree_attn: bool = True,
114 | ) -> None:
115 | self.eos_token_id = eos_token_id
116 | self.max_new_tokens = max_new_tokens
117 | self.strategy: strategies.Strategy = None
118 |
119 | if tree_attn:
120 | self.strategy = strategies.TreeStrategy(
121 | draft_model=draft_model,
122 | target_model=target_model,
123 | k_config=k_config,
124 | draft_model_temp=draft_model_temp,
125 | target_model_temp=target_model_temp,
126 | replacement=replacement,
127 | speculative_sampling=speculative_sampling,
128 | )
129 | else:
130 | self.strategy = strategies.BatchStrategy(
131 | draft_model=draft_model,
132 | target_model=target_model,
133 | k_config=k_config,
134 | draft_model_temp=draft_model_temp,
135 | target_model_temp=target_model_temp,
136 | replacement=replacement,
137 | speculative_sampling=speculative_sampling,
138 | )
139 |
140 | def generate(
141 | self,
142 | input_ids: Optional[torch.Tensor] = None,
143 | ) -> DecoderOnlyOutput:
144 | target_model_past_key_values = None
145 | draft_model_past_key_values = None
146 |
147 | invocation_count = 0
148 | acceptance_count = 0
149 |
150 | init_input_len = input_ids.size(-1)
151 |
152 | while True:
153 | draft_output = self.strategy.generate_draft(
154 | input_ids,
155 | past_key_values=draft_model_past_key_values,
156 | )
157 |
158 | draft_model_past_key_values = draft_output.past_key_values
159 |
160 | verification_output = self.strategy.verify(
161 | input_ids=draft_output.sequences,
162 | target_model_past_key_values=target_model_past_key_values,
163 | draft_model_past_key_values=draft_output.past_key_values,
164 | cand_probs=draft_output.cand_probs,
165 | )
166 |
167 | input_ids = verification_output.sequences
168 |
169 | draft_model_past_key_values = (
170 | verification_output.draft_model_past_key_values
171 | )
172 | target_model_past_key_values = (
173 | verification_output.target_model_past_key_values
174 | )
175 |
176 | invocation_count += 1
177 | acceptance_count += verification_output.acceptance_count
178 |
179 | if (
180 | self.eos_token_id in input_ids[0, -self.strategy.max_draft_len :]
181 | or input_ids.size(-1) - init_input_len >= self.max_new_tokens
182 | ):
183 | break
184 | return DecoderOnlyOutput(
185 | sequences=input_ids,
186 | acceptance_count=acceptance_count,
187 | draft_token_count=invocation_count * self.strategy.max_draft_len,
188 | invocation_count=invocation_count,
189 | )
190 |
--------------------------------------------------------------------------------
/MCSD/inference/strategies.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from dataclasses import dataclass
3 | from typing import Callable, List, Literal, Optional, Tuple, Union
4 |
5 | import torch
6 | from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
7 |
8 |
9 | @dataclass
10 | class DecoderOnlyDraftOutput(ModelOutput):
11 | """
12 | Base class for draft outputs of decoder-only generation models using speculative decoding.
13 | """
14 |
15 | sequences: torch.LongTensor = None
16 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
17 | cand_probs: Optional[Tuple[torch.FloatTensor]] = None
18 |
19 |
20 | @dataclass
21 | class DecoderOnlyVerificationOutput(ModelOutput):
22 | """
23 | Base class for verification outputs of decoder-only generation models using speculative decoding.
24 | """
25 |
26 | sequences: torch.LongTensor = None
27 | target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
28 | draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
29 | acceptance_count: Optional[int] = None
30 |
31 |
32 | def _MCNS(
33 | ground_probs: torch.FloatTensor,
34 | cand_probs: Tuple[torch.FloatTensor],
35 | cand_tokens: torch.LongTensor,
36 | ) -> Optional[int]:
37 | ground_token = torch.multinomial(ground_probs, num_samples=1).item()
38 |
39 | for check_idx, cand_token in enumerate(cand_tokens):
40 | if ground_token == cand_token:
41 | return check_idx
42 | ground_probs[:] = 0
43 | ground_probs[ground_token] = 1
44 | return None
45 |
46 |
47 | def _MCSSwoReplacement(
48 | ground_probs: torch.FloatTensor,
49 | cand_probs: Tuple[torch.FloatTensor],
50 | cand_tokens: torch.LongTensor,
51 | ) -> Optional[int]:
52 | cand_probs = cand_probs.to(ground_probs.device)
53 | for check_idx, cand_token in enumerate(cand_tokens):
54 | accept_threshold = ground_probs[cand_token] / cand_probs[cand_token]
55 | if torch.rand(1, device=accept_threshold.device) <= accept_threshold:
56 | return check_idx
57 | else:
58 | ground_probs -= cand_probs
59 | ground_probs = torch.nn.functional.relu(ground_probs, inplace=True)
60 | ground_probs /= ground_probs.sum()
61 | cand_probs[cand_token] = 0
62 | cand_probs = cand_probs / cand_probs.sum()
63 | return None
64 |
65 |
66 | def _MCSSwReplacement(
67 | ground_probs: torch.FloatTensor,
68 | cand_probs: Tuple[torch.FloatTensor],
69 | cand_tokens: torch.LongTensor,
70 | ) -> Optional[int]:
71 | cand_probs = cand_probs.to(ground_probs.device)
72 | for check_idx, cand_token in enumerate(cand_tokens):
73 | accept_threshold = ground_probs[cand_token] / cand_probs[cand_token]
74 | if torch.rand(1, device=accept_threshold.device) <= accept_threshold:
75 | return check_idx
76 | else:
77 | ground_probs -= cand_probs
78 | ground_probs = torch.nn.functional.relu(ground_probs, inplace=True)
79 | ground_probs /= ground_probs.sum()
80 | return None
81 |
82 |
83 | class Strategy:
84 | def __init__(
85 | self,
86 | draft_model,
87 | target_model,
88 | k_config: Tuple[int],
89 | draft_model_temp: float = 1,
90 | target_model_temp: float = 1,
91 | replacement: bool = False,
92 | speculative_sampling: bool = True,
93 | ) -> None:
94 | self.k_config = k_config
95 | self.draft_model = draft_model
96 | self.target_model = target_model
97 | self.draft_model_device = draft_model.model.get_input_embeddings().weight.device
98 | self.target_model_device = (
99 | target_model.model.get_input_embeddings().weight.device
100 | )
101 | self.max_draft_len = len(k_config)
102 | self.draft_model_temp = draft_model_temp
103 | self.target_model_temp = target_model_temp
104 | self.replacement = replacement
105 | self.speculative_sampling = speculative_sampling
106 |
107 | self.acceptance_check: Callable[
108 | [torch.FloatTensor, Tuple[torch.FloatTensor], torch.LongTensor],
109 | Optional[int],
110 | ] = None
111 | if speculative_sampling:
112 | if replacement:
113 | self.acceptance_check = _MCSSwReplacement
114 | if draft_model_temp == 0:
115 | warnings.warn(
116 | (
117 | "You have set Temp=0 and are using sampling with replacement. "
118 | "As a result, all the candidates obtained are the same, causing "
119 | "the MCSD algorithm to degenerate into the vanilla SD."
120 | ),
121 | category=UserWarning,
122 | stacklevel=3,
123 | )
124 | else:
125 | self.acceptance_check = _MCSSwoReplacement
126 | else:
127 | if replacement:
128 | warnings.warn(
129 | (
130 | "`replacement` is not applicable when `speculative_sampling` is False."
131 | "The acceptance check algorithm defaults to MCNS (Multi-Candidate Naive Sampling)"
132 | " when `speculative_sampling=False`."
133 | ),
134 | category=UserWarning,
135 | stacklevel=3,
136 | )
137 | self.acceptance_check = _MCNS
138 |
139 | def generate_draft(
140 | self,
141 | input_ids: torch.LongTensor,
142 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
143 | ) -> DecoderOnlyDraftOutput:
144 | raise NotImplementedError
145 |
146 | def acceptance_check(self, ground_probs, cand_probs, cand_tokens) -> Optional[int]:
147 | raise NotImplementedError
148 |
149 | def verify(
150 | self,
151 | input_ids: torch.LongTensor,
152 | target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
153 | draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
154 | cand_probs: Optional[Tuple[torch.FloatTensor]],
155 | ) -> DecoderOnlyVerificationOutput:
156 | raise NotImplementedError
157 |
158 |
159 | class BatchStrategy(Strategy):
160 | def __init__(
161 | self,
162 | draft_model,
163 | target_model,
164 | k_config: Tuple[int],
165 | draft_model_temp=1,
166 | target_model_temp=1,
167 | replacement: bool = False,
168 | speculative_sampling: bool = True,
169 | ) -> None:
170 | super().__init__(
171 | draft_model,
172 | target_model,
173 | k_config,
174 | draft_model_temp,
175 | target_model_temp,
176 | replacement,
177 | speculative_sampling,
178 | )
179 |
180 | reversed_prod_size = [1]
181 | for i in range(1, self.max_draft_len):
182 | reversed_prod_size.insert(0, reversed_prod_size[0] * k_config[-i])
183 |
184 | self.reversed_prod_size = reversed_prod_size
185 |
186 | def generate_draft(
187 | self,
188 | input_ids: torch.LongTensor,
189 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
190 | ) -> DecoderOnlyDraftOutput:
191 | input_ids = input_ids.to(self.draft_model_device)
192 | cand_probs = []
193 | for step in range(self.max_draft_len):
194 | step_k = self.k_config[step]
195 | if past_key_values is not None:
196 | pruned_input_ids = input_ids[:, past_key_values[0][0].size(2) :]
197 | else:
198 | pruned_input_ids = input_ids
199 | outputs: BaseModelOutputWithPast = self.draft_model.model(
200 | input_ids=pruned_input_ids,
201 | use_cache=True,
202 | past_key_values=past_key_values,
203 | return_dict=True,
204 | output_attentions=False,
205 | output_hidden_states=False,
206 | )
207 |
208 | hidden_states = outputs.last_hidden_state
209 |
210 | logits = self.draft_model.lm_head(hidden_states[:, -1])
211 |
212 | past_key_values = list(outputs.past_key_values)
213 |
214 | if self.draft_model_temp == 0:
215 | if not self.replacement:
216 | topk_logit, topk_index = logits.topk(k=step_k, dim=-1) # batch x k
217 | topk_probs = torch.softmax(topk_logit, dim=-1)
218 | step_cand_probs = torch.zeros_like(logits)
219 | step_cand_probs.scatter_(dim=1, index=topk_index, src=topk_probs)
220 | cand_tokens = topk_index.view(-1, 1)
221 | else:
222 | topk_logit, topk_index = logits.topk(k=1, dim=-1) # batch x k
223 | step_cand_probs = torch.zeros_like(logits)
224 | step_cand_probs.scatter_(dim=1, index=topk_index, value=1)
225 | cand_tokens = topk_index.view(-1, 1)
226 | cand_tokens = torch.repeat_interleave(cand_tokens, step_k, dim=0)
227 | else:
228 | step_cand_probs = torch.softmax(logits / self.draft_model_temp, dim=-1)
229 | cand_tokens = torch.multinomial(
230 | step_cand_probs,
231 | step_k,
232 | replacement=self.replacement,
233 | ).view(-1, 1)
234 |
235 | cand_probs.append(step_cand_probs)
236 |
237 | input_ids = input_ids.repeat_interleave(step_k, dim=0)
238 | input_ids = torch.cat(
239 | (
240 | input_ids,
241 | cand_tokens,
242 | ),
243 | dim=1,
244 | )
245 | if step + 1 != self.max_draft_len:
246 | for i in range(len(past_key_values)):
247 | past_key_values[i] = (
248 | past_key_values[i][0].repeat_interleave(step_k, dim=0),
249 | past_key_values[i][1].repeat_interleave(step_k, dim=0),
250 | )
251 |
252 | return DecoderOnlyDraftOutput(
253 | sequences=input_ids,
254 | past_key_values=past_key_values,
255 | cand_probs=tuple(cand_probs),
256 | )
257 |
258 | def verify(
259 | self,
260 | input_ids: torch.LongTensor,
261 | target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
262 | draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
263 | cand_probs: Optional[Tuple[torch.FloatTensor]],
264 | ) -> DecoderOnlyVerificationOutput:
265 | input_ids = input_ids.to(self.target_model_device)
266 | batch_size, input_len = input_ids.size()
267 | if target_model_past_key_values is not None:
268 | pruned_input_ids = input_ids[
269 | :, target_model_past_key_values[0][0].size(2) :
270 | ]
271 | for i in range(len(target_model_past_key_values)):
272 | target_model_past_key_values[i] = (
273 | target_model_past_key_values[i][0].repeat_interleave(
274 | batch_size, dim=0
275 | ),
276 | target_model_past_key_values[i][1].repeat_interleave(
277 | batch_size, dim=0
278 | ),
279 | )
280 | else:
281 | pruned_input_ids = input_ids
282 |
283 | outputs: BaseModelOutputWithPast = self.target_model.model(
284 | input_ids=pruned_input_ids,
285 | use_cache=True,
286 | past_key_values=target_model_past_key_values,
287 | return_dict=True,
288 | output_attentions=False,
289 | output_hidden_states=False,
290 | )
291 | hidden_states = outputs.last_hidden_state
292 | target_model_past_key_values = list(outputs.past_key_values)
293 |
294 | logits = self.target_model.lm_head(hidden_states[:, -self.max_draft_len - 1 :])
295 |
296 | if self.target_model_temp == 0:
297 | _, topk_index = logits.topk(k=1, dim=-1) # seq_len x 1
298 | ground_probs = torch.zeros_like(logits)
299 | ground_probs.scatter_(dim=2, index=topk_index, value=1)
300 | else:
301 | ground_probs = torch.softmax(logits / self.target_model_temp, dim=-1)
302 |
303 | unverified_input_ids = input_ids[:, -self.max_draft_len :]
304 |
305 | assert ground_probs.size(1) == unverified_input_ids.size(1) + 1
306 |
307 | cand_probs_idx = 0
308 | alive_group_id = 0
309 |
310 | for depth in range(self.max_draft_len):
311 | verify_batch_ids = [
312 | alive_group_id + group_offset * self.reversed_prod_size[depth]
313 | for group_offset in range(self.k_config[depth])
314 | ]
315 | accept_idx_bias = self.acceptance_check(
316 | ground_probs[alive_group_id, depth],
317 | cand_probs[depth][cand_probs_idx],
318 | unverified_input_ids[verify_batch_ids, depth],
319 | )
320 | if accept_idx_bias is not None:
321 | alive_group_id = verify_batch_ids[accept_idx_bias]
322 | cand_probs_idx = accept_idx_bias + cand_probs_idx * self.k_config[depth]
323 | if depth == self.max_draft_len - 1:
324 | depth = self.max_draft_len
325 | else:
326 | break
327 | input_ids = input_ids[alive_group_id, : input_len - self.max_draft_len + depth]
328 | endpoint_token = torch.multinomial(
329 | ground_probs[alive_group_id, depth], num_samples=1
330 | ).to(device=input_ids.device)
331 |
332 | input_ids = torch.cat((input_ids, endpoint_token))
333 |
334 | input_ids.unsqueeze_(0)
335 |
336 | for i in range(len(target_model_past_key_values)):
337 | target_model_past_key_values[i] = (
338 | target_model_past_key_values[i][0][
339 | None, alive_group_id, :, : input_len - self.max_draft_len + depth
340 | ],
341 | target_model_past_key_values[i][1][
342 | None, alive_group_id, :, : input_len - self.max_draft_len + depth
343 | ],
344 | )
345 | for i in range(len(draft_model_past_key_values)):
346 | draft_model_past_key_values[i] = (
347 | draft_model_past_key_values[i][0][
348 | None,
349 | alive_group_id // self.k_config[-1],
350 | :,
351 | : input_len - self.max_draft_len + depth,
352 | ],
353 | draft_model_past_key_values[i][1][
354 | None,
355 | alive_group_id // self.k_config[-1],
356 | :,
357 | : input_len - self.max_draft_len + depth,
358 | ],
359 | )
360 | return DecoderOnlyVerificationOutput(
361 | sequences=input_ids,
362 | target_model_past_key_values=target_model_past_key_values,
363 | draft_model_past_key_values=draft_model_past_key_values,
364 | acceptance_count=depth,
365 | )
366 |
367 |
368 | def get_tree_attn_self_mask(k_config: Tuple[int]):
369 | k_config = torch.tensor(k_config, dtype=torch.int)
370 | prod_size = torch.cumprod(k_config, dim=0)
371 | mask_size = prod_size.sum().item()
372 | attn_mask = torch.zeros((mask_size, mask_size), dtype=torch.bool)
373 | attn_mask = attn_mask.diagonal_scatter(torch.ones(mask_size))
374 | # run BFS
375 | idx_queue = [
376 | (0, None, idx) for idx in list(range(k_config[0]))
377 | ] # each node: (depth, parent, idx)
378 | while len(idx_queue) != 0:
379 | depth, parent, idx = idx_queue.pop(0)
380 | if parent is not None:
381 | attn_mask[idx, : parent + 1] = attn_mask[parent, : parent + 1]
382 |
383 | if depth != len(k_config) - 1:
384 | idx_base = prod_size[:depth].sum().item()
385 | child_idx_base = prod_size[: depth + 1].sum().item()
386 | for child_idx_bias in range(k_config[depth + 1]):
387 | real_child_idx = (
388 | (idx - idx_base) * k_config[depth + 1]
389 | + child_idx_base
390 | + child_idx_bias
391 | )
392 | idx_queue.append((depth + 1, idx, real_child_idx))
393 | return attn_mask
394 |
395 |
396 | class TreeStrategy(Strategy):
397 | def __init__(
398 | self,
399 | draft_model,
400 | target_model,
401 | k_config: Tuple[int],
402 | draft_model_temp: float = 1,
403 | target_model_temp: float = 1,
404 | replacement: bool = False,
405 | speculative_sampling: bool = True,
406 | ) -> None:
407 | super().__init__(
408 | draft_model,
409 | target_model,
410 | k_config,
411 | draft_model_temp,
412 | target_model_temp,
413 | replacement,
414 | speculative_sampling,
415 | )
416 |
417 | prod_size = torch.cumprod(torch.tensor(k_config, dtype=torch.int), dim=0)
418 | prod_size = torch.cat((torch.zeros(1).to(prod_size), prod_size)).tolist()
419 | self.prod_size = prod_size
420 | self.cumulative_prod_size = torch.cumsum(
421 | torch.tensor(prod_size), dim=0
422 | ).tolist()
423 |
424 | self.tree_attn_self_mask = get_tree_attn_self_mask(k_config).to(
425 | device=self.draft_model_device
426 | )
427 |
428 | def generate_draft(
429 | self,
430 | input_ids: torch.LongTensor,
431 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
432 | ) -> DecoderOnlyDraftOutput:
433 | input_ids = input_ids.to(self.draft_model_device)
434 | cand_probs = []
435 | step_tree_attn_mask = None
436 | position_ids = None
437 | init_input_length = input_ids.size(1)
438 | if past_key_values is not None:
439 | pruned_input_ids = input_ids[:, past_key_values[0][0].size(2) :]
440 | else:
441 | pruned_input_ids = input_ids
442 | for step in range(self.max_draft_len):
443 | step_k = self.k_config[step]
444 |
445 | # prepare attn mask
446 | if step != 0:
447 | step_tree_attn_self_mask = self.tree_attn_self_mask[
448 | self.cumulative_prod_size[step - 1] : self.cumulative_prod_size[
449 | step
450 | ],
451 | : self.cumulative_prod_size[step],
452 | ]
453 | position_ids = torch.full(
454 | (1, self.prod_size[step]),
455 | init_input_length + step - 1,
456 | dtype=torch.long,
457 | device=self.draft_model_device,
458 | )
459 | context_attn_mask = torch.ones(
460 | (self.prod_size[step], init_input_length), dtype=torch.bool
461 | ).to(self.tree_attn_self_mask)
462 | step_tree_attn_mask = torch.cat(
463 | (context_attn_mask, step_tree_attn_self_mask), dim=1
464 | )
465 |
466 | outputs: BaseModelOutputWithPast = self.draft_model.model(
467 | input_ids=pruned_input_ids,
468 | use_cache=True,
469 | past_key_values=past_key_values,
470 | return_dict=True,
471 | output_attentions=False,
472 | output_hidden_states=False,
473 | tree_attn_mask=step_tree_attn_mask,
474 | position_ids=position_ids,
475 | )
476 |
477 | hidden_states = outputs.last_hidden_state
478 |
479 | if step == 0:
480 | hidden_states = hidden_states[0, -1:]
481 | else:
482 | hidden_states = hidden_states[0]
483 | logits = self.draft_model.lm_head(hidden_states) # seq_len x hidden_dim
484 |
485 | past_key_values = list(outputs.past_key_values)
486 |
487 | if self.draft_model_temp == 0:
488 | if not self.replacement:
489 | topk_logit, topk_index = logits.topk(
490 | k=step_k, dim=-1
491 | ) # seq_len x k
492 | topk_probs = torch.softmax(topk_logit, dim=-1)
493 | step_cand_probs = torch.zeros_like(logits)
494 | step_cand_probs.scatter_(dim=1, index=topk_index, src=topk_probs)
495 | cand_tokens = topk_index.view(1, -1)
496 | else:
497 | topk_logit, topk_index = logits.topk(k=1, dim=-1) # seq_len x k
498 | step_cand_probs = torch.zeros_like(logits)
499 | step_cand_probs.scatter_(dim=1, index=topk_index, value=1)
500 | cand_tokens = topk_index.view(1, -1)
501 | cand_tokens = torch.repeat_interleave(cand_tokens, step_k, dim=1)
502 | else:
503 | step_cand_probs = torch.softmax(logits / self.draft_model_temp, dim=-1)
504 | cand_tokens = torch.multinomial(
505 | step_cand_probs, step_k, replacement=self.replacement
506 | ).view(1, -1)
507 | cand_probs.append(step_cand_probs)
508 |
509 | pruned_input_ids = cand_tokens
510 |
511 | input_ids = torch.cat((input_ids, pruned_input_ids), dim=1)
512 |
513 | return DecoderOnlyDraftOutput(
514 | sequences=input_ids,
515 | past_key_values=past_key_values,
516 | cand_probs=tuple(cand_probs),
517 | )
518 |
519 | def _forward_target_model(
520 | self,
521 | input_ids: torch.LongTensor,
522 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
523 | ):
524 | input_ids = input_ids.to(self.target_model_device)
525 | tree_attn_len = self.tree_attn_self_mask.size(0)
526 | init_input_length = input_ids.size(1) - tree_attn_len
527 | init_forward = False
528 |
529 | if past_key_values is not None:
530 | pruned_input_ids = input_ids[:, past_key_values[0][0].size(2) :]
531 | else:
532 | pruned_input_ids = input_ids
533 | init_forward = True
534 |
535 | if init_forward:
536 | tree_attn_mask = torch.zeros(
537 | (input_ids.size(1), input_ids.size(1)),
538 | dtype=torch.bool,
539 | device=self.target_model_device,
540 | )
541 | mask_cond = torch.arange(
542 | tree_attn_mask.size(-1), device=self.target_model_device
543 | )
544 | tree_attn_mask.masked_fill_(
545 | mask_cond < (mask_cond + 1).view(tree_attn_mask.size(-1), 1), 1
546 | )
547 | tree_attn_mask[-tree_attn_len:, -tree_attn_len:] = self.tree_attn_self_mask
548 | position_ids = tree_attn_mask.sum(dim=1) - 1
549 |
550 | else:
551 | tree_attn_mask = torch.ones(
552 | (
553 | tree_attn_len + 1,
554 | input_ids.size(1),
555 | ), # there is one token not stored in the kv values
556 | dtype=torch.bool,
557 | device=self.target_model_device,
558 | )
559 |
560 | tree_attn_mask[1:, init_input_length:] = self.tree_attn_self_mask
561 | tree_attn_mask[0, init_input_length:] = 0
562 | position_ids = tree_attn_mask.sum(dim=1) - 1
563 |
564 | outputs: BaseModelOutputWithPast = self.target_model.model(
565 | input_ids=pruned_input_ids,
566 | use_cache=True,
567 | past_key_values=past_key_values,
568 | return_dict=True,
569 | output_attentions=False,
570 | output_hidden_states=False,
571 | tree_attn_mask=tree_attn_mask,
572 | position_ids=position_ids,
573 | )
574 | hidden_states = outputs.last_hidden_state
575 | past_key_values = list(outputs.past_key_values)
576 |
577 | logits = self.target_model.lm_head(
578 | hidden_states[:, -tree_attn_len - 1 :]
579 | ) # 1 x seq_len x hidden_dim
580 | return logits, past_key_values
581 |
582 | def verify(
583 | self,
584 | input_ids: torch.LongTensor,
585 | target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
586 | draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]],
587 | cand_probs: Optional[Tuple[torch.FloatTensor]],
588 | ) -> DecoderOnlyVerificationOutput:
589 | input_ids = input_ids.to(self.target_model_device)
590 | logits, target_model_past_key_values = self._forward_target_model(
591 | input_ids, target_model_past_key_values
592 | )
593 | logits = logits[0] # seq_len x hidden_dim
594 | tree_attn_len = self.tree_attn_self_mask.size(0)
595 | unverified_tokens = input_ids[0, -tree_attn_len:]
596 | init_input_length = input_ids.size(1) - tree_attn_len
597 |
598 | if self.target_model_temp == 0:
599 | _, topk_index = logits.topk(k=1, dim=-1) # seq_len x 1
600 | ground_probs = torch.zeros_like(logits)
601 | ground_probs.scatter_(dim=1, index=topk_index, value=1)
602 | else:
603 | ground_probs = torch.softmax(logits / self.target_model_temp, dim=-1)
604 | current_ground_prob = ground_probs[0]
605 | ground_probs = ground_probs[1:]
606 |
607 | keep_indices = list(range(init_input_length))
608 | to_drop_len = 0
609 | idx_group_bias = 0
610 | cand_probs_idx = 0
611 |
612 | for depth in range(self.max_draft_len):
613 | idx_base = self.cumulative_prod_size[depth] + idx_group_bias
614 | accept_idx_bias = self.acceptance_check(
615 | current_ground_prob,
616 | cand_probs[depth][cand_probs_idx],
617 | unverified_tokens[idx_base : idx_base + self.k_config[depth]],
618 | )
619 | if accept_idx_bias is not None:
620 | global_idx = idx_base + accept_idx_bias
621 | current_ground_prob = ground_probs[global_idx]
622 | keep_indices.append(init_input_length + global_idx)
623 | if depth == self.max_draft_len - 1:
624 | to_drop_len += 1
625 | depth = self.max_draft_len
626 | else:
627 | cand_probs_idx = idx_group_bias + accept_idx_bias
628 | idx_group_bias = cand_probs_idx * self.k_config[depth + 1]
629 | else:
630 | break
631 |
632 | keep_indices = torch.tensor(
633 | keep_indices, dtype=torch.long, device=self.target_model_device
634 | )
635 | if to_drop_len != 0:
636 | draft_keep_indices = keep_indices[: len(keep_indices) - to_drop_len]
637 | else:
638 | draft_keep_indices = keep_indices
639 |
640 | tail_ground_token = torch.multinomial(current_ground_prob, num_samples=1).to(
641 | device=input_ids.device
642 | )
643 |
644 | input_ids = input_ids.index_select(dim=1, index=keep_indices)
645 | input_ids = torch.cat((input_ids, tail_ground_token[None]), dim=1)
646 |
647 | for i in range(len(target_model_past_key_values)):
648 | keep_indices = keep_indices.to(
649 | device=target_model_past_key_values[i][0].device
650 | )
651 | target_model_past_key_values[i] = (
652 | target_model_past_key_values[i][0].index_select(
653 | dim=2, index=keep_indices
654 | ),
655 | target_model_past_key_values[i][1].index_select(
656 | dim=2, index=keep_indices
657 | ),
658 | )
659 | for i in range(len(draft_model_past_key_values)):
660 | draft_model_past_key_values[i] = (
661 | draft_model_past_key_values[i][0].index_select(
662 | dim=2, index=draft_keep_indices
663 | ),
664 | draft_model_past_key_values[i][1].index_select(
665 | dim=2, index=draft_keep_indices
666 | ),
667 | )
668 |
669 | return DecoderOnlyVerificationOutput(
670 | sequences=input_ids,
671 | target_model_past_key_values=target_model_past_key_values,
672 | draft_model_past_key_values=draft_model_past_key_values,
673 | acceptance_count=depth,
674 | )
675 |
--------------------------------------------------------------------------------
/MCSD/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NJUNLP/MCSD/8aadd6501a9e987ba5fca6cc8f9ad5949e480ec7/MCSD/model/__init__.py
--------------------------------------------------------------------------------
/MCSD/model/llama_tree_attn/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import TYPE_CHECKING
15 |
16 | from transformers.utils import (
17 | OptionalDependencyNotAvailable,
18 | _LazyModule,
19 | is_sentencepiece_available,
20 | is_tokenizers_available,
21 | is_torch_available,
22 | )
23 |
24 |
25 | _import_structure = {
26 | "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"],
27 | }
28 |
29 | try:
30 | if not is_sentencepiece_available():
31 | raise OptionalDependencyNotAvailable()
32 | except OptionalDependencyNotAvailable:
33 | pass
34 | else:
35 | _import_structure["tokenization_llama"] = ["LlamaTokenizer"]
36 |
37 | try:
38 | if not is_tokenizers_available():
39 | raise OptionalDependencyNotAvailable()
40 | except OptionalDependencyNotAvailable:
41 | pass
42 | else:
43 | _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"]
44 |
45 | try:
46 | if not is_torch_available():
47 | raise OptionalDependencyNotAvailable()
48 | except OptionalDependencyNotAvailable:
49 | pass
50 | else:
51 | _import_structure["modeling_llama"] = [
52 | "LlamaForCausalLM",
53 | "LlamaModel",
54 | "LlamaPreTrainedModel",
55 | "LlamaForSequenceClassification",
56 | ]
57 |
58 |
59 | if TYPE_CHECKING:
60 | from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig
61 |
62 | try:
63 | if not is_sentencepiece_available():
64 | raise OptionalDependencyNotAvailable()
65 | except OptionalDependencyNotAvailable:
66 | pass
67 | else:
68 | from .tokenization_llama import LlamaTokenizer
69 |
70 | try:
71 | if not is_tokenizers_available():
72 | raise OptionalDependencyNotAvailable()
73 | except OptionalDependencyNotAvailable:
74 | pass
75 | else:
76 | from .tokenization_llama_fast import LlamaTokenizerFast
77 |
78 | try:
79 | if not is_torch_available():
80 | raise OptionalDependencyNotAvailable()
81 | except OptionalDependencyNotAvailable:
82 | pass
83 | else:
84 | from .modeling_llama import (
85 | LlamaForCausalLM,
86 | LlamaForSequenceClassification,
87 | LlamaModel,
88 | LlamaPreTrainedModel,
89 | )
90 |
91 |
92 | else:
93 | import sys
94 |
95 | sys.modules[__name__] = _LazyModule(
96 | __name__, globals()["__file__"], _import_structure, module_spec=__spec__
97 | )
98 |
--------------------------------------------------------------------------------
/MCSD/model/llama_tree_attn/configuration_llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ LLaMA model configuration"""
21 |
22 | from transformers.configuration_utils import PretrainedConfig
23 | from transformers.utils import logging
24 |
25 |
26 | logger = logging.get_logger(__name__)
27 |
28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
29 |
30 |
31 | class LlamaConfig(PretrainedConfig):
32 | r"""
33 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
35 | defaults will yield a similar configuration to that of the LLaMA-7B.
36 |
37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38 | documentation from [`PretrainedConfig`] for more information.
39 |
40 |
41 | Args:
42 | vocab_size (`int`, *optional*, defaults to 32000):
43 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
44 | `inputs_ids` passed when calling [`LlamaModel`]
45 | hidden_size (`int`, *optional*, defaults to 4096):
46 | Dimension of the hidden representations.
47 | intermediate_size (`int`, *optional*, defaults to 11008):
48 | Dimension of the MLP representations.
49 | num_hidden_layers (`int`, *optional*, defaults to 32):
50 | Number of hidden layers in the Transformer encoder.
51 | num_attention_heads (`int`, *optional*, defaults to 32):
52 | Number of attention heads for each attention layer in the Transformer encoder.
53 | num_key_value_heads (`int`, *optional*):
54 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
55 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
56 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
57 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
58 | by meanpooling all the original heads within that group. For more details checkout [this
59 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
60 | `num_attention_heads`.
61 | pretraining_tp (`int`, *optional*, defaults to `1`):
62 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
63 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
64 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
65 | issue](https://github.com/pytorch/pytorch/issues/76232).
66 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
67 | The non-linear activation function (function or string) in the decoder.
68 | max_position_embeddings (`int`, *optional*, defaults to 2048):
69 | The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
70 | Llama 2 up to 4096, CodeLlama up to 16384.
71 | initializer_range (`float`, *optional*, defaults to 0.02):
72 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73 | rms_norm_eps (`float`, *optional*, defaults to 1e-12):
74 | The epsilon used by the rms normalization layers.
75 | use_cache (`bool`, *optional*, defaults to `True`):
76 | Whether or not the model should return the last key/values attentions (not used by all models). Only
77 | relevant if `config.is_decoder=True`.
78 | tie_word_embeddings(`bool`, *optional*, defaults to `False`):
79 | Whether to tie weight embeddings
80 | rope_theta (`float`, *optional*, defaults to 10000.0):
81 | The base period of the RoPE embeddings.
82 | rope_scaling (`Dict`, *optional*):
83 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
84 | strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
85 | is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
86 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
87 | these scaling strategies behave:
88 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
89 | experimental feature, subject to breaking API changes in future versions.
90 | attention_bias (`bool`, defaults to `False`):
91 | Whether to use a bias in the query, key, value and output projection layers during self-attention.
92 |
93 | Example:
94 |
95 | ```python
96 | >>> from transformers import LlamaModel, LlamaConfig
97 |
98 | >>> # Initializing a LLaMA llama-7b style configuration
99 | >>> configuration = LlamaConfig()
100 |
101 | >>> # Initializing a model from the llama-7b style configuration
102 | >>> model = LlamaModel(configuration)
103 |
104 | >>> # Accessing the model configuration
105 | >>> configuration = model.config
106 | ```"""
107 | model_type = "llama"
108 | keys_to_ignore_at_inference = ["past_key_values"]
109 |
110 | def __init__(
111 | self,
112 | vocab_size=32000,
113 | hidden_size=4096,
114 | intermediate_size=11008,
115 | num_hidden_layers=32,
116 | num_attention_heads=32,
117 | num_key_value_heads=None,
118 | hidden_act="silu",
119 | max_position_embeddings=2048,
120 | initializer_range=0.02,
121 | rms_norm_eps=1e-6,
122 | use_cache=True,
123 | pad_token_id=None,
124 | bos_token_id=1,
125 | eos_token_id=2,
126 | pretraining_tp=1,
127 | tie_word_embeddings=False,
128 | rope_theta=10000.0,
129 | rope_scaling=None,
130 | attention_bias=False,
131 | **kwargs,
132 | ):
133 | self.vocab_size = vocab_size
134 | self.max_position_embeddings = max_position_embeddings
135 | self.hidden_size = hidden_size
136 | self.intermediate_size = intermediate_size
137 | self.num_hidden_layers = num_hidden_layers
138 | self.num_attention_heads = num_attention_heads
139 |
140 | # for backward compatibility
141 | if num_key_value_heads is None:
142 | num_key_value_heads = num_attention_heads
143 |
144 | self.num_key_value_heads = num_key_value_heads
145 | self.hidden_act = hidden_act
146 | self.initializer_range = initializer_range
147 | self.rms_norm_eps = rms_norm_eps
148 | self.pretraining_tp = pretraining_tp
149 | self.use_cache = use_cache
150 | self.rope_theta = rope_theta
151 | self.rope_scaling = rope_scaling
152 | self._rope_scaling_validation()
153 | self.attention_bias = attention_bias
154 |
155 | super().__init__(
156 | pad_token_id=pad_token_id,
157 | bos_token_id=bos_token_id,
158 | eos_token_id=eos_token_id,
159 | tie_word_embeddings=tie_word_embeddings,
160 | **kwargs,
161 | )
162 |
163 | def _rope_scaling_validation(self):
164 | """
165 | Validate the `rope_scaling` configuration.
166 | """
167 | if self.rope_scaling is None:
168 | return
169 |
170 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
171 | raise ValueError(
172 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
173 | f"got {self.rope_scaling}"
174 | )
175 | rope_scaling_type = self.rope_scaling.get("type", None)
176 | rope_scaling_factor = self.rope_scaling.get("factor", None)
177 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
178 | raise ValueError(
179 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
180 | )
181 | if (
182 | rope_scaling_factor is None
183 | or not isinstance(rope_scaling_factor, float)
184 | or rope_scaling_factor <= 1.0
185 | ):
186 | raise ValueError(
187 | f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}"
188 | )
189 |
--------------------------------------------------------------------------------
/MCSD/model/llama_tree_attn/convert_llama_weights_to_hf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import argparse
15 | import gc
16 | import json
17 | import os
18 | import shutil
19 | import warnings
20 |
21 | import torch
22 |
23 | from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
24 |
25 |
26 | try:
27 | from transformers import LlamaTokenizerFast
28 | except ImportError as e:
29 | warnings.warn(e)
30 | warnings.warn(
31 | "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
32 | )
33 | LlamaTokenizerFast = None
34 |
35 | """
36 | Sample usage:
37 |
38 | ```
39 | python src/transformers/models/llama/convert_llama_weights_to_hf.py \
40 | --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
41 | ```
42 |
43 | Thereafter, models can be loaded via:
44 |
45 | ```py
46 | from transformers import LlamaForCausalLM, LlamaTokenizer
47 |
48 | model = LlamaForCausalLM.from_pretrained("/output/path")
49 | tokenizer = LlamaTokenizer.from_pretrained("/output/path")
50 | ```
51 |
52 | Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
53 | come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
54 | """
55 |
56 | NUM_SHARDS = {
57 | "7B": 1,
58 | "7Bf": 1,
59 | "13B": 2,
60 | "13Bf": 2,
61 | "34B": 4,
62 | "30B": 4,
63 | "65B": 8,
64 | "70B": 8,
65 | "70Bf": 8,
66 | }
67 |
68 |
69 | def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
70 | return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
71 |
72 |
73 | def read_json(path):
74 | with open(path, "r") as f:
75 | return json.load(f)
76 |
77 |
78 | def write_json(text, path):
79 | with open(path, "w") as f:
80 | json.dump(text, f)
81 |
82 |
83 | def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
84 | # for backward compatibility, before you needed the repo to be called `my_repo/model_size`
85 | if not os.path.isfile(os.path.join(input_base_path, "params.json")):
86 | input_base_path = os.path.join(input_base_path, model_size)
87 |
88 | os.makedirs(model_path, exist_ok=True)
89 | tmp_model_path = os.path.join(model_path, "tmp")
90 | os.makedirs(tmp_model_path, exist_ok=True)
91 |
92 | params = read_json(os.path.join(input_base_path, "params.json"))
93 | num_shards = NUM_SHARDS[model_size]
94 | n_layers = params["n_layers"]
95 | n_heads = params["n_heads"]
96 | n_heads_per_shard = n_heads // num_shards
97 | dim = params["dim"]
98 | dims_per_head = dim // n_heads
99 | base = params.get("rope_theta", 10000.0)
100 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
101 | if base > 10000.0:
102 | max_position_embeddings = 16384
103 | else:
104 | max_position_embeddings = 2048
105 |
106 | tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
107 | if tokenizer_path is not None:
108 | tokenizer = tokenizer_class(tokenizer_path)
109 | tokenizer.save_pretrained(model_path)
110 | vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
111 |
112 | if "n_kv_heads" in params:
113 | num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
114 | num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
115 | key_value_dim = dim // num_key_value_heads
116 | else: # compatibility with other checkpoints
117 | num_key_value_heads = n_heads
118 | num_local_key_value_heads = n_heads_per_shard
119 | key_value_dim = dim
120 |
121 | # permute for sliced rotary
122 | def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
123 | return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
124 |
125 | print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
126 | # Load weights
127 | if model_size == "7B":
128 | # Not sharded
129 | # (The sharded implementation would also work, but this is simpler.)
130 | loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
131 | else:
132 | # Sharded
133 | loaded = [
134 | torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
135 | for i in range(num_shards)
136 | ]
137 | param_count = 0
138 | index_dict = {"weight_map": {}}
139 | for layer_i in range(n_layers):
140 | filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
141 | if model_size == "7B":
142 | # Unsharded
143 | state_dict = {
144 | f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
145 | loaded[f"layers.{layer_i}.attention.wq.weight"]
146 | ),
147 | f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
148 | loaded[f"layers.{layer_i}.attention.wk.weight"]
149 | ),
150 | f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
151 | f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
152 | f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
153 | f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
154 | f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
155 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
156 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
157 | }
158 | else:
159 | # Sharded
160 | # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
161 | # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
162 | # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
163 |
164 | state_dict = {
165 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
166 | f"layers.{layer_i}.attention_norm.weight"
167 | ].clone(),
168 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
169 | f"layers.{layer_i}.ffn_norm.weight"
170 | ].clone(),
171 | }
172 | state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
173 | torch.cat(
174 | [
175 | loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
176 | for i in range(num_shards)
177 | ],
178 | dim=0,
179 | ).reshape(dim, dim)
180 | )
181 | state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
182 | torch.cat(
183 | [
184 | loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
185 | num_local_key_value_heads, dims_per_head, dim
186 | )
187 | for i in range(num_shards)
188 | ],
189 | dim=0,
190 | ).reshape(key_value_dim, dim),
191 | num_key_value_heads,
192 | key_value_dim,
193 | dim,
194 | )
195 | state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
196 | [
197 | loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
198 | num_local_key_value_heads, dims_per_head, dim
199 | )
200 | for i in range(num_shards)
201 | ],
202 | dim=0,
203 | ).reshape(key_value_dim, dim)
204 |
205 | state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
206 | [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
207 | )
208 | state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
209 | [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
210 | )
211 | state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
212 | [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
213 | )
214 | state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
215 | [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
216 | )
217 |
218 | state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
219 | for k, v in state_dict.items():
220 | index_dict["weight_map"][k] = filename
221 | param_count += v.numel()
222 | torch.save(state_dict, os.path.join(tmp_model_path, filename))
223 |
224 | filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
225 | if model_size == "7B":
226 | # Unsharded
227 | state_dict = {
228 | "model.embed_tokens.weight": loaded["tok_embeddings.weight"],
229 | "model.norm.weight": loaded["norm.weight"],
230 | "lm_head.weight": loaded["output.weight"],
231 | }
232 | else:
233 | state_dict = {
234 | "model.norm.weight": loaded[0]["norm.weight"],
235 | "model.embed_tokens.weight": torch.cat(
236 | [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
237 | ),
238 | "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
239 | }
240 |
241 | for k, v in state_dict.items():
242 | index_dict["weight_map"][k] = filename
243 | param_count += v.numel()
244 | torch.save(state_dict, os.path.join(tmp_model_path, filename))
245 |
246 | # Write configs
247 | index_dict["metadata"] = {"total_size": param_count * 2}
248 | write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
249 | ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
250 | multiple_of = params["multiple_of"] if "multiple_of" in params else 256
251 | config = LlamaConfig(
252 | hidden_size=dim,
253 | intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
254 | num_attention_heads=params["n_heads"],
255 | num_hidden_layers=params["n_layers"],
256 | rms_norm_eps=params["norm_eps"],
257 | num_key_value_heads=num_key_value_heads,
258 | vocab_size=vocab_size,
259 | rope_theta=base,
260 | max_position_embeddings=max_position_embeddings,
261 | )
262 | config.save_pretrained(tmp_model_path)
263 |
264 | # Make space so we can load the model properly now.
265 | del state_dict
266 | del loaded
267 | gc.collect()
268 |
269 | print("Loading the checkpoint in a Llama model.")
270 | model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
271 | # Avoid saving this as part of the config.
272 | del model.config._name_or_path
273 | model.config.torch_dtype = torch.float16
274 | print("Saving in the Transformers format.")
275 | model.save_pretrained(model_path, safe_serialization=safe_serialization)
276 | shutil.rmtree(tmp_model_path)
277 |
278 |
279 | def write_tokenizer(tokenizer_path, input_tokenizer_path):
280 | # Initialize the tokenizer based on the `spm` model
281 | tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
282 | print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
283 | tokenizer = tokenizer_class(input_tokenizer_path)
284 | tokenizer.save_pretrained(tokenizer_path)
285 |
286 |
287 | def main():
288 | parser = argparse.ArgumentParser()
289 | parser.add_argument(
290 | "--input_dir",
291 | help="Location of LLaMA weights, which contains tokenizer.model and model folders",
292 | )
293 | parser.add_argument(
294 | "--model_size",
295 | choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"],
296 | help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama",
297 | )
298 | parser.add_argument(
299 | "--output_dir",
300 | help="Location to write HF model and tokenizer",
301 | )
302 | parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
303 | args = parser.parse_args()
304 | spm_path = os.path.join(args.input_dir, "tokenizer.model")
305 | if args.model_size != "tokenizer_only":
306 | write_model(
307 | model_path=args.output_dir,
308 | input_base_path=args.input_dir,
309 | model_size=args.model_size,
310 | safe_serialization=args.safe_serialization,
311 | tokenizer_path=spm_path,
312 | )
313 | else:
314 | write_tokenizer(args.output_dir, spm_path)
315 |
316 |
317 | if __name__ == "__main__":
318 | main()
319 |
--------------------------------------------------------------------------------
/MCSD/model/llama_tree_attn/modeling_llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ PyTorch LLaMA model."""
21 | import math
22 | from typing import List, Optional, Tuple, Union
23 |
24 | import torch
25 | import torch.nn.functional as F
26 | import torch.utils.checkpoint
27 | from torch import nn
28 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29 |
30 | from transformers.activations import ACT2FN
31 | from transformers.modeling_outputs import (
32 | BaseModelOutputWithPast,
33 | CausalLMOutputWithPast,
34 | SequenceClassifierOutputWithPast,
35 | )
36 | from transformers.modeling_utils import PreTrainedModel
37 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
38 | from transformers.utils import (
39 | add_start_docstrings,
40 | add_start_docstrings_to_model_forward,
41 | is_flash_attn_available,
42 | logging,
43 | replace_return_docstrings,
44 | )
45 | from .configuration_llama import LlamaConfig
46 |
47 |
48 | if is_flash_attn_available():
49 | from flash_attn import flash_attn_func, flash_attn_varlen_func
50 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
51 |
52 |
53 | logger = logging.get_logger(__name__)
54 |
55 | _CONFIG_FOR_DOC = "LlamaConfig"
56 |
57 |
58 | def _get_unpad_data(padding_mask):
59 | seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
60 | indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
61 | max_seqlen_in_batch = seqlens_in_batch.max().item()
62 | cu_seqlens = F.pad(
63 | torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
64 | )
65 | return (
66 | indices,
67 | cu_seqlens,
68 | max_seqlen_in_batch,
69 | )
70 |
71 |
72 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask
73 | def _make_causal_mask(
74 | input_ids_shape: torch.Size,
75 | dtype: torch.dtype,
76 | device: torch.device,
77 | past_key_values_length: int = 0,
78 | tree_attn_mask: Optional[torch.Tensor] = None,
79 | ):
80 | """
81 | Make causal mask used for bi-directional self-attention.
82 | """
83 | bsz, tgt_len = input_ids_shape
84 |
85 | if tree_attn_mask is not None:
86 | mask = torch.full_like(
87 | tree_attn_mask,
88 | torch.finfo(dtype).min,
89 | dtype=dtype,
90 | device=device,
91 | )
92 | mask.masked_fill_(tree_attn_mask, 0)
93 | else:
94 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
95 | mask_cond = torch.arange(mask.size(-1), device=device)
96 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
97 | mask = mask.to(dtype)
98 |
99 | if past_key_values_length > 0:
100 | mask = torch.cat(
101 | [
102 | torch.zeros(
103 | tgt_len, past_key_values_length, dtype=dtype, device=device
104 | ),
105 | mask,
106 | ],
107 | dim=-1,
108 | )
109 | return mask[None, None, :, :].expand(
110 | bsz, 1, tgt_len, tgt_len + past_key_values_length
111 | )
112 |
113 |
114 | # Copied from transformers.models.bart.modeling_bart._expand_mask
115 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
116 | """
117 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
118 | """
119 | bsz, src_len = mask.size()
120 | tgt_len = tgt_len if tgt_len is not None else src_len
121 |
122 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
123 |
124 | inverted_mask = 1.0 - expanded_mask
125 |
126 | return inverted_mask.masked_fill(
127 | inverted_mask.to(torch.bool), torch.finfo(dtype).min
128 | )
129 |
130 |
131 | class LlamaRMSNorm(nn.Module):
132 | def __init__(self, hidden_size, eps=1e-6):
133 | """
134 | LlamaRMSNorm is equivalent to T5LayerNorm
135 | """
136 | super().__init__()
137 | self.weight = nn.Parameter(torch.ones(hidden_size))
138 | self.variance_epsilon = eps
139 |
140 | def forward(self, hidden_states):
141 | input_dtype = hidden_states.dtype
142 | hidden_states = hidden_states.to(torch.float32)
143 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
144 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
145 | return self.weight * hidden_states.to(input_dtype)
146 |
147 |
148 | ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
149 |
150 |
151 | class LlamaRotaryEmbedding(nn.Module):
152 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
153 | super().__init__()
154 |
155 | self.dim = dim
156 | self.max_position_embeddings = max_position_embeddings
157 | self.base = base
158 | inv_freq = 1.0 / (
159 | self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
160 | )
161 | self.register_buffer("inv_freq", inv_freq, persistent=False)
162 |
163 | # Build here to make `torch.jit.trace` work.
164 | self._set_cos_sin_cache(
165 | seq_len=max_position_embeddings,
166 | device=self.inv_freq.device,
167 | dtype=torch.get_default_dtype(),
168 | )
169 |
170 | def _set_cos_sin_cache(self, seq_len, device, dtype):
171 | self.max_seq_len_cached = seq_len
172 | t = torch.arange(
173 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
174 | )
175 |
176 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
177 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
178 | emb = torch.cat((freqs, freqs), dim=-1)
179 | self.register_buffer(
180 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
181 | )
182 | self.register_buffer(
183 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
184 | )
185 |
186 | def forward(self, x, seq_len=None):
187 | # x: [bs, num_attention_heads, seq_len, head_size]
188 | if seq_len > self.max_seq_len_cached:
189 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
190 |
191 | return (
192 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
193 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
194 | )
195 |
196 |
197 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
198 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
199 |
200 | def __init__(
201 | self,
202 | dim,
203 | max_position_embeddings=2048,
204 | base=10000,
205 | device=None,
206 | scaling_factor=1.0,
207 | ):
208 | self.scaling_factor = scaling_factor
209 | super().__init__(dim, max_position_embeddings, base, device)
210 |
211 | def _set_cos_sin_cache(self, seq_len, device, dtype):
212 | self.max_seq_len_cached = seq_len
213 | t = torch.arange(
214 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
215 | )
216 | t = t / self.scaling_factor
217 |
218 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
219 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
220 | emb = torch.cat((freqs, freqs), dim=-1)
221 | self.register_buffer(
222 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
223 | )
224 | self.register_buffer(
225 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
226 | )
227 |
228 |
229 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
230 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
231 |
232 | def __init__(
233 | self,
234 | dim,
235 | max_position_embeddings=2048,
236 | base=10000,
237 | device=None,
238 | scaling_factor=1.0,
239 | ):
240 | self.scaling_factor = scaling_factor
241 | super().__init__(dim, max_position_embeddings, base, device)
242 |
243 | def _set_cos_sin_cache(self, seq_len, device, dtype):
244 | self.max_seq_len_cached = seq_len
245 |
246 | if seq_len > self.max_position_embeddings:
247 | base = self.base * (
248 | (self.scaling_factor * seq_len / self.max_position_embeddings)
249 | - (self.scaling_factor - 1)
250 | ) ** (self.dim / (self.dim - 2))
251 | inv_freq = 1.0 / (
252 | base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
253 | )
254 | self.register_buffer("inv_freq", inv_freq, persistent=False)
255 |
256 | t = torch.arange(
257 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
258 | )
259 |
260 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
261 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
262 | emb = torch.cat((freqs, freqs), dim=-1)
263 | self.register_buffer(
264 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
265 | )
266 | self.register_buffer(
267 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
268 | )
269 |
270 |
271 | def rotate_half(x):
272 | """Rotates half the hidden dims of the input."""
273 | x1 = x[..., : x.shape[-1] // 2]
274 | x2 = x[..., x.shape[-1] // 2 :]
275 | return torch.cat((-x2, x1), dim=-1)
276 |
277 |
278 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
279 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
280 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
281 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
282 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
283 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
284 | q_embed = (q * cos) + (rotate_half(q) * sin)
285 | k_embed = (k * cos) + (rotate_half(k) * sin)
286 | return q_embed, k_embed
287 |
288 |
289 | class LlamaMLP(nn.Module):
290 | def __init__(self, config):
291 | super().__init__()
292 | self.config = config
293 | self.hidden_size = config.hidden_size
294 | self.intermediate_size = config.intermediate_size
295 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
296 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
297 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
298 | self.act_fn = ACT2FN[config.hidden_act]
299 |
300 | def forward(self, x):
301 | if self.config.pretraining_tp > 1:
302 | slice = self.intermediate_size // self.config.pretraining_tp
303 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
304 | up_proj_slices = self.up_proj.weight.split(slice, dim=0)
305 | down_proj_slices = self.down_proj.weight.split(slice, dim=1)
306 |
307 | gate_proj = torch.cat(
308 | [
309 | F.linear(x, gate_proj_slices[i])
310 | for i in range(self.config.pretraining_tp)
311 | ],
312 | dim=-1,
313 | )
314 | up_proj = torch.cat(
315 | [
316 | F.linear(x, up_proj_slices[i])
317 | for i in range(self.config.pretraining_tp)
318 | ],
319 | dim=-1,
320 | )
321 |
322 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
323 | down_proj = [
324 | F.linear(intermediate_states[i], down_proj_slices[i])
325 | for i in range(self.config.pretraining_tp)
326 | ]
327 | down_proj = sum(down_proj)
328 | else:
329 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
330 |
331 | return down_proj
332 |
333 |
334 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
335 | """
336 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
337 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
338 | """
339 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape
340 | if n_rep == 1:
341 | return hidden_states
342 | hidden_states = hidden_states[:, :, None, :, :].expand(
343 | batch, num_key_value_heads, n_rep, slen, head_dim
344 | )
345 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
346 |
347 |
348 | class LlamaAttention(nn.Module):
349 | """Multi-headed attention from 'Attention Is All You Need' paper"""
350 |
351 | def __init__(self, config: LlamaConfig):
352 | super().__init__()
353 | self.config = config
354 | self.hidden_size = config.hidden_size
355 | self.num_heads = config.num_attention_heads
356 | self.head_dim = self.hidden_size // self.num_heads
357 | self.num_key_value_heads = config.num_key_value_heads
358 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
359 | self.max_position_embeddings = config.max_position_embeddings
360 | self.rope_theta = config.rope_theta
361 |
362 | if (self.head_dim * self.num_heads) != self.hidden_size:
363 | raise ValueError(
364 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
365 | f" and `num_heads`: {self.num_heads})."
366 | )
367 | self.q_proj = nn.Linear(
368 | self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
369 | )
370 | self.k_proj = nn.Linear(
371 | self.hidden_size,
372 | self.num_key_value_heads * self.head_dim,
373 | bias=config.attention_bias,
374 | )
375 | self.v_proj = nn.Linear(
376 | self.hidden_size,
377 | self.num_key_value_heads * self.head_dim,
378 | bias=config.attention_bias,
379 | )
380 | self.o_proj = nn.Linear(
381 | self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
382 | )
383 | self._init_rope()
384 |
385 | def _init_rope(self):
386 | if self.config.rope_scaling is None:
387 | self.rotary_emb = LlamaRotaryEmbedding(
388 | self.head_dim,
389 | max_position_embeddings=self.max_position_embeddings,
390 | base=self.rope_theta,
391 | )
392 | else:
393 | scaling_type = self.config.rope_scaling["type"]
394 | scaling_factor = self.config.rope_scaling["factor"]
395 | if scaling_type == "linear":
396 | self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
397 | self.head_dim,
398 | max_position_embeddings=self.max_position_embeddings,
399 | scaling_factor=scaling_factor,
400 | base=self.rope_theta,
401 | )
402 | elif scaling_type == "dynamic":
403 | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
404 | self.head_dim,
405 | max_position_embeddings=self.max_position_embeddings,
406 | scaling_factor=scaling_factor,
407 | base=self.rope_theta,
408 | )
409 | else:
410 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
411 |
412 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
413 | return (
414 | tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
415 | .transpose(1, 2)
416 | .contiguous()
417 | )
418 |
419 | def forward(
420 | self,
421 | hidden_states: torch.Tensor,
422 | attention_mask: Optional[torch.Tensor] = None,
423 | position_ids: Optional[torch.LongTensor] = None,
424 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
425 | output_attentions: bool = False,
426 | use_cache: bool = False,
427 | padding_mask: Optional[torch.LongTensor] = None,
428 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
429 | bsz, q_len, _ = hidden_states.size()
430 |
431 | if self.config.pretraining_tp > 1:
432 | key_value_slicing = (
433 | self.num_key_value_heads * self.head_dim
434 | ) // self.config.pretraining_tp
435 | query_slices = self.q_proj.weight.split(
436 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
437 | )
438 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
439 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
440 |
441 | query_states = [
442 | F.linear(hidden_states, query_slices[i])
443 | for i in range(self.config.pretraining_tp)
444 | ]
445 | query_states = torch.cat(query_states, dim=-1)
446 |
447 | key_states = [
448 | F.linear(hidden_states, key_slices[i])
449 | for i in range(self.config.pretraining_tp)
450 | ]
451 | key_states = torch.cat(key_states, dim=-1)
452 |
453 | value_states = [
454 | F.linear(hidden_states, value_slices[i])
455 | for i in range(self.config.pretraining_tp)
456 | ]
457 | value_states = torch.cat(value_states, dim=-1)
458 |
459 | else:
460 | query_states = self.q_proj(hidden_states)
461 | key_states = self.k_proj(hidden_states)
462 | value_states = self.v_proj(hidden_states)
463 |
464 | query_states = query_states.view(
465 | bsz, q_len, self.num_heads, self.head_dim
466 | ).transpose(1, 2)
467 | key_states = key_states.view(
468 | bsz, q_len, self.num_key_value_heads, self.head_dim
469 | ).transpose(1, 2)
470 | value_states = value_states.view(
471 | bsz, q_len, self.num_key_value_heads, self.head_dim
472 | ).transpose(1, 2)
473 |
474 | kv_seq_len = key_states.shape[-2]
475 | if past_key_value is not None:
476 | kv_seq_len += past_key_value[0].shape[-2]
477 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
478 | query_states, key_states = apply_rotary_pos_emb(
479 | query_states, key_states, cos, sin, position_ids
480 | )
481 |
482 | if past_key_value is not None:
483 | # reuse k, v, self_attention
484 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
485 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
486 |
487 | past_key_value = (key_states, value_states) if use_cache else None
488 |
489 | key_states = repeat_kv(key_states, self.num_key_value_groups)
490 | value_states = repeat_kv(value_states, self.num_key_value_groups)
491 |
492 | attn_weights = torch.matmul(
493 | query_states, key_states.transpose(2, 3)
494 | ) / math.sqrt(self.head_dim)
495 |
496 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
497 | raise ValueError(
498 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
499 | f" {attn_weights.size()}"
500 | )
501 |
502 | if attention_mask is not None:
503 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
504 | raise ValueError(
505 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
506 | )
507 | attn_weights = attn_weights + attention_mask
508 |
509 | # upcast attention to fp32
510 | attn_weights = nn.functional.softmax(
511 | attn_weights, dim=-1, dtype=torch.float32
512 | ).to(query_states.dtype)
513 | attn_output = torch.matmul(attn_weights, value_states)
514 |
515 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
516 | raise ValueError(
517 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
518 | f" {attn_output.size()}"
519 | )
520 |
521 | attn_output = attn_output.transpose(1, 2).contiguous()
522 |
523 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
524 |
525 | if self.config.pretraining_tp > 1:
526 | attn_output = attn_output.split(
527 | self.hidden_size // self.config.pretraining_tp, dim=2
528 | )
529 | o_proj_slices = self.o_proj.weight.split(
530 | self.hidden_size // self.config.pretraining_tp, dim=1
531 | )
532 | attn_output = sum(
533 | [
534 | F.linear(attn_output[i], o_proj_slices[i])
535 | for i in range(self.config.pretraining_tp)
536 | ]
537 | )
538 | else:
539 | attn_output = self.o_proj(attn_output)
540 |
541 | if not output_attentions:
542 | attn_weights = None
543 |
544 | return attn_output, attn_weights, past_key_value
545 |
546 |
547 | class LlamaFlashAttention2(LlamaAttention):
548 | """
549 | Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
550 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
551 | flash attention and deal with padding tokens in case the input contains any of them.
552 | """
553 |
554 | def forward(
555 | self,
556 | hidden_states: torch.Tensor,
557 | attention_mask: Optional[torch.Tensor] = None,
558 | position_ids: Optional[torch.LongTensor] = None,
559 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
560 | output_attentions: bool = False,
561 | use_cache: bool = False,
562 | padding_mask: Optional[torch.LongTensor] = None,
563 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
564 | # LlamaFlashAttention2 attention does not support output_attentions
565 | output_attentions = False
566 |
567 | bsz, q_len, _ = hidden_states.size()
568 |
569 | query_states = self.q_proj(hidden_states)
570 | key_states = self.k_proj(hidden_states)
571 | value_states = self.v_proj(hidden_states)
572 |
573 | # Flash attention requires the input to have the shape
574 | # batch_size x seq_length x head_dime x hidden_dim
575 | # therefore we just need to keep the original shape
576 | query_states = query_states.view(
577 | bsz, q_len, self.num_heads, self.head_dim
578 | ).transpose(1, 2)
579 | key_states = key_states.view(
580 | bsz, q_len, self.num_key_value_heads, self.head_dim
581 | ).transpose(1, 2)
582 | value_states = value_states.view(
583 | bsz, q_len, self.num_key_value_heads, self.head_dim
584 | ).transpose(1, 2)
585 |
586 | kv_seq_len = key_states.shape[-2]
587 | if past_key_value is not None:
588 | kv_seq_len += past_key_value[0].shape[-2]
589 |
590 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
591 |
592 | query_states, key_states = apply_rotary_pos_emb(
593 | query_states, key_states, cos, sin, position_ids
594 | )
595 |
596 | if past_key_value is not None:
597 | # reuse k, v, self_attention
598 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
599 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
600 |
601 | past_key_value = (key_states, value_states) if use_cache else None
602 |
603 | query_states = query_states.transpose(1, 2)
604 | key_states = key_states.transpose(1, 2)
605 | value_states = value_states.transpose(1, 2)
606 |
607 | # TODO: llama does not have dropout in the config??
608 | # It is recommended to use dropout with FA according to the docs
609 | # when training.
610 | dropout_rate = 0.0 # if not self.training else self.attn_dropout
611 |
612 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
613 | # therefore the input hidden states gets silently casted in float32. Hence, we need
614 | # cast them back in float16 just to be sure everything works as expected.
615 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms
616 | # in fp32. (LlamaRMSNorm handles it correctly)
617 | input_dtype = query_states.dtype
618 | if input_dtype == torch.float32:
619 | logger.warning_once(
620 | "The input hidden states seems to be silently casted in float32, this might be related to"
621 | " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
622 | " float16."
623 | )
624 |
625 | query_states = query_states.to(torch.float16)
626 | key_states = key_states.to(torch.float16)
627 | value_states = value_states.to(torch.float16)
628 |
629 | attn_output = self._flash_attention_forward(
630 | query_states,
631 | key_states,
632 | value_states,
633 | padding_mask,
634 | q_len,
635 | dropout=dropout_rate,
636 | )
637 |
638 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
639 | attn_output = self.o_proj(attn_output)
640 |
641 | if not output_attentions:
642 | attn_weights = None
643 |
644 | return attn_output, attn_weights, past_key_value
645 |
646 | def _flash_attention_forward(
647 | self,
648 | query_states,
649 | key_states,
650 | value_states,
651 | padding_mask,
652 | query_length,
653 | dropout=0.0,
654 | softmax_scale=None,
655 | ):
656 | """
657 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
658 | first unpad the input, then computes the attention scores and pad the final attention scores.
659 |
660 | Args:
661 | query_states (`torch.Tensor`):
662 | Input query states to be passed to Flash Attention API
663 | key_states (`torch.Tensor`):
664 | Input key states to be passed to Flash Attention API
665 | value_states (`torch.Tensor`):
666 | Input value states to be passed to Flash Attention API
667 | padding_mask (`torch.Tensor`):
668 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
669 | position of padding tokens and 1 for the position of non-padding tokens.
670 | dropout (`int`, *optional*):
671 | Attention dropout
672 | softmax_scale (`float`, *optional*):
673 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
674 | """
675 | # Contains at least one padding token in the sequence
676 | if padding_mask is not None:
677 | batch_size = query_states.shape[0]
678 | (
679 | query_states,
680 | key_states,
681 | value_states,
682 | indices_q,
683 | cu_seq_lens,
684 | max_seq_lens,
685 | ) = self._upad_input(
686 | query_states, key_states, value_states, padding_mask, query_length
687 | )
688 |
689 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens
690 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
691 |
692 | attn_output_unpad = flash_attn_varlen_func(
693 | query_states,
694 | key_states,
695 | value_states,
696 | cu_seqlens_q=cu_seqlens_q,
697 | cu_seqlens_k=cu_seqlens_k,
698 | max_seqlen_q=max_seqlen_in_batch_q,
699 | max_seqlen_k=max_seqlen_in_batch_k,
700 | dropout_p=dropout,
701 | softmax_scale=softmax_scale,
702 | causal=True,
703 | )
704 |
705 | attn_output = pad_input(
706 | attn_output_unpad, indices_q, batch_size, query_length
707 | )
708 | else:
709 | attn_output = flash_attn_func(
710 | query_states,
711 | key_states,
712 | value_states,
713 | dropout,
714 | softmax_scale=softmax_scale,
715 | causal=True,
716 | )
717 |
718 | return attn_output
719 |
720 | def _upad_input(
721 | self, query_layer, key_layer, value_layer, padding_mask, query_length
722 | ):
723 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
724 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
725 |
726 | key_layer = index_first_axis(
727 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
728 | indices_k,
729 | )
730 | value_layer = index_first_axis(
731 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
732 | indices_k,
733 | )
734 | if query_length == kv_seq_len:
735 | query_layer = index_first_axis(
736 | query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
737 | indices_k,
738 | )
739 | cu_seqlens_q = cu_seqlens_k
740 | max_seqlen_in_batch_q = max_seqlen_in_batch_k
741 | indices_q = indices_k
742 | elif query_length == 1:
743 | max_seqlen_in_batch_q = 1
744 | cu_seqlens_q = torch.arange(
745 | batch_size + 1, dtype=torch.int32, device=query_layer.device
746 | ) # There is a memcpy here, that is very bad.
747 | indices_q = cu_seqlens_q[:-1]
748 | query_layer = query_layer.squeeze(1)
749 | else:
750 | # The -q_len: slice assumes left padding.
751 | padding_mask = padding_mask[:, -query_length:]
752 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
753 | query_layer, padding_mask
754 | )
755 |
756 | return (
757 | query_layer,
758 | key_layer,
759 | value_layer,
760 | indices_q,
761 | (cu_seqlens_q, cu_seqlens_k),
762 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
763 | )
764 |
765 |
766 | class LlamaDecoderLayer(nn.Module):
767 | def __init__(self, config: LlamaConfig):
768 | super().__init__()
769 | self.hidden_size = config.hidden_size
770 | self.self_attn = (
771 | LlamaAttention(config=config)
772 | if not getattr(config, "_flash_attn_2_enabled", False)
773 | else LlamaFlashAttention2(config=config)
774 | )
775 | self.mlp = LlamaMLP(config)
776 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
777 | self.post_attention_layernorm = LlamaRMSNorm(
778 | config.hidden_size, eps=config.rms_norm_eps
779 | )
780 |
781 | def forward(
782 | self,
783 | hidden_states: torch.Tensor,
784 | attention_mask: Optional[torch.Tensor] = None,
785 | position_ids: Optional[torch.LongTensor] = None,
786 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
787 | output_attentions: Optional[bool] = False,
788 | use_cache: Optional[bool] = False,
789 | padding_mask: Optional[torch.LongTensor] = None,
790 | ) -> Tuple[
791 | torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
792 | ]:
793 | """
794 | Args:
795 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
796 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
797 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
798 | output_attentions (`bool`, *optional*):
799 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
800 | returned tensors for more detail.
801 | use_cache (`bool`, *optional*):
802 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
803 | (see `past_key_values`).
804 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
805 | """
806 |
807 | residual = hidden_states
808 |
809 | hidden_states = self.input_layernorm(hidden_states)
810 |
811 | # Self Attention
812 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
813 | hidden_states=hidden_states,
814 | attention_mask=attention_mask,
815 | position_ids=position_ids,
816 | past_key_value=past_key_value,
817 | output_attentions=output_attentions,
818 | use_cache=use_cache,
819 | padding_mask=padding_mask,
820 | )
821 | hidden_states = residual + hidden_states
822 |
823 | # Fully Connected
824 | residual = hidden_states
825 | hidden_states = self.post_attention_layernorm(hidden_states)
826 | hidden_states = self.mlp(hidden_states)
827 | hidden_states = residual + hidden_states
828 |
829 | outputs = (hidden_states,)
830 |
831 | if output_attentions:
832 | outputs += (self_attn_weights,)
833 |
834 | if use_cache:
835 | outputs += (present_key_value,)
836 |
837 | return outputs
838 |
839 |
840 | LLAMA_START_DOCSTRING = r"""
841 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
842 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
843 | etc.)
844 |
845 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
846 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
847 | and behavior.
848 |
849 | Parameters:
850 | config ([`LlamaConfig`]):
851 | Model configuration class with all the parameters of the model. Initializing with a config file does not
852 | load the weights associated with the model, only the configuration. Check out the
853 | [`~PreTrainedModel.from_pretrained`] method to load the model weights.
854 | """
855 |
856 |
857 | @add_start_docstrings(
858 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
859 | LLAMA_START_DOCSTRING,
860 | )
861 | class LlamaPreTrainedModel(PreTrainedModel):
862 | config_class = LlamaConfig
863 | base_model_prefix = "model"
864 | supports_gradient_checkpointing = True
865 | _no_split_modules = ["LlamaDecoderLayer"]
866 | _skip_keys_device_placement = "past_key_values"
867 | _supports_flash_attn_2 = True
868 |
869 | def _init_weights(self, module):
870 | std = self.config.initializer_range
871 | if isinstance(module, nn.Linear):
872 | module.weight.data.normal_(mean=0.0, std=std)
873 | if module.bias is not None:
874 | module.bias.data.zero_()
875 | elif isinstance(module, nn.Embedding):
876 | module.weight.data.normal_(mean=0.0, std=std)
877 | if module.padding_idx is not None:
878 | module.weight.data[module.padding_idx].zero_()
879 |
880 | def _set_gradient_checkpointing(self, module, value=False):
881 | if isinstance(module, LlamaModel):
882 | module.gradient_checkpointing = value
883 |
884 |
885 | LLAMA_INPUTS_DOCSTRING = r"""
886 | Args:
887 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
888 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
889 | it.
890 |
891 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
892 | [`PreTrainedTokenizer.__call__`] for details.
893 |
894 | [What are input IDs?](../glossary#input-ids)
895 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
896 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
897 |
898 | - 1 for tokens that are **not masked**,
899 | - 0 for tokens that are **masked**.
900 |
901 | [What are attention masks?](../glossary#attention-mask)
902 |
903 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
904 | [`PreTrainedTokenizer.__call__`] for details.
905 |
906 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
907 | `past_key_values`).
908 |
909 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
910 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
911 | information on the default strategy.
912 |
913 | - 1 indicates the head is **not masked**,
914 | - 0 indicates the head is **masked**.
915 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
916 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
917 | config.n_positions - 1]`.
918 |
919 | [What are position IDs?](../glossary#position-ids)
920 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
921 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
922 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
923 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
924 |
925 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
926 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
927 |
928 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
929 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
930 | of shape `(batch_size, sequence_length)`.
931 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
932 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
933 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
934 | model's internal embedding lookup matrix.
935 | use_cache (`bool`, *optional*):
936 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
937 | `past_key_values`).
938 | output_attentions (`bool`, *optional*):
939 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
940 | tensors for more detail.
941 | output_hidden_states (`bool`, *optional*):
942 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
943 | more detail.
944 | return_dict (`bool`, *optional*):
945 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
946 | """
947 |
948 |
949 | @add_start_docstrings(
950 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
951 | LLAMA_START_DOCSTRING,
952 | )
953 | class LlamaModel(LlamaPreTrainedModel):
954 | """
955 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
956 |
957 | Args:
958 | config: LlamaConfig
959 | """
960 |
961 | def __init__(self, config: LlamaConfig):
962 | super().__init__(config)
963 | self.padding_idx = config.pad_token_id
964 | self.vocab_size = config.vocab_size
965 |
966 | self.embed_tokens = nn.Embedding(
967 | config.vocab_size, config.hidden_size, self.padding_idx
968 | )
969 | self.layers = nn.ModuleList(
970 | [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
971 | )
972 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
973 |
974 | self.gradient_checkpointing = False
975 | # Initialize weights and apply final processing
976 | self.post_init()
977 |
978 | def get_input_embeddings(self):
979 | return self.embed_tokens
980 |
981 | def set_input_embeddings(self, value):
982 | self.embed_tokens = value
983 |
984 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
985 | def _prepare_decoder_attention_mask(
986 | self,
987 | attention_mask,
988 | input_shape,
989 | inputs_embeds,
990 | past_key_values_length,
991 | tree_attn_mask=None,
992 | ):
993 | # create causal mask
994 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
995 | combined_attention_mask = None
996 | if input_shape[-1] > 1:
997 | combined_attention_mask = _make_causal_mask(
998 | input_shape,
999 | inputs_embeds.dtype,
1000 | device=inputs_embeds.device,
1001 | past_key_values_length=past_key_values_length,
1002 | tree_attn_mask=tree_attn_mask,
1003 | )
1004 |
1005 | if attention_mask is not None:
1006 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1007 | expanded_attn_mask = _expand_mask(
1008 | attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1009 | ).to(inputs_embeds.device)
1010 | combined_attention_mask = (
1011 | expanded_attn_mask
1012 | if combined_attention_mask is None
1013 | else expanded_attn_mask + combined_attention_mask
1014 | )
1015 |
1016 | return combined_attention_mask
1017 |
1018 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1019 | def forward(
1020 | self,
1021 | input_ids: torch.LongTensor = None,
1022 | attention_mask: Optional[torch.Tensor] = None,
1023 | position_ids: Optional[torch.LongTensor] = None,
1024 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1025 | inputs_embeds: Optional[torch.FloatTensor] = None,
1026 | use_cache: Optional[bool] = None,
1027 | output_attentions: Optional[bool] = None,
1028 | output_hidden_states: Optional[bool] = None,
1029 | return_dict: Optional[bool] = None,
1030 | tree_attn_mask: Optional[torch.Tensor] = None,
1031 | ) -> Union[Tuple, BaseModelOutputWithPast]:
1032 | output_attentions = (
1033 | output_attentions
1034 | if output_attentions is not None
1035 | else self.config.output_attentions
1036 | )
1037 | output_hidden_states = (
1038 | output_hidden_states
1039 | if output_hidden_states is not None
1040 | else self.config.output_hidden_states
1041 | )
1042 | use_cache = use_cache if use_cache is not None else self.config.use_cache
1043 |
1044 | return_dict = (
1045 | return_dict if return_dict is not None else self.config.use_return_dict
1046 | )
1047 |
1048 | # retrieve input_ids and inputs_embeds
1049 | if input_ids is not None and inputs_embeds is not None:
1050 | raise ValueError(
1051 | "You cannot specify both input_ids and inputs_embeds at the same time"
1052 | )
1053 | elif input_ids is not None:
1054 | batch_size, seq_length = input_ids.shape
1055 | elif inputs_embeds is not None:
1056 | batch_size, seq_length, _ = inputs_embeds.shape
1057 | else:
1058 | raise ValueError("You have to specify either input_ids or inputs_embeds")
1059 |
1060 | seq_length_with_past = seq_length
1061 | past_key_values_length = 0
1062 |
1063 | if past_key_values is not None:
1064 | past_key_values_length = past_key_values[0][0].shape[2]
1065 | seq_length_with_past = seq_length_with_past + past_key_values_length
1066 |
1067 | if position_ids is None:
1068 | device = input_ids.device if input_ids is not None else inputs_embeds.device
1069 | position_ids = torch.arange(
1070 | past_key_values_length,
1071 | seq_length + past_key_values_length,
1072 | dtype=torch.long,
1073 | device=device,
1074 | )
1075 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1076 | else:
1077 | position_ids = position_ids.view(-1, seq_length).long()
1078 |
1079 | if inputs_embeds is None:
1080 | inputs_embeds = self.embed_tokens(input_ids)
1081 | # embed positions
1082 | if attention_mask is None:
1083 | attention_mask = torch.ones(
1084 | (batch_size, seq_length_with_past),
1085 | dtype=torch.bool,
1086 | device=inputs_embeds.device,
1087 | )
1088 | padding_mask = None
1089 | else:
1090 | if 0 in attention_mask:
1091 | padding_mask = attention_mask
1092 | else:
1093 | padding_mask = None
1094 |
1095 | attention_mask = self._prepare_decoder_attention_mask(
1096 | attention_mask,
1097 | (batch_size, seq_length),
1098 | inputs_embeds,
1099 | past_key_values_length,
1100 | tree_attn_mask=tree_attn_mask,
1101 | )
1102 |
1103 | hidden_states = inputs_embeds
1104 |
1105 | if self.gradient_checkpointing and self.training:
1106 | if use_cache:
1107 | logger.warning_once(
1108 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1109 | )
1110 | use_cache = False
1111 |
1112 | # decoder layers
1113 | all_hidden_states = () if output_hidden_states else None
1114 | all_self_attns = () if output_attentions else None
1115 | next_decoder_cache = () if use_cache else None
1116 |
1117 | for idx, decoder_layer in enumerate(self.layers):
1118 | if output_hidden_states:
1119 | all_hidden_states += (hidden_states,)
1120 |
1121 | past_key_value = (
1122 | past_key_values[idx] if past_key_values is not None else None
1123 | )
1124 |
1125 | if self.gradient_checkpointing and self.training:
1126 |
1127 | def create_custom_forward(module):
1128 | def custom_forward(*inputs):
1129 | # None for past_key_value
1130 | return module(
1131 | *inputs,
1132 | past_key_value,
1133 | output_attentions,
1134 | padding_mask=padding_mask,
1135 | )
1136 |
1137 | return custom_forward
1138 |
1139 | layer_outputs = torch.utils.checkpoint.checkpoint(
1140 | create_custom_forward(decoder_layer),
1141 | hidden_states,
1142 | attention_mask,
1143 | position_ids,
1144 | )
1145 | else:
1146 | layer_outputs = decoder_layer(
1147 | hidden_states,
1148 | attention_mask=attention_mask,
1149 | position_ids=position_ids,
1150 | past_key_value=past_key_value,
1151 | output_attentions=output_attentions,
1152 | use_cache=use_cache,
1153 | padding_mask=padding_mask,
1154 | )
1155 |
1156 | hidden_states = layer_outputs[0]
1157 |
1158 | if use_cache:
1159 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1160 |
1161 | if output_attentions:
1162 | all_self_attns += (layer_outputs[1],)
1163 |
1164 | hidden_states = self.norm(hidden_states)
1165 |
1166 | # add hidden states from the last decoder layer
1167 | if output_hidden_states:
1168 | all_hidden_states += (hidden_states,)
1169 |
1170 | next_cache = next_decoder_cache if use_cache else None
1171 | if not return_dict:
1172 | return tuple(
1173 | v
1174 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1175 | if v is not None
1176 | )
1177 | return BaseModelOutputWithPast(
1178 | last_hidden_state=hidden_states,
1179 | past_key_values=next_cache,
1180 | hidden_states=all_hidden_states,
1181 | attentions=all_self_attns,
1182 | )
1183 |
1184 |
1185 | class LlamaForCausalLM(LlamaPreTrainedModel):
1186 | _tied_weights_keys = ["lm_head.weight"]
1187 |
1188 | def __init__(self, config):
1189 | super().__init__(config)
1190 | self.model = LlamaModel(config)
1191 | self.vocab_size = config.vocab_size
1192 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1193 |
1194 | # Initialize weights and apply final processing
1195 | self.post_init()
1196 |
1197 | def get_input_embeddings(self):
1198 | return self.model.embed_tokens
1199 |
1200 | def set_input_embeddings(self, value):
1201 | self.model.embed_tokens = value
1202 |
1203 | def get_output_embeddings(self):
1204 | return self.lm_head
1205 |
1206 | def set_output_embeddings(self, new_embeddings):
1207 | self.lm_head = new_embeddings
1208 |
1209 | def set_decoder(self, decoder):
1210 | self.model = decoder
1211 |
1212 | def get_decoder(self):
1213 | return self.model
1214 |
1215 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1216 | @replace_return_docstrings(
1217 | output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1218 | )
1219 | def forward(
1220 | self,
1221 | input_ids: torch.LongTensor = None,
1222 | attention_mask: Optional[torch.Tensor] = None,
1223 | position_ids: Optional[torch.LongTensor] = None,
1224 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1225 | inputs_embeds: Optional[torch.FloatTensor] = None,
1226 | labels: Optional[torch.LongTensor] = None,
1227 | use_cache: Optional[bool] = None,
1228 | output_attentions: Optional[bool] = None,
1229 | output_hidden_states: Optional[bool] = None,
1230 | return_dict: Optional[bool] = None,
1231 | ) -> Union[Tuple, CausalLMOutputWithPast]:
1232 | r"""
1233 | Args:
1234 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1235 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1236 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1237 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1238 |
1239 | Returns:
1240 |
1241 | Example:
1242 |
1243 | ```python
1244 | >>> from transformers import AutoTokenizer, LlamaForCausalLM
1245 |
1246 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1247 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1248 |
1249 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
1250 | >>> inputs = tokenizer(prompt, return_tensors="pt")
1251 |
1252 | >>> # Generate
1253 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1254 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1255 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1256 | ```"""
1257 |
1258 | output_attentions = (
1259 | output_attentions
1260 | if output_attentions is not None
1261 | else self.config.output_attentions
1262 | )
1263 | output_hidden_states = (
1264 | output_hidden_states
1265 | if output_hidden_states is not None
1266 | else self.config.output_hidden_states
1267 | )
1268 | return_dict = (
1269 | return_dict if return_dict is not None else self.config.use_return_dict
1270 | )
1271 |
1272 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1273 | outputs = self.model(
1274 | input_ids=input_ids,
1275 | attention_mask=attention_mask,
1276 | position_ids=position_ids,
1277 | past_key_values=past_key_values,
1278 | inputs_embeds=inputs_embeds,
1279 | use_cache=use_cache,
1280 | output_attentions=output_attentions,
1281 | output_hidden_states=output_hidden_states,
1282 | return_dict=return_dict,
1283 | )
1284 |
1285 | hidden_states = outputs[0]
1286 | if self.config.pretraining_tp > 1:
1287 | lm_head_slices = self.lm_head.weight.split(
1288 | self.vocab_size // self.config.pretraining_tp, dim=0
1289 | )
1290 | logits = [
1291 | F.linear(hidden_states, lm_head_slices[i])
1292 | for i in range(self.config.pretraining_tp)
1293 | ]
1294 | logits = torch.cat(logits, dim=-1)
1295 | else:
1296 | logits = self.lm_head(hidden_states)
1297 | logits = logits.float()
1298 |
1299 | loss = None
1300 | if labels is not None:
1301 | # Shift so that tokens < n predict n
1302 | shift_logits = logits[..., :-1, :].contiguous()
1303 | shift_labels = labels[..., 1:].contiguous()
1304 | # Flatten the tokens
1305 | loss_fct = CrossEntropyLoss()
1306 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
1307 | shift_labels = shift_labels.view(-1)
1308 | # Enable model parallelism
1309 | shift_labels = shift_labels.to(shift_logits.device)
1310 | loss = loss_fct(shift_logits, shift_labels)
1311 |
1312 | if not return_dict:
1313 | output = (logits,) + outputs[1:]
1314 | return (loss,) + output if loss is not None else output
1315 |
1316 | return CausalLMOutputWithPast(
1317 | loss=loss,
1318 | logits=logits,
1319 | past_key_values=outputs.past_key_values,
1320 | hidden_states=outputs.hidden_states,
1321 | attentions=outputs.attentions,
1322 | )
1323 |
1324 | def prepare_inputs_for_generation(
1325 | self,
1326 | input_ids,
1327 | past_key_values=None,
1328 | attention_mask=None,
1329 | inputs_embeds=None,
1330 | **kwargs,
1331 | ):
1332 | if past_key_values:
1333 | input_ids = input_ids[:, -1:]
1334 |
1335 | position_ids = kwargs.get("position_ids", None)
1336 | if attention_mask is not None and position_ids is None:
1337 | # create position_ids on the fly for batch generation
1338 | position_ids = attention_mask.long().cumsum(-1) - 1
1339 | position_ids.masked_fill_(attention_mask == 0, 1)
1340 | if past_key_values:
1341 | position_ids = position_ids[:, -1].unsqueeze(-1)
1342 |
1343 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1344 | if inputs_embeds is not None and past_key_values is None:
1345 | model_inputs = {"inputs_embeds": inputs_embeds}
1346 | else:
1347 | model_inputs = {"input_ids": input_ids}
1348 |
1349 | model_inputs.update(
1350 | {
1351 | "position_ids": position_ids,
1352 | "past_key_values": past_key_values,
1353 | "use_cache": kwargs.get("use_cache"),
1354 | "attention_mask": attention_mask,
1355 | }
1356 | )
1357 | return model_inputs
1358 |
1359 | @staticmethod
1360 | def _reorder_cache(past_key_values, beam_idx):
1361 | reordered_past = ()
1362 | for layer_past in past_key_values:
1363 | reordered_past += (
1364 | tuple(
1365 | past_state.index_select(0, beam_idx.to(past_state.device))
1366 | for past_state in layer_past
1367 | ),
1368 | )
1369 | return reordered_past
1370 |
1371 |
1372 | @add_start_docstrings(
1373 | """
1374 | The LLaMa Model transformer with a sequence classification head on top (linear layer).
1375 |
1376 | [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1377 | (e.g. GPT-2) do.
1378 |
1379 | Since it does classification on the last token, it requires to know the position of the last token. If a
1380 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1381 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1382 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1383 | each row of the batch).
1384 | """,
1385 | LLAMA_START_DOCSTRING,
1386 | )
1387 | class LlamaForSequenceClassification(LlamaPreTrainedModel):
1388 | def __init__(self, config):
1389 | super().__init__(config)
1390 | self.num_labels = config.num_labels
1391 | self.model = LlamaModel(config)
1392 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1393 |
1394 | # Initialize weights and apply final processing
1395 | self.post_init()
1396 |
1397 | def get_input_embeddings(self):
1398 | return self.model.embed_tokens
1399 |
1400 | def set_input_embeddings(self, value):
1401 | self.model.embed_tokens = value
1402 |
1403 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1404 | def forward(
1405 | self,
1406 | input_ids: torch.LongTensor = None,
1407 | attention_mask: Optional[torch.Tensor] = None,
1408 | position_ids: Optional[torch.LongTensor] = None,
1409 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1410 | inputs_embeds: Optional[torch.FloatTensor] = None,
1411 | labels: Optional[torch.LongTensor] = None,
1412 | use_cache: Optional[bool] = None,
1413 | output_attentions: Optional[bool] = None,
1414 | output_hidden_states: Optional[bool] = None,
1415 | return_dict: Optional[bool] = None,
1416 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1417 | r"""
1418 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1419 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1420 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1421 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1422 | """
1423 | return_dict = (
1424 | return_dict if return_dict is not None else self.config.use_return_dict
1425 | )
1426 |
1427 | transformer_outputs = self.model(
1428 | input_ids,
1429 | attention_mask=attention_mask,
1430 | position_ids=position_ids,
1431 | past_key_values=past_key_values,
1432 | inputs_embeds=inputs_embeds,
1433 | use_cache=use_cache,
1434 | output_attentions=output_attentions,
1435 | output_hidden_states=output_hidden_states,
1436 | return_dict=return_dict,
1437 | )
1438 | hidden_states = transformer_outputs[0]
1439 | logits = self.score(hidden_states)
1440 |
1441 | if input_ids is not None:
1442 | batch_size = input_ids.shape[0]
1443 | else:
1444 | batch_size = inputs_embeds.shape[0]
1445 |
1446 | if self.config.pad_token_id is None and batch_size != 1:
1447 | raise ValueError(
1448 | "Cannot handle batch sizes > 1 if no padding token is defined."
1449 | )
1450 | if self.config.pad_token_id is None:
1451 | sequence_lengths = -1
1452 | else:
1453 | if input_ids is not None:
1454 | sequence_lengths = (
1455 | torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1
1456 | ).to(logits.device)
1457 | else:
1458 | sequence_lengths = -1
1459 |
1460 | pooled_logits = logits[
1461 | torch.arange(batch_size, device=logits.device), sequence_lengths
1462 | ]
1463 |
1464 | loss = None
1465 | if labels is not None:
1466 | labels = labels.to(logits.device)
1467 | if self.config.problem_type is None:
1468 | if self.num_labels == 1:
1469 | self.config.problem_type = "regression"
1470 | elif self.num_labels > 1 and (
1471 | labels.dtype == torch.long or labels.dtype == torch.int
1472 | ):
1473 | self.config.problem_type = "single_label_classification"
1474 | else:
1475 | self.config.problem_type = "multi_label_classification"
1476 |
1477 | if self.config.problem_type == "regression":
1478 | loss_fct = MSELoss()
1479 | if self.num_labels == 1:
1480 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1481 | else:
1482 | loss = loss_fct(pooled_logits, labels)
1483 | elif self.config.problem_type == "single_label_classification":
1484 | loss_fct = CrossEntropyLoss()
1485 | loss = loss_fct(
1486 | pooled_logits.view(-1, self.num_labels), labels.view(-1)
1487 | )
1488 | elif self.config.problem_type == "multi_label_classification":
1489 | loss_fct = BCEWithLogitsLoss()
1490 | loss = loss_fct(pooled_logits, labels)
1491 | if not return_dict:
1492 | output = (pooled_logits,) + transformer_outputs[1:]
1493 | return ((loss,) + output) if loss is not None else output
1494 |
1495 | return SequenceClassifierOutputWithPast(
1496 | loss=loss,
1497 | logits=pooled_logits,
1498 | past_key_values=transformer_outputs.past_key_values,
1499 | hidden_states=transformer_outputs.hidden_states,
1500 | attentions=transformer_outputs.attentions,
1501 | )
1502 |
--------------------------------------------------------------------------------
/MCSD/model/llama_tree_attn/tokenization_llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 |
21 | """Tokenization classes for LLaMA."""
22 | import os
23 | from shutil import copyfile
24 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
25 |
26 | import sentencepiece as spm
27 |
28 | from transformers.convert_slow_tokenizer import import_protobuf
29 | from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
30 | from transformers.utils import logging
31 |
32 |
33 | if TYPE_CHECKING:
34 | from transformers.tokenization_utils_base import TextInput
35 |
36 | logger = logging.get_logger(__name__)
37 |
38 | VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
39 |
40 | PRETRAINED_VOCAB_FILES_MAP = {
41 | "vocab_file": {
42 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
43 | },
44 | "tokenizer_file": {
45 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
46 | },
47 | }
48 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
49 | "hf-internal-testing/llama-tokenizer": 2048,
50 | }
51 | SPIECE_UNDERLINE = "▁"
52 |
53 | B_INST, E_INST = "[INST]", "[/INST]"
54 | B_SYS, E_SYS = "<>\n", "\n<>\n\n"
55 |
56 | # fmt: off
57 | DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
58 | answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
59 | that your responses are socially unbiased and positive in nature.
60 |
61 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
62 | correct. If you don't know the answer to a question, please don't share false information."""
63 | # fmt: on
64 |
65 |
66 | class LlamaTokenizer(PreTrainedTokenizer):
67 | """
68 | Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
69 | no padding token in the original model.
70 |
71 | Args:
72 | vocab_file (`str`):
73 | Path to the vocabulary file.
74 | legacy (`bool`, *optional*):
75 | Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
76 | and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple
77 | example:
78 |
79 | - `legacy=True`:
80 | ```python
81 | >>> from transformers import T5Tokenizer
82 |
83 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True)
84 | >>> tokenizer.encode("Hello .")
85 | [8774, 32099, 3, 5, 1]
86 | ```
87 | - `legacy=False`:
88 | ```python
89 | >>> from transformers import T5Tokenizer
90 |
91 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False)
92 | >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here
93 | [8774, 32099, 5, 1]
94 | ```
95 | Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
96 |
97 | """
98 |
99 | vocab_files_names = VOCAB_FILES_NAMES
100 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
101 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
102 | model_input_names = ["input_ids", "attention_mask"]
103 |
104 | def __init__(
105 | self,
106 | vocab_file,
107 | unk_token="",
108 | bos_token="",
109 | eos_token="",
110 | pad_token=None,
111 | sp_model_kwargs: Optional[Dict[str, Any]] = None,
112 | add_bos_token=True,
113 | add_eos_token=False,
114 | clean_up_tokenization_spaces=False,
115 | use_default_system_prompt=True,
116 | spaces_between_special_tokens=False,
117 | legacy=None,
118 | **kwargs,
119 | ):
120 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
121 | bos_token = (
122 | AddedToken(bos_token, lstrip=False, rstrip=False)
123 | if isinstance(bos_token, str)
124 | else bos_token
125 | )
126 | eos_token = (
127 | AddedToken(eos_token, lstrip=False, rstrip=False)
128 | if isinstance(eos_token, str)
129 | else eos_token
130 | )
131 | unk_token = (
132 | AddedToken(unk_token, lstrip=False, rstrip=False)
133 | if isinstance(unk_token, str)
134 | else unk_token
135 | )
136 | pad_token = (
137 | AddedToken(pad_token, lstrip=False, rstrip=False)
138 | if isinstance(pad_token, str)
139 | else pad_token
140 | )
141 |
142 | if legacy is None:
143 | logger.warning_once(
144 | f"You are using the default legacy behaviour of the {self.__class__}. This is"
145 | " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
146 | " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
147 | " means, and thouroughly read the reason why this was added as explained in"
148 | " https://github.com/huggingface/transformers/pull/24565"
149 | )
150 | legacy = True
151 |
152 | self.legacy = legacy
153 | self.vocab_file = vocab_file
154 | self.add_bos_token = add_bos_token
155 | self.add_eos_token = add_eos_token
156 | self.use_default_system_prompt = use_default_system_prompt
157 | self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
158 |
159 | super().__init__(
160 | bos_token=bos_token,
161 | eos_token=eos_token,
162 | unk_token=unk_token,
163 | pad_token=pad_token,
164 | add_bos_token=add_bos_token,
165 | add_eos_token=add_eos_token,
166 | sp_model_kwargs=self.sp_model_kwargs,
167 | clean_up_tokenization_spaces=clean_up_tokenization_spaces,
168 | use_default_system_prompt=use_default_system_prompt,
169 | spaces_between_special_tokens=spaces_between_special_tokens,
170 | legacy=legacy,
171 | **kwargs,
172 | )
173 |
174 | @property
175 | def unk_token_length(self):
176 | return len(self.sp_model.encode(str(self.unk_token)))
177 |
178 | # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
179 | def get_spm_processor(self, from_slow=False):
180 | tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
181 | if self.legacy or from_slow: # no dependency on protobuf
182 | tokenizer.Load(self.vocab_file)
183 | return tokenizer
184 |
185 | with open(self.vocab_file, "rb") as f:
186 | sp_model = f.read()
187 | model_pb2 = import_protobuf(
188 | f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)"
189 | )
190 | model = model_pb2.ModelProto.FromString(sp_model)
191 | normalizer_spec = model_pb2.NormalizerSpec()
192 | normalizer_spec.add_dummy_prefix = False
193 | model.normalizer_spec.MergeFrom(normalizer_spec)
194 | sp_model = model.SerializeToString()
195 | tokenizer.LoadFromSerializedProto(sp_model)
196 | return tokenizer
197 |
198 | def __getstate__(self):
199 | state = self.__dict__.copy()
200 | state["sp_model"] = None
201 | state["sp_model_proto"] = self.sp_model.serialized_model_proto()
202 | return state
203 |
204 | def __setstate__(self, d):
205 | self.__dict__ = d
206 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
207 | self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
208 |
209 | @property
210 | def vocab_size(self):
211 | """Returns vocab size"""
212 | return self.sp_model.get_piece_size()
213 |
214 | def get_vocab(self):
215 | """Returns vocab as a dict"""
216 | vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
217 | vocab.update(self.added_tokens_encoder)
218 | return vocab
219 |
220 | # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
221 | def tokenize(
222 | self, text: "TextInput", add_special_tokens=False, **kwargs
223 | ) -> List[str]:
224 | """
225 | Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
226 | first token is special.
227 | """
228 | if self.legacy or len(text) == 0:
229 | return super().tokenize(text, **kwargs)
230 |
231 | tokens = super().tokenize(
232 | SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs
233 | )
234 |
235 | if (
236 | len(tokens) > 1
237 | and tokens[0] == SPIECE_UNDERLINE
238 | and tokens[1] in self.all_special_tokens
239 | ):
240 | tokens = tokens[1:]
241 | return tokens
242 |
243 | # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
244 | def _tokenize(self, text, **kwargs):
245 | """
246 | Returns a tokenized string.
247 |
248 | We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
249 | SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
250 | `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
251 | `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`.
252 | `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`.
253 | """
254 | tokens = self.sp_model.encode(text, out_type=str)
255 | if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
256 | return tokens
257 |
258 | # 1. Encode string + prefix ex: " Hey"
259 | tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
260 | # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
261 | return (
262 | tokens[self.unk_token_length :]
263 | if len(tokens) >= self.unk_token_length
264 | else tokens
265 | )
266 |
267 | def _convert_token_to_id(self, token):
268 | """Converts a token (str) in an id using the vocab."""
269 | return self.sp_model.piece_to_id(token)
270 |
271 | def _convert_id_to_token(self, index):
272 | """Converts an index (integer) in a token (str) using the vocab."""
273 | token = self.sp_model.IdToPiece(index)
274 | return token
275 |
276 | def convert_tokens_to_string(self, tokens):
277 | """Converts a sequence of tokens (string) in a single string."""
278 | # since we manually add the prefix space, we have to remove it when decoding
279 | if tokens[0].startswith(SPIECE_UNDERLINE):
280 | tokens[0] = tokens[0][1:]
281 |
282 | current_sub_tokens = []
283 | out_string = ""
284 | prev_is_special = False
285 | for i, token in enumerate(tokens):
286 | # make sure that special tokens are not decoded using sentencepiece model
287 | if token in self.all_special_tokens:
288 | if not prev_is_special and i != 0 and self.legacy:
289 | out_string += " "
290 | out_string += self.sp_model.decode(current_sub_tokens) + token
291 | prev_is_special = True
292 | current_sub_tokens = []
293 | else:
294 | current_sub_tokens.append(token)
295 | prev_is_special = False
296 | out_string += self.sp_model.decode(current_sub_tokens)
297 | return out_string
298 |
299 | def save_vocabulary(
300 | self, save_directory, filename_prefix: Optional[str] = None
301 | ) -> Tuple[str]:
302 | """
303 | Save the vocabulary and special tokens file to a directory.
304 |
305 | Args:
306 | save_directory (`str`):
307 | The directory in which to save the vocabulary.
308 |
309 | Returns:
310 | `Tuple(str)`: Paths to the files saved.
311 | """
312 | if not os.path.isdir(save_directory):
313 | logger.error(f"Vocabulary path ({save_directory}) should be a directory")
314 | return
315 | out_vocab_file = os.path.join(
316 | save_directory,
317 | (filename_prefix + "-" if filename_prefix else "")
318 | + VOCAB_FILES_NAMES["vocab_file"],
319 | )
320 |
321 | if os.path.abspath(self.vocab_file) != os.path.abspath(
322 | out_vocab_file
323 | ) and os.path.isfile(self.vocab_file):
324 | copyfile(self.vocab_file, out_vocab_file)
325 | elif not os.path.isfile(self.vocab_file):
326 | with open(out_vocab_file, "wb") as fi:
327 | content_spiece_model = self.sp_model.serialized_model_proto()
328 | fi.write(content_spiece_model)
329 |
330 | return (out_vocab_file,)
331 |
332 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
333 | bos_token_id = [self.bos_token_id] if self.add_bos_token else []
334 | eos_token_id = [self.eos_token_id] if self.add_eos_token else []
335 |
336 | output = bos_token_id + token_ids_0 + eos_token_id
337 |
338 | if token_ids_1 is not None:
339 | output = output + bos_token_id + token_ids_1 + eos_token_id
340 |
341 | return output
342 |
343 | def get_special_tokens_mask(
344 | self,
345 | token_ids_0: List[int],
346 | token_ids_1: Optional[List[int]] = None,
347 | already_has_special_tokens: bool = False,
348 | ) -> List[int]:
349 | """
350 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
351 | special tokens using the tokenizer `prepare_for_model` method.
352 |
353 | Args:
354 | token_ids_0 (`List[int]`):
355 | List of IDs.
356 | token_ids_1 (`List[int]`, *optional*):
357 | Optional second list of IDs for sequence pairs.
358 | already_has_special_tokens (`bool`, *optional*, defaults to `False`):
359 | Whether or not the token list is already formatted with special tokens for the model.
360 |
361 | Returns:
362 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
363 | """
364 | if already_has_special_tokens:
365 | return super().get_special_tokens_mask(
366 | token_ids_0=token_ids_0,
367 | token_ids_1=token_ids_1,
368 | already_has_special_tokens=True,
369 | )
370 |
371 | bos_token_id = [1] if self.add_bos_token else []
372 | eos_token_id = [1] if self.add_eos_token else []
373 |
374 | if token_ids_1 is None:
375 | return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
376 | return (
377 | bos_token_id
378 | + ([0] * len(token_ids_0))
379 | + eos_token_id
380 | + bos_token_id
381 | + ([0] * len(token_ids_1))
382 | + eos_token_id
383 | )
384 |
385 | def create_token_type_ids_from_sequences(
386 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
387 | ) -> List[int]:
388 | """
389 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
390 | sequence pair mask has the following format:
391 |
392 | ```
393 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
394 | | first sequence | second sequence |
395 | ```
396 |
397 | if token_ids_1 is None, only returns the first portion of the mask (0s).
398 |
399 | Args:
400 | token_ids_0 (`List[int]`):
401 | List of ids.
402 | token_ids_1 (`List[int]`, *optional*):
403 | Optional second list of IDs for sequence pairs.
404 |
405 | Returns:
406 | `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
407 | """
408 | bos_token_id = [self.bos_token_id] if self.add_bos_token else []
409 | eos_token_id = [self.eos_token_id] if self.add_eos_token else []
410 |
411 | output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
412 |
413 | if token_ids_1 is not None:
414 | output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
415 |
416 | return output
417 |
418 | @property
419 | def default_chat_template(self):
420 | """
421 | LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages.
422 | Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
423 | user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
424 | rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
425 | results in an unusual token ordering when it is present. This template should definitely be changed if you wish
426 | to fine-tune a model with more flexible role ordering!
427 |
428 | The output should look something like:
429 |
430 | [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer
431 | [INST] Prompt [/INST]
432 | """
433 |
434 | template = (
435 | "{% if messages[0]['role'] == 'system' %}"
436 | "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
437 | "{% set system_message = messages[0]['content'] %}"
438 | "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}"
439 | "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
440 | "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
441 | "{% else %}"
442 | "{% set loop_messages = messages %}"
443 | "{% set system_message = false %}"
444 | "{% endif %}"
445 | "{% for message in loop_messages %}" # Loop over all non-system messages
446 | "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
447 | "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
448 | "{% endif %}"
449 | "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
450 | "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}"
451 | "{% else %}"
452 | "{% set content = message['content'] %}"
453 | "{% endif %}"
454 | "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
455 | "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
456 | "{% elif message['role'] == 'system' %}"
457 | "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}"
458 | "{% elif message['role'] == 'assistant' %}"
459 | "{{ ' ' + content.strip() + ' ' + eos_token }}"
460 | "{% endif %}"
461 | "{% endfor %}"
462 | )
463 | template = template.replace(
464 | "USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false"
465 | )
466 | default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
467 | template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
468 |
469 | return template
470 |
--------------------------------------------------------------------------------
/MCSD/model/llama_tree_attn/tokenization_llama_fast.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import os
16 | from shutil import copyfile
17 | from typing import Optional, Tuple
18 |
19 | from tokenizers import processors
20 |
21 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
22 | from transformers.utils import is_sentencepiece_available, logging
23 | from transformers.utils.versions import require_version
24 |
25 |
26 | require_version("tokenizers>=0.13.3")
27 |
28 | if is_sentencepiece_available():
29 | from .tokenization_llama import LlamaTokenizer
30 | else:
31 | LlamaTokenizer = None
32 |
33 | logger = logging.get_logger(__name__)
34 | VOCAB_FILES_NAMES = {
35 | "vocab_file": "tokenizer.model",
36 | "tokenizer_file": "tokenizer.json",
37 | }
38 |
39 | PRETRAINED_VOCAB_FILES_MAP = {
40 | "vocab_file": {
41 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
42 | },
43 | "tokenizer_file": {
44 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
45 | },
46 | }
47 | B_INST, E_INST = "[INST]", "[/INST]"
48 | B_SYS, E_SYS = "<>\n", "\n<>\n\n"
49 |
50 | # fmt: off
51 | DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
52 | answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
53 | that your responses are socially unbiased and positive in nature.
54 |
55 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
56 | correct. If you don't know the answer to a question, please don't share false information."""
57 | # fmt: on
58 |
59 |
60 | class LlamaTokenizerFast(PreTrainedTokenizerFast):
61 | """
62 | Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
63 |
64 | This uses notably ByteFallback and no normalization.
65 |
66 | ```
67 | from transformers import LlamaTokenizerFast
68 |
69 | tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
70 | tokenizer.encode("Hello this is a test")
71 | >>> [1, 15043, 445, 338, 263, 1243]
72 | ```
73 |
74 | If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
75 | call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
76 | values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
77 | [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
78 |
79 |
80 | This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
81 | refer to this superclass for more information regarding those methods.
82 |
83 | Args:
84 | vocab_file (`str`):
85 | [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
86 | contains the vocabulary necessary to instantiate a tokenizer.
87 | tokenizer_file (`str`):
88 | [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
89 | contains everything needed to load the tokenizer.
90 |
91 | clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
92 | Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
93 | spaces.
94 |
95 | bos_token (`str`, *optional*, defaults to `""`):
96 | The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
97 |
98 | eos_token (`str`, *optional*, defaults to `""`):
99 | The end of sequence token.
100 |
101 | unk_token (`str`, *optional*, defaults to `""`):
102 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
103 | token instead.
104 | """
105 |
106 | vocab_files_names = VOCAB_FILES_NAMES
107 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
108 | slow_tokenizer_class = LlamaTokenizer
109 | padding_side = "left"
110 | model_input_names = ["input_ids", "attention_mask"]
111 |
112 | def __init__(
113 | self,
114 | vocab_file=None,
115 | tokenizer_file=None,
116 | clean_up_tokenization_spaces=False,
117 | unk_token="",
118 | bos_token="",
119 | eos_token="",
120 | add_bos_token=True,
121 | add_eos_token=False,
122 | use_default_system_prompt=True,
123 | **kwargs,
124 | ):
125 | super().__init__(
126 | vocab_file=vocab_file,
127 | tokenizer_file=tokenizer_file,
128 | clean_up_tokenization_spaces=clean_up_tokenization_spaces,
129 | unk_token=unk_token,
130 | bos_token=bos_token,
131 | eos_token=eos_token,
132 | use_default_system_prompt=use_default_system_prompt,
133 | **kwargs,
134 | )
135 | self._add_bos_token = add_bos_token
136 | self._add_eos_token = add_eos_token
137 | self.update_post_processor()
138 | self.use_default_system_prompt = use_default_system_prompt
139 | self.vocab_file = vocab_file
140 |
141 | @property
142 | def can_save_slow_tokenizer(self) -> bool:
143 | return os.path.isfile(self.vocab_file) if self.vocab_file else False
144 |
145 | def update_post_processor(self):
146 | """
147 | Updates the underlying post processor with the current `bos_token` and `eos_token`.
148 | """
149 | bos = self.bos_token
150 | bos_token_id = self.bos_token_id
151 |
152 | eos = self.eos_token
153 | eos_token_id = self.eos_token_id
154 |
155 | single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
156 | pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
157 |
158 | special_tokens = []
159 | if self.add_bos_token:
160 | special_tokens.append((bos, bos_token_id))
161 | if self.add_eos_token:
162 | special_tokens.append((eos, eos_token_id))
163 | self._tokenizer.post_processor = processors.TemplateProcessing(
164 | single=single, pair=pair, special_tokens=special_tokens
165 | )
166 |
167 | @property
168 | def add_eos_token(self):
169 | return self._add_eos_token
170 |
171 | @property
172 | def add_bos_token(self):
173 | return self._add_bos_token
174 |
175 | @add_eos_token.setter
176 | def add_eos_token(self, value):
177 | self._add_eos_token = value
178 | self.update_post_processor()
179 |
180 | @add_bos_token.setter
181 | def add_bos_token(self, value):
182 | self._add_bos_token = value
183 | self.update_post_processor()
184 |
185 | def save_vocabulary(
186 | self, save_directory: str, filename_prefix: Optional[str] = None
187 | ) -> Tuple[str]:
188 | if not self.can_save_slow_tokenizer:
189 | raise ValueError(
190 | "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
191 | "tokenizer."
192 | )
193 |
194 | if not os.path.isdir(save_directory):
195 | logger.error(f"Vocabulary path ({save_directory}) should be a directory")
196 | return
197 | out_vocab_file = os.path.join(
198 | save_directory,
199 | (filename_prefix + "-" if filename_prefix else "")
200 | + VOCAB_FILES_NAMES["vocab_file"],
201 | )
202 |
203 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
204 | copyfile(self.vocab_file, out_vocab_file)
205 |
206 | return (out_vocab_file,)
207 |
208 | @property
209 | # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
210 | def default_chat_template(self):
211 | """
212 | LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages.
213 | Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
214 | user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
215 | rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
216 | results in an unusual token ordering when it is present. This template should definitely be changed if you wish
217 | to fine-tune a model with more flexible role ordering!
218 |
219 | The output should look something like:
220 |
221 | [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer
222 | [INST] Prompt [/INST]
223 | """
224 |
225 | template = (
226 | "{% if messages[0]['role'] == 'system' %}"
227 | "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
228 | "{% set system_message = messages[0]['content'] %}"
229 | "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}"
230 | "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
231 | "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
232 | "{% else %}"
233 | "{% set loop_messages = messages %}"
234 | "{% set system_message = false %}"
235 | "{% endif %}"
236 | "{% for message in loop_messages %}" # Loop over all non-system messages
237 | "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
238 | "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
239 | "{% endif %}"
240 | "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
241 | "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}"
242 | "{% else %}"
243 | "{% set content = message['content'] %}"
244 | "{% endif %}"
245 | "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
246 | "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
247 | "{% elif message['role'] == 'system' %}"
248 | "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}"
249 | "{% elif message['role'] == 'assistant' %}"
250 | "{{ ' ' + content.strip() + ' ' + eos_token }}"
251 | "{% endif %}"
252 | "{% endfor %}"
253 | )
254 | template = template.replace(
255 | "USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false"
256 | )
257 | default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
258 | template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
259 |
260 | return template
261 |
262 | # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
263 | # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
264 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
265 | bos_token_id = [self.bos_token_id] if self.add_bos_token else []
266 | eos_token_id = [self.eos_token_id] if self.add_eos_token else []
267 |
268 | output = bos_token_id + token_ids_0 + eos_token_id
269 |
270 | if token_ids_1 is not None:
271 | output = output + bos_token_id + token_ids_1 + eos_token_id
272 |
273 | return output
274 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi-Candidate Speculative Decoding
2 |
3 | ## Code Release
4 | See [here](./MCSD/).
5 |
6 | ## Data Release
7 | For [Alpaca dataset](https://github.com/flexflow/FlexFlow/tree/inference?tab=readme-ov-file#prompt-datasets), we use exactly the same exact source as [SpecInfer](https://arxiv.org/pdf/2305.09781.pdf).
8 |
9 | For the [WMT dataset](/dataset/wmt_ende.json), we follow the process of SpecInfer: randomly sampling 1000 samples from the test set. We wrap the source sentences using the following template:
10 | ```
11 | Translate the input English sentence into German.
12 | Input: {source sentence}
13 | Output:
14 | ```
15 |
16 | ## Model Release
17 | We release our fine-tuned draft models on hugginface, see [Vicuna-68M](https://huggingface.co/double7/vicuna-68m) and [Vicuna-160M](https://huggingface.co/double7/vicuna-160m). They are fine-tuned from [LLaMA-68M](https://huggingface.co/JackFram/llama-68m) and [LLaMA-160M](https://huggingface.co/JackFram/llama-160m) respectively on ShareGPT data. The training setup follows [FastChat](https://github.com/lm-sys/FastChat).
18 |
--------------------------------------------------------------------------------