├── README.md ├── config.json ├── sft.json ├── config.py ├── training_config.py ├── sampling.py ├── data └── fineweb_edu │ └── fineweb_edu.py ├── tokenizer.py ├── sft.py ├── dataloader.py ├── datacollator.py ├── deepseek.py ├── trainer.py ├── moe.py ├── pre_train.py ├── train.py └── mla.py /README.md: -------------------------------------------------------------------------------- 1 | # DeepSeek From Scratch -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "d_model": 1024, 3 | "nheads": 16, 4 | "max_position_embeddings": 4096, 5 | "dropout": 0.4, 6 | "device": "cuda", 7 | "use_kv_cache": true, 8 | "q_lora_rank": null, 9 | "kv_lora_rank": 512, 10 | "rope_head_dim": 64, 11 | "nope_head_dim": 128, 12 | "rope_base": 10000, 13 | "rope_scaling": { 14 | "type": "yarn", 15 | "scaling_factor": 4.0, 16 | "alpha": 1, 17 | "beta": 32, 18 | "attn_factor": 1.0, 19 | "training_context_length": 128 20 | }, 21 | "v_head_dim": 128, 22 | "num_shared_experts": 2, 23 | "num_routed_experts": 64, 24 | "moe_hidden_dimension": 1024, 25 | "mlp_hidden_dimension": 4096, 26 | "topk": 6, 27 | "topk_norm_epsilon": 1e-9, 28 | "rms_norm_eps": 1e-6, 29 | "normalized_moe_gates": true, 30 | "expert_load_balance_factor": 0.01, 31 | "num_layers": 4, 32 | "vocab_size": 101024, 33 | "init_weight_std": 0.006, 34 | "first_k_dense_replace": 1 35 | } -------------------------------------------------------------------------------- /sft.json: -------------------------------------------------------------------------------- 1 | { 2 | "device": "cuda", 3 | "learning_rate": 5e-4, 4 | "min_learning_rate": 5e-5, 5 | "eval_iters": 5, 6 | "eval_interval": 10, 7 | "dtype": "bfloat16", 8 | "measure_throughput_interval": 100, 9 | "estimate_throughput": false, 10 | "wandb_log": true, 11 | "wandb_project": "deepseek sft", 12 | "wandb_run_name": "sft_8_bit_optimizer_16_batch_size", 13 | "adamw_use_fused": true, 14 | "max_train_steps": 30000, 15 | "batch_size": 1, 16 | "gradient_accumulation_steps": 16, 17 | "warmup_iters": 500, 18 | "lr_decay_iters": 1000, 19 | "decay_lr": true, 20 | "out_dir": "sft_output", 21 | "resume": false, 22 | "checkpoint_path": "sft_ckpt.pt", 23 | "adamw_beta1": 0.9, 24 | "adamw_beta2": 0.95, 25 | "adamw_weight_decay": 0.1, 26 | "use_eight_bit_optimizer": true, 27 | "grad_clip": 1.0, 28 | "model_config_path": "config.json", 29 | "dataset_name": "HuggingFaceH4/ultrachat_200k", 30 | "train_split": "train_sft", 31 | "eval_split": "test_sft", 32 | "tokenizer_type": "cl100k_base" 33 | } -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | """ 4 | From deepseek v2 paper, 5 | 6 | d_model: 5120, hidden dimension 7 | nheads: 128, n_h 8 | block_size: 4096 then extend to 128k 9 | q_lora_rank: 1536, d_c_prime 10 | kv_lora_rank: 512, d_c 11 | rope_head_dim: 64, d_h_R 12 | nope_head_dim: 128, d_h 13 | """ 14 | 15 | @dataclass 16 | class DeepSeekConfig: 17 | d_model: int 18 | nheads: int 19 | max_position_embeddings: int 20 | dropout: float 21 | device: str 22 | use_kv_cache: bool 23 | # MLA parameters 24 | q_lora_rank: int 25 | kv_lora_rank: int 26 | rope_head_dim: int 27 | nope_head_dim: int 28 | v_head_dim: int 29 | rope_base: int 30 | rope_scaling: dict 31 | 32 | # MoE parameters 33 | num_shared_experts: int 34 | num_routed_experts: int 35 | topk: int 36 | moe_hidden_dimension: int 37 | mlp_hidden_dimension: int 38 | topk_norm_epsilon: float 39 | normalized_moe_gates: bool 40 | expert_load_balance_factor: float # alpha1 41 | rms_norm_eps: float 42 | first_k_dense_replace: int 43 | # DeepSeek model 44 | num_layers: int 45 | vocab_size: int 46 | init_weight_std: float 47 | -------------------------------------------------------------------------------- /training_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class TrainingConfig: 5 | device: str = "cuda" 6 | learning_rate: float = 5e-4 7 | min_learning_rate: float = 5e-5 8 | eval_iters: int = 5 9 | eval_interval: int = 10 10 | dtype: str = "bfloat16" 11 | measure_throughput_interval: int = 100 12 | estimate_throughput: bool = False 13 | wandb_log: bool = True 14 | wandb_project: str = "deepseek training" 15 | wandb_run_name: str = "8_bit_optimizer" 16 | adamw_use_fused: bool = True 17 | max_train_steps: int = 30000 18 | batch_size: int = 8 19 | gradient_accumulation_steps: int = 8 20 | warmup_iters: int = 500 21 | lr_decay_iters: int = 1000 22 | decay_lr: bool = True 23 | out_dir: str = "output" 24 | resume: bool = False 25 | checkpoint_path: str = "8_bit_optimizer_ckpt.pt" 26 | adamw_beta1: float = 0.9 27 | adamw_beta2: float = 0.95 28 | adamw_weight_decay: float = 0.1 29 | use_eight_bit_optimizer: bool = True 30 | grad_clip: float = 1.0 31 | model_config_path: str = "config.json" 32 | dataset_name: str = "HuggingFaceH4/ultrachat_200k" 33 | train_split: str = "train_sft" 34 | eval_split: str = "test_sft" 35 | tokenizer_type: str = "cl100k_base" 36 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def sample_top_p(logits: torch.Tensor, temperature: float, top_p: float): 4 | ''' 5 | logits: [batch_size, vocab_size] 6 | temperature: float > 0 7 | top_p: float between 0 and 1 8 | ''' 9 | logits = logits / temperature 10 | probs = torch.softmax(logits, dim=-1) 11 | sorted_probs, sorted_prob_indices = torch.sort(probs, dim=-1, descending=True) 12 | cum_probs = torch.cumsum(sorted_probs, dim=-1) 13 | sorted_removed_indices = cum_probs > top_p 14 | sorted_removed_indices[:, 1:] = sorted_removed_indices[:, :-1].clone() 15 | sorted_removed_indices[:, 0] = False # always keep the top first token 16 | # map the removed indices to the original logits 17 | removed_indices = sorted_removed_indices.scatter(dim=1, index=sorted_prob_indices, src=sorted_removed_indices) 18 | logits[removed_indices] = float('-inf') 19 | probs = torch.softmax(logits, dim=-1) 20 | # sample from the distribution 21 | return torch.multinomial(probs, num_samples=1) 22 | 23 | 24 | def sample_top_k(logits: torch.Tensor, temperature: float, k: int): 25 | ''' 26 | logits: [batch_size, vocab_size] 27 | temperature: float > 0 28 | k: int > 0 29 | ''' 30 | logits = logits / temperature 31 | top_k_logits, top_k_indices = torch.topk(logits, k, dim=-1) 32 | masked_logits = torch.full_like(logits, float('-inf')) 33 | masked_logits.scatter_(dim=1, index=top_k_indices, src=top_k_logits) 34 | probs = torch.softmax(masked_logits, dim=-1) 35 | # sample from the distribution 36 | return torch.multinomial(probs, num_samples=1) 37 | 38 | 39 | if __name__ == '__main__': 40 | logits = torch.tensor([[1.0, 2.0, 3.0, 2.5], [1.0, 2.0, 3.0, 2.5]]) 41 | # print(sample_top_p(logits, 1.0, 0.8)) 42 | print(sample_top_k(logits, 0.1, 2)) 43 | -------------------------------------------------------------------------------- /data/fineweb_edu/fineweb_edu.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import os 3 | import tiktoken 4 | import numpy as np 5 | import multiprocessing as mp 6 | from tqdm import tqdm 7 | dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train") 8 | 9 | local_dir = "edu_fineweb10B" 10 | remote_name = "sample-10BT" 11 | shard_size = int(1e8) # 100M tokens per shard, total of 100 shards 12 | fineweb_edu = "fineweb_edu" 13 | 14 | # create the cache the local directory if it doesn't exist yet 15 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir) 16 | os.makedirs(DATA_CACHE_DIR, exist_ok=True) 17 | 18 | # encode with tiktoken gpt4 bpe 19 | enc = tiktoken.get_encoding("cl100k_base") 20 | eot = enc._special_tokens['<|endoftext|>'] # end of text token 21 | 22 | def tokenize(doc): 23 | # tokenizes a single document and returns a numpy array of int32 tokens 24 | tokens = [eot] # the special <|endoftext|> token delimits all documents 25 | tokens.extend(enc.encode_ordinary(doc["text"])) 26 | return np.array(tokens, dtype=np.int32) 27 | 28 | def save_shard(filename, tokens): 29 | np.save(filename, tokens) 30 | 31 | nprocs = max(1, os.cpu_count()//2) 32 | 33 | with mp.Pool(nprocs) as pool: 34 | shard_idx = 0 35 | # pre-allocate the buffer for shard 36 | shard_buffer = np.empty(shard_size, dtype=np.int32) 37 | token_cnt = 0 38 | progress_bar = None 39 | # each process will handle a chunk size of 16 documents to tokenize them in parallel 40 | # tokens in the for loop are the tokenized results for 16 * nprocs documents 41 | for tokens in pool.imap(tokenize, dataset, chunksize=16): 42 | if token_cnt + len(tokens) < shard_size: 43 | shard_buffer[token_cnt:token_cnt+len(tokens)] = tokens 44 | token_cnt += len(tokens) 45 | if progress_bar is None: 46 | progress_bar = tqdm(total=shard_size, unit="tokens", desc=f"shard: {shard_idx:06d}") 47 | progress_bar.update(len(tokens)) 48 | else: 49 | split = "val" if shard_idx == 0 else "train" 50 | reminder = shard_size - token_cnt 51 | progress_bar.update(reminder) 52 | shard_buffer[token_cnt:token_cnt+reminder] = tokens[:reminder] 53 | file_path = os.path.join(DATA_CACHE_DIR, f"{fineweb_edu}_{split}_{shard_idx:06d}") 54 | # save the full shard 55 | save_shard(file_path, shard_buffer) 56 | shard_idx += 1 57 | token_cnt = len(tokens) - reminder 58 | shard_buffer[:token_cnt] = tokens[reminder:] 59 | progress_bar = None 60 | if token_cnt > 0: 61 | split = "val" if shard_idx == 0 else "train" 62 | file_path = os.path.join(DATA_CACHE_DIR, f"{fineweb_edu}_{split}_{shard_idx:06d}") 63 | save_shard(file_path, shard_buffer[:token_cnt]) 64 | 65 | 66 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | import tiktoken 2 | 3 | # reference: https://github.com/openai/tiktoken 4 | class Tokenizer: 5 | def __init__(self, tokenizer_type: str): 6 | self.tokenizer_type = tokenizer_type 7 | self.tokenizer = tiktoken.get_encoding(tokenizer_type) 8 | self.allowed_special = set() 9 | self.eos_token_id = None 10 | 11 | def get_token_id(self, token: str): 12 | ''' 13 | get the token id 14 | ''' 15 | return self.tokenizer.encode(token) 16 | 17 | 18 | def get_token_str(self, token_id: int): 19 | ''' 20 | get the token string 21 | ''' 22 | return self.tokenizer.decode([token_id]) 23 | 24 | def add_special_tokens(self, tokens: list[str]): 25 | ''' 26 | add the special tokens to the tokens 27 | ''' 28 | vocab_size = self.tokenizer.n_vocab 29 | 30 | special_tokens = self.tokenizer._special_tokens 31 | for i, token in enumerate(tokens): 32 | if token not in special_tokens: 33 | special_tokens[token] = vocab_size + i 34 | 35 | enc = tiktoken.Encoding( 36 | # If you're changing the set of special tokens, make sure to use a different name 37 | # It should be clear from the name what behaviour to expect. 38 | name=f"{self.tokenizer_type}_im", 39 | pat_str=self.tokenizer._pat_str, 40 | mergeable_ranks=self.tokenizer._mergeable_ranks, 41 | special_tokens=special_tokens 42 | ) 43 | self.tokenizer = enc 44 | self.allowed_special = set(special_tokens.keys()) 45 | self.eos_token_id = special_tokens["<|im_end|>"] 46 | 47 | def print_special_tokens(self): 48 | ''' 49 | print the special tokens 50 | ''' 51 | print(self.tokenizer._special_tokens) 52 | 53 | def get_vocab_size(self): 54 | ''' 55 | get the vocab size 56 | ''' 57 | return self.tokenizer.n_vocab 58 | 59 | def encode(self, text: str): 60 | ''' 61 | encode the text 62 | ''' 63 | return self.tokenizer.encode(text, allowed_special=self.allowed_special) 64 | 65 | def decode(self, token_ids: list[int]): 66 | ''' 67 | decode the token ids 68 | ''' 69 | return self.tokenizer.decode(token_ids) 70 | 71 | if __name__ == "__main__": 72 | tokenizer = Tokenizer(tokenizer_type="cl100k_base") 73 | text = "Hello, how are you?" 74 | token_ids = tokenizer.encode(text) 75 | print(f"token_ids: {token_ids}") 76 | 77 | token_strs = tokenizer.decode(token_ids) 78 | print(f"token_strs: {token_strs}") 79 | 80 | tokenizer.add_special_tokens(["<|im_start|>"]) 81 | text = "Hello, how are you? <|im_start|>" 82 | token_ids = tokenizer.encode(text) 83 | print(f"token_ids: {token_ids}") 84 | 85 | token_strs = tokenizer.decode(token_ids) 86 | print(f"token_strs: {token_strs}") 87 | 88 | -------------------------------------------------------------------------------- /sft.py: -------------------------------------------------------------------------------- 1 | # to format: black --line-length 88 sft.py 2 | 3 | from dataclasses import dataclass 4 | from tokenizer import Tokenizer 5 | import datasets 6 | from torch.utils.data import DataLoader, Dataset 7 | import json 8 | from training_config import TrainingConfig 9 | from trainer import Trainer 10 | from datacollator import DataCollatorForChatMl, ChatMlSpecialTokens 11 | 12 | # Define a custom PyTorch Dataset 13 | class SFTDataset(Dataset): 14 | def __init__(self, dataset): 15 | self.dataset = dataset 16 | 17 | def __len__(self): 18 | return len(self.dataset) 19 | 20 | def __getitem__(self, idx): 21 | return self.dataset[idx]["messages"] 22 | 23 | 24 | if __name__ == "__main__": 25 | sft_training_config_file = "sft.json" 26 | with open(sft_training_config_file, "r") as f: 27 | training_config = json.load(f) 28 | sft_training_config = TrainingConfig(**training_config) 29 | 30 | tokenizer = Tokenizer(sft_training_config.tokenizer_type) 31 | tokenizer.add_special_tokens( 32 | [ChatMlSpecialTokens().bos_token, ChatMlSpecialTokens().eos_token] 33 | ) 34 | # load huggingface dataset 35 | train_dataset = datasets.load_dataset(sft_training_config.dataset_name, split=sft_training_config.train_split) 36 | eval_dataset = datasets.load_dataset(sft_training_config.dataset_name, split=sft_training_config.eval_split) 37 | 38 | sft_train_dataset = SFTDataset(train_dataset) 39 | sft_eval_dataset = SFTDataset(eval_dataset) 40 | 41 | # create the datacollator 42 | data_collator = DataCollatorForChatMl( 43 | tokenizer, 44 | tokenizer.eos_token_id, 45 | # pytorch cross entropy loss will ignore labels with value -100 46 | # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html 47 | -100, 48 | ChatMlSpecialTokens().assistant, 49 | tokenizer.eos_token_id, 50 | ) 51 | 52 | # create dataloader 53 | sft_train_dataloader = DataLoader( 54 | sft_train_dataset, 55 | batch_size=sft_training_config.batch_size, 56 | shuffle=True, 57 | collate_fn=data_collator.process, 58 | ) 59 | sft_eval_dataloader = DataLoader( 60 | sft_eval_dataset, 61 | batch_size=sft_training_config.batch_size, 62 | shuffle=True, 63 | collate_fn=data_collator.process, 64 | ) 65 | 66 | # max_len = 0 67 | # train_batch_size = 0 68 | # eval_batch_size = 0 69 | # for batch in sft_train_dataloader: 70 | # input_ids = batch['input_ids'] 71 | # max_len = max(max_len, input_ids.shape[1]) 72 | # train_batch_size += 1 73 | # if train_batch_size % 100 == 0: 74 | # print(f"train_batch_size: {train_batch_size}") 75 | # for batch in sft_eval_dataloader: 76 | # input_ids = batch['input_ids'] 77 | # max_len = max(max_len, input_ids.shape[1]) 78 | # eval_batch_size += 1 79 | # if eval_batch_size % 100 == 0: 80 | # print(f"eval_batch_size: {eval_batch_size}") 81 | # print(f"max_len: {max_len}") 82 | 83 | sft_trainer = Trainer( 84 | sft_train_dataloader, 85 | sft_eval_dataloader, 86 | sft_training_config, 87 | ) 88 | sft_trainer.train() 89 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | 6 | def load_shard(file_path): 7 | ''' 8 | load a shard from the given path and convert to pytorch tensor 9 | ''' 10 | # need to use int64 so that torch cross entropy loss can work 11 | np_tensors = np.load(file_path).astype(np.int64) 12 | return torch.from_numpy(np_tensors) 13 | 14 | class DataLoader: 15 | def __init__(self, data_dir, batch_size, seq_len, split, device="cpu", shuffle=True): 16 | ''' 17 | data_dir: the directory of the data 18 | batch_size: the batch size 19 | seq_len: the sequence length 20 | split: the split of the data, train or val 21 | shuffle: whether to shuffle the shards 22 | ''' 23 | self.data_dir = data_dir 24 | self.batch_size = batch_size 25 | self.seq_len = seq_len 26 | self.split = split 27 | self.shuffle = shuffle 28 | self.device = device 29 | all_shards = os.listdir(data_dir) 30 | self.shards = [shard for shard in all_shards if split in shard] 31 | if shuffle: 32 | random.shuffle(self.shards) 33 | self.current_shard = 0 34 | self.current_pos = 0 35 | self.reset_status() 36 | 37 | 38 | def reset_status(self): 39 | ''' 40 | reset the status of the dataloader 41 | ''' 42 | self.current_shard = 0 43 | self.current_pos = 0 44 | if self.shuffle: 45 | random.shuffle(self.shards) 46 | self.tokens = load_shard(os.path.join(self.data_dir, self.shards[self.current_shard])) 47 | 48 | def next_batch(self): 49 | ''' 50 | return the next batch of data 51 | ''' 52 | batch_tokens = self.tokens[self.current_pos:self.current_pos+self.batch_size * self.seq_len+1] 53 | x = batch_tokens[:-1].view(self.batch_size, self.seq_len) 54 | y = batch_tokens[1:].view(self.batch_size, self.seq_len) 55 | self.current_pos += self.batch_size * self.seq_len 56 | if self.current_pos + self.batch_size * self.seq_len + 1 > len(self.tokens): 57 | self.current_shard = (self.current_shard + 1) % len(self.shards) 58 | self.current_pos = 0 59 | self.tokens = load_shard(os.path.join(self.data_dir, self.shards[self.current_shard])) 60 | return x.to(self.device), y.to(self.device) 61 | 62 | 63 | class FineWebEduDataLoader(DataLoader): 64 | def __init__(self, batch_size, seq_len, split, device="cpu", shuffle=True): 65 | data_dir = os.path.join(os.path.dirname(__file__), "data", "fineweb_edu", "edu_fineweb10B") 66 | print(f"Loading data from {data_dir}") 67 | super().__init__(data_dir, batch_size, seq_len, split, device, shuffle) 68 | 69 | class TinyShakespeareDataLoader(DataLoader): 70 | def __init__(self, batch_size, seq_len, split, device="cpu", shuffle=True): 71 | data_dir = os.path.join(os.path.dirname(__file__), "data", "tinyshakespeare") 72 | print(f"Loading data from {data_dir}") 73 | super().__init__(data_dir, batch_size, seq_len, split, device, shuffle) 74 | 75 | def reset_status(self): 76 | pass 77 | 78 | def next_batch(self): 79 | if self.split == "train": 80 | data = np.memmap(os.path.join(self.data_dir, "train.bin"), dtype=np.uint16, mode="r") 81 | else: 82 | data = np.memmap(os.path.join(self.data_dir, "val.bin"), dtype=np.uint16, mode="r") 83 | ix = torch.randint(len(data) - self.seq_len, (self.batch_size,)) 84 | x = torch.stack( 85 | [ 86 | torch.from_numpy((data[i : i + self.seq_len]).astype(np.int64)) 87 | for i in ix 88 | ] 89 | ) 90 | y = torch.stack( 91 | [ 92 | torch.from_numpy( 93 | (data[i + 1 : i + 1 + self.seq_len]).astype(np.int64) 94 | ) 95 | for i in ix 96 | ] 97 | ) 98 | if self.device == "cuda": 99 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 100 | # make the tensors in the non-pageable area 101 | x, y = x.pin_memory().to(self.device, non_blocking=True), y.pin_memory().to( 102 | self.device, non_blocking=True 103 | ) 104 | else: 105 | x, y = x.to(self.device), y.to(self.device) 106 | return x, y 107 | 108 | -------------------------------------------------------------------------------- /datacollator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tokenizer import Tokenizer 3 | import jinja2 4 | from dataclasses import dataclass 5 | # reference: https://github.com/huggingface/trl/blob/main/trl/models/utils.py#L44 6 | @dataclass 7 | class ChatMlSpecialTokens: 8 | """Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" 9 | 10 | bos_token: str = "<|im_start|>" 11 | eos_token: str = "<|im_end|>" 12 | pad_token: str = "<|im_end|>" 13 | 14 | @property 15 | def assistant(self): 16 | return f"{self.bos_token}assistant" 17 | 18 | @property 19 | def chat_template(self): 20 | """ 21 | the jinja2 template for the chatml format 22 | """ 23 | return ( 24 | "{% for message in messages %}" 25 | f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" 26 | "{% endfor %}" 27 | "{% if add_generation_prompt %}" 28 | f"{{{{ '{self.assistant}\n' }}}}" 29 | "{% endif %}" 30 | ) 31 | 32 | 33 | def format_input_text(input: list[dict[str, str]], add_generation_prompt: bool = False): 34 | """ 35 | format the input text with chat template to differentiate among different roles 36 | 37 | an exmple of input: 38 | [ 39 | {"role": "user", "content": "Hello, how are you?"}, 40 | {"role": "assistant", "content": "I'm fine, thank you!"}, 41 | {"role": "user", "content": "What is the capital of France?"}, 42 | ] 43 | 44 | when add_generation_prompt is False, the output should be: 45 | <|im_start|>user 46 | Hello, how are you?<|im_end|> 47 | <|im_start|>assistant 48 | I'm fine, thank you!<|im_end|> 49 | <|im_start|>user 50 | What is the capital of France?<|im_end|> 51 | 52 | when add_generation_prompt is True, the output should be: 53 | <|im_start|>user 54 | Hello, how are you?<|im_end|> 55 | <|im_start|>assistant 56 | I'm fine, thank you!<|im_end|> 57 | <|im_start|>assistant 58 | """ 59 | 60 | template = jinja2.Template(ChatMlSpecialTokens().chat_template) 61 | return template.render(messages=input, add_generation_prompt=add_generation_prompt) 62 | 63 | def pad(examples: list[torch.Tensor], pad_value: int): 64 | """ 65 | pad the input text to the max length of the batch 66 | """ 67 | max_length = max(len(example) for example in examples) 68 | padded_examples = [ 69 | torch.cat([example, torch.full((max_length - len(example),), pad_value)]) 70 | for example in examples 71 | ] 72 | return padded_examples 73 | 74 | class DataCollatorForChatMl: 75 | """ 76 | The data collator will primary do three things: 77 | 1. format the input text with chat template 78 | 2. pad the input text to the max length of the batch 79 | 3. set the label values for the non-assistant response tokens as ignore_index 80 | """ 81 | 82 | def __init__( 83 | self, 84 | tokenizer: Tokenizer, 85 | pad_token_id: int, 86 | ignore_index: int, 87 | assistant_response_format: str, 88 | end_token_id: int, 89 | ): 90 | self.tokenizer = tokenizer 91 | self.pad_token_id = pad_token_id 92 | self.ignore_index = ignore_index 93 | self.assistant_response_format = assistant_response_format 94 | self.end_token_id = end_token_id 95 | 96 | def process(self, examples: list[list[dict[str, str]]]): 97 | """ 98 | process a batch of examples 99 | """ 100 | formatted_examples = [format_input_text(example) for example in examples] 101 | tokenized_examples = [ 102 | self.tokenizer.encode(example) for example in formatted_examples 103 | ] 104 | input_ids = [torch.tensor(example[:-1]) for example in tokenized_examples] 105 | attention_mask = [torch.ones_like(input_id) for input_id in input_ids] 106 | labels = [torch.tensor(example[1:]) for example in tokenized_examples] 107 | 108 | input_ids = pad(input_ids, self.pad_token_id) 109 | attention_mask = pad(attention_mask, 0) 110 | labels = pad(labels, self.ignore_index) 111 | # mask out the non-assistant response tokens in labels 112 | labels = self.mask_labels(labels) 113 | batch = { 114 | "input_ids": torch.stack(input_ids), 115 | "labels": torch.stack(labels), 116 | "attention_mask": torch.stack(attention_mask), 117 | } 118 | return batch 119 | 120 | def mask_labels(self, labels: list[torch.Tensor]): 121 | """ 122 | mask the labels for the non-assistant response tokens 123 | """ 124 | response_ids = self.tokenizer.encode(self.assistant_response_format) 125 | for label in labels: 126 | start_ind = 0 127 | prev_assistant_response = False 128 | i = 0 129 | while i < len(label): 130 | if i < len(label) - len(response_ids) and torch.equal( 131 | label[i : i + len(response_ids)], torch.tensor(response_ids) 132 | ): 133 | label[start_ind : i + len(response_ids)] = self.ignore_index 134 | i += len(response_ids) 135 | prev_assistant_response = True 136 | elif ( 137 | torch.equal(label[i], torch.tensor(self.end_token_id)) 138 | and prev_assistant_response 139 | ): 140 | start_ind = i + 1 141 | prev_assistant_response = False 142 | i += 1 143 | else: 144 | i += 1 145 | label[start_ind:] = self.ignore_index 146 | return labels 147 | 148 | if __name__ == "__main__": 149 | tokenizer = Tokenizer('cl100k_base') 150 | tokenizer.add_special_tokens( 151 | [ChatMlSpecialTokens().bos_token, ChatMlSpecialTokens().eos_token] 152 | ) 153 | messages = [ 154 | [ 155 | {"role": "user", "content": "Hello, how are you?"}, 156 | {"role": "assistant", "content": "I'm fine, thank you!"}, 157 | {"role": "user", "content": "What is the capital of France?"}, 158 | ], 159 | [ 160 | {"role": "user", "content": "what is up?"}, 161 | {"role": "assistant", "content": "not much, just chilling"}, 162 | {"role": "user", "content": "what is your name?"}, 163 | ], 164 | ] 165 | 166 | formatted_input = format_input_text(messages[0]) 167 | print(f"formatted_input: {formatted_input}") 168 | 169 | token_ids = tokenizer.encode(formatted_input) 170 | print(f"token_ids: {token_ids}") 171 | 172 | token_strs = tokenizer.decode(token_ids) 173 | print(f"token_strs: {token_strs}") 174 | 175 | bos = ChatMlSpecialTokens().bos_token 176 | bos_encoded = tokenizer.encode(bos) 177 | print(f"bos_encoded: {bos_encoded}") 178 | 179 | eos = ChatMlSpecialTokens().eos_token 180 | eos_encoded = tokenizer.encode(eos)[0] 181 | print(f"eos_encoded: {eos_encoded}") 182 | 183 | data_collator = DataCollatorForChatMl( 184 | tokenizer, 185 | tokenizer.eos_token_id, 186 | -100, 187 | ChatMlSpecialTokens().assistant, 188 | tokenizer.eos_token_id, 189 | ) 190 | batch = data_collator.process(messages) 191 | print(f"batch: {batch}") -------------------------------------------------------------------------------- /deepseek.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from moe import MoE, FeedForward 4 | from mla import MultiHeadLatentAttention 5 | from config import DeepSeekConfig 6 | import torch.nn.functional as F 7 | from typing import Optional 8 | from mla import KVCache 9 | import json 10 | 11 | 12 | class Block(nn.Module): 13 | def __init__(self, config: DeepSeekConfig, block_idx: int): 14 | super().__init__() 15 | self.self_attn = MultiHeadLatentAttention(config, layer_idx=block_idx) 16 | self.mlp = ( 17 | MoE(config) 18 | if block_idx >= config.first_k_dense_replace 19 | else FeedForward(config.d_model, config.mlp_hidden_dimension) 20 | ) 21 | self.input_layernorm = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps) 22 | self.post_attention_layernorm = nn.RMSNorm( 23 | config.d_model, eps=config.rms_norm_eps 24 | ) 25 | 26 | def forward( 27 | self, 28 | x: torch.tensor, 29 | past_key_value: Optional[KVCache] = None, 30 | attention_mask: Optional[torch.tensor] = None, 31 | ): 32 | """ 33 | args: 34 | x: (B, T, d_model) 35 | past_key_value (KVCache, optional): when it is None, KV cache will not be used 36 | attention_mask (torch.tensor, optional): (B, T) 37 | return: 38 | x: (B, T, d_model) 39 | past_key_value (KVCache, optional): None or updated KVCache. 40 | """ 41 | identity = x 42 | x, past_key_value = self.self_attn( 43 | self.input_layernorm(x), past_key_value, attention_mask 44 | ) 45 | if past_key_value is not None: 46 | print(f"past_key_value shape: {past_key_value.key_cache[0].shape}") 47 | x = identity + x 48 | x = x + self.mlp(self.post_attention_layernorm(x)) 49 | return x, past_key_value 50 | 51 | 52 | class DeepSeekModel(nn.Module): 53 | def __init__(self, config: DeepSeekConfig): 54 | super().__init__() 55 | self.layers = nn.ModuleList( 56 | [Block(config, i) for i in range(config.num_layers)] 57 | ) 58 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) 59 | self.dropout = nn.Dropout(config.dropout) 60 | self.max_position_embeddings = config.max_position_embeddings 61 | self.init_weight_std = config.init_weight_std 62 | self.norm = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps) 63 | 64 | def forward( 65 | self, 66 | x: torch.tensor, 67 | past_key_value: Optional[KVCache] = None, 68 | attention_mask: Optional[torch.tensor] = None, 69 | ) -> torch.tensor: 70 | """ 71 | Args: 72 | x: (B, T) 73 | targets: (B, T) 74 | past_key_value: (KVCache, optional) 75 | attention_mask: (B, T) 76 | return: (B, T, vocab_size) 77 | """ 78 | B, T = x.shape 79 | assert ( 80 | T <= self.max_position_embeddings 81 | ), f"Sequence length {T} cannot exceed block size {self.max_position_embeddings}" 82 | 83 | x = self.embed_tokens(x) 84 | x = self.dropout(x) 85 | for layer in self.layers: 86 | x, past_key_value = layer(x, past_key_value, attention_mask) 87 | return self.norm(x), past_key_value 88 | 89 | 90 | class DeepSeekModelForCausalLM(nn.Module): 91 | def __init__(self, config: DeepSeekConfig): 92 | super().__init__() 93 | self.topk = config.topk 94 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 95 | self.model = DeepSeekModel(config) 96 | self.config = config 97 | # initialize weights 98 | self.init_weight_std = config.init_weight_std 99 | self.apply(self._init_weights) 100 | 101 | def _init_weights(self, module): 102 | if isinstance(module, nn.Linear): 103 | module.weight.data.normal_(mean=0.0, std=self.init_weight_std) 104 | if module.bias is not None: 105 | module.bias.data.zero_() 106 | elif isinstance(module, nn.Embedding): 107 | module.weight.data.normal_(mean=0.0, std=self.init_weight_std) 108 | 109 | def get_total_parameters(self): 110 | total_params = 0 111 | activated_params = 0 112 | routed_moe_module_name = "mlp.experts" 113 | activated_routed_moe_module_name = [ 114 | f"{routed_moe_module_name}.{i}" for i in range(self.topk) 115 | ] 116 | 117 | def is_activated_routed_moe_module(name: str): 118 | for activated_routed_moe_module in activated_routed_moe_module_name: 119 | if name.find(activated_routed_moe_module) != -1: 120 | return True 121 | return False 122 | 123 | for name, param in self.named_parameters(): 124 | if param.requires_grad: 125 | total_params += param.numel() 126 | if not name.find( 127 | routed_moe_module_name 128 | ) == -1 or is_activated_routed_moe_module(name): 129 | activated_params += param.numel() 130 | return total_params, activated_params 131 | 132 | def forward( 133 | self, 134 | x: torch.tensor, 135 | targets: torch.tensor = None, 136 | past_key_value: Optional[KVCache] = None, 137 | attention_mask: Optional[torch.tensor] = None, 138 | ) -> torch.tensor: 139 | """ 140 | Args: 141 | x: (B, T) 142 | targets: (B, T) 143 | past_key_value: (KVCache, optional) 144 | attention_mask: (B, T) 145 | """ 146 | x, past_key_value = self.model(x, past_key_value, attention_mask) 147 | if targets is not None: 148 | logits = self.lm_head(x) 149 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 150 | else: 151 | # for inference, only the last token logits is used for prediction the next token 152 | logits = self.lm_head(x[:, [-1], :]) 153 | loss = None 154 | return logits, loss, past_key_value 155 | 156 | @torch.no_grad() 157 | def generate(self, input: torch.tensor, max_length: int, temperature: float = 1.0): 158 | x = input 159 | kv_cache = KVCache(self.config.num_layers) 160 | for _ in range(max_length): 161 | # use kv cache 162 | logits, _, kv_cache = self(x, past_key_value=kv_cache) 163 | # [B, vocab_size] 164 | logits = logits[:, -1, :] / temperature 165 | probs = F.softmax(logits, dim=-1) 166 | next_token = torch.multinomial(probs, num_samples=1) 167 | x = next_token 168 | input = torch.cat([input, next_token], dim=1) 169 | return input 170 | 171 | 172 | if __name__ == "__main__": 173 | with open("config.json", "r") as f: 174 | config = json.load(f) 175 | config = DeepSeekConfig(**config) 176 | input = torch.randint(0, config.vocab_size, (2, 2)).to(config.device) 177 | targets = torch.randint(0, config.vocab_size, (2, 2)).to(config.device) 178 | model = DeepSeekModelForCausalLM(config).to(config.device) 179 | output, loss, past_key_value = model(input, targets) 180 | print(f"when targets is not None: output shape: {output.shape}, loss: {loss}") 181 | targets = None 182 | model.eval() 183 | past_key_value = KVCache(config.num_layers) 184 | output, loss, past_key_value = model(input, targets, past_key_value) 185 | print("-" * 100) 186 | print("When targets is None") 187 | print(f"output shape: {output.shape}") 188 | print(f"loss: {loss}") 189 | print(f"KV Cache shape: {past_key_value.key_cache[0].shape}") 190 | print("-" * 100) 191 | print("State dict") 192 | sd = model.state_dict() 193 | sd_keys = sd.keys() 194 | for key in sd_keys: 195 | print(f"{key}: {sd[key].shape}") 196 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import nullcontext 3 | import os 4 | from training_config import TrainingConfig 5 | from torch.utils.data import DataLoader 6 | import json 7 | from deepseek import DeepSeekModelForCausalLM, DeepSeekConfig 8 | import bitsandbytes as bnb 9 | import math 10 | import torch.nn as nn 11 | 12 | ptdtype = { 13 | "float32": torch.float32, 14 | "bfloat16": torch.bfloat16, 15 | "float16": torch.float16, 16 | } 17 | 18 | 19 | def get_model_config(model_config_path: str): 20 | with open(model_config_path, "r") as f: 21 | config = json.load(f) 22 | return config 23 | 24 | 25 | def get_wandb_config(training_config: TrainingConfig): 26 | config = { 27 | "batch_size": training_config.batch_size, 28 | "learning_rate": training_config.learning_rate, 29 | "use_fused_adamw": training_config.adamw_use_fused, 30 | } 31 | config.update(get_model_config(training_config.model_config_path)) 32 | return config 33 | 34 | 35 | def get_model(model_config_path: str): 36 | config = get_model_config(model_config_path) 37 | model = DeepSeekModelForCausalLM(DeepSeekConfig(**config)) 38 | total_params, activated_params = model.get_total_parameters() 39 | print(f"Total parameters: {total_params:,}") 40 | print(f"Activated parameters: {activated_params:,}") 41 | print(f"Activated parameters ratio: {activated_params / total_params:.2%}") 42 | return model 43 | 44 | 45 | def configure_optimizers( 46 | model, weight_decay, learning_rate, betas, fused, use_eight_bit_optimizer=False 47 | ): 48 | # start with all of the candidate parameters 49 | param_dict = {pn: p for pn, p in model.named_parameters()} 50 | # filter out those that do not require grad 51 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 52 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwis no. 53 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 54 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 55 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 56 | optim_groups = [ 57 | {"params": decay_params, "weight_decay": weight_decay}, 58 | {"params": nodecay_params, "weight_decay": 0.0}, 59 | ] 60 | num_decay_params = sum(p.numel() for p in decay_params) 61 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 62 | print( 63 | f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" 64 | ) 65 | print( 66 | f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" 67 | ) 68 | # Create AdamW optimizer and use the fused version if it is available 69 | if use_eight_bit_optimizer: 70 | # fuse is not supported 71 | optimizer = bnb.optim.AdamW8bit(optim_groups, lr=learning_rate, betas=betas) 72 | else: 73 | optimizer = torch.optim.AdamW( 74 | optim_groups, lr=learning_rate, betas=betas, fused=fused 75 | ) 76 | return optimizer 77 | 78 | 79 | # learning rate decay scheduler (cosine with warmup) 80 | def get_lr(it, warmup_iters, lr_decay_iters, learning_rate, min_lr): 81 | # 1) linear warmup for warmup_iters steps 82 | if it < warmup_iters: 83 | return learning_rate * (it + 1) / (warmup_iters + 1) 84 | # 2) if it > lr_decay_iters, return min learning rate 85 | if it > lr_decay_iters: 86 | return min_lr 87 | # 3) in between, use cosine decay down to min learning rate 88 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 89 | assert 0 <= decay_ratio <= 1 90 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 91 | return min_lr + coeff * (learning_rate - min_lr) 92 | 93 | 94 | @torch.no_grad() 95 | def evalaute( 96 | model: nn.Module, 97 | eval_iters: int, 98 | train_dataloader: DataLoader, 99 | eval_dataloader: DataLoader, 100 | cur_train_iter: iter, 101 | cur_eval_iter: iter, 102 | device: str, 103 | ): 104 | model.eval() 105 | loss_iters = torch.zeros(eval_iters) 106 | losses = {} 107 | for i in range(eval_iters): 108 | eval_batch, cur_eval_iter = get_next_batch(eval_dataloader, cur_eval_iter) 109 | _, loss, _ = model( 110 | x = eval_batch["input_ids"].to(device), 111 | targets = eval_batch["labels"].to(device), 112 | past_key_value = None, 113 | attention_mask = eval_batch["attention_mask"].to(device), 114 | ) 115 | loss_iters[i] = loss.item() 116 | losses["eval"] = loss_iters.mean() 117 | for i in range(eval_iters): 118 | train_batch, cur_train_iter = get_next_batch(train_dataloader, cur_train_iter) 119 | _, loss, _ = model( 120 | x = train_batch["input_ids"].to(device), 121 | targets = train_batch["labels"].to(device), 122 | past_key_value = None, 123 | attention_mask = train_batch["attention_mask"].to(device), 124 | ) 125 | loss_iters[i] = loss.item() 126 | losses["train"] = loss_iters.mean() 127 | model.train() 128 | return losses, cur_train_iter, cur_eval_iter 129 | 130 | 131 | def get_next_batch(dataloader, current_iter): 132 | try: 133 | batch = next(current_iter) 134 | except StopIteration: 135 | # Reset the iterator 136 | current_iter = iter(dataloader) 137 | batch = next(current_iter) 138 | 139 | return batch, current_iter 140 | 141 | 142 | class Trainer: 143 | def __init__( 144 | self, 145 | train_dataloader: DataLoader, 146 | eval_dataloader: DataLoader, 147 | training_config: TrainingConfig, 148 | ): 149 | global ptdtype 150 | self.training_config = training_config 151 | self.train_dataloader = train_dataloader 152 | self.eval_dataloader = eval_dataloader 153 | self.ctx = ( 154 | nullcontext() 155 | if training_config.device == "cpu" 156 | # mixed precision training 157 | else torch.amp.autocast( 158 | device_type=training_config.device, dtype=ptdtype[training_config.dtype] 159 | ) 160 | ) 161 | 162 | def train(self): 163 | if self.training_config.wandb_log: 164 | import wandb 165 | 166 | wandb.init( 167 | project=self.training_config.wandb_project, 168 | name=self.training_config.wandb_run_name, 169 | config=get_wandb_config(self.training_config), 170 | ) 171 | model = get_model(self.training_config.model_config_path) 172 | model.to(self.training_config.device) 173 | optimizer = configure_optimizers( 174 | model, 175 | self.training_config.adamw_weight_decay, 176 | self.training_config.learning_rate, 177 | (self.training_config.adamw_beta1, self.training_config.adamw_beta2), 178 | self.training_config.adamw_use_fused, 179 | self.training_config.use_eight_bit_optimizer, 180 | ) 181 | optimizer.zero_grad(set_to_none=True) 182 | best_val_loss = 1e9 183 | iter_num = 0 184 | if self.training_config.resume: 185 | checkpoint = torch.load( 186 | os.path.join(self.training_config.out_dir, "ckpt.pt") 187 | ) 188 | config = checkpoint["model_config"] 189 | model = DeepSeekModelForCausalLM(config) 190 | model.to(self.training_config.device) 191 | model.load_state_dict(checkpoint["model"]) 192 | optimizer.load_state_dict(checkpoint["optimizer"]) 193 | best_val_loss = checkpoint["best_val_loss"] 194 | iter_num = checkpoint["iter_num"] 195 | 196 | cur_train_iter = iter(self.train_dataloader) 197 | cur_eval_iter = iter(self.eval_dataloader) 198 | while iter_num < self.training_config.max_train_steps: 199 | # determine and set the learning rate for this iteration 200 | lr = ( 201 | get_lr( 202 | iter_num, 203 | self.training_config.warmup_iters, 204 | self.training_config.lr_decay_iters, 205 | self.training_config.learning_rate, 206 | self.training_config.min_learning_rate, 207 | ) 208 | if self.training_config.decay_lr 209 | else self.training_config.learning_rate 210 | ) 211 | for param_group in optimizer.param_groups: 212 | param_group["lr"] = lr 213 | for _ in range(self.training_config.gradient_accumulation_steps): 214 | batch, cur_train_iter = get_next_batch( 215 | self.train_dataloader, cur_train_iter 216 | ) 217 | with self.ctx: 218 | _, train_loss, _ = model( 219 | x = batch["input_ids"].to(self.training_config.device), 220 | targets = batch["labels"].to(self.training_config.device), 221 | past_key_value = None, 222 | attention_mask = batch["attention_mask"].to(self.training_config.device), 223 | ) 224 | train_loss = ( 225 | train_loss / self.training_config.gradient_accumulation_steps 226 | ) 227 | train_loss.backward() 228 | if self.training_config.grad_clip > 0: 229 | torch.nn.utils.clip_grad_norm_( 230 | model.parameters(), self.training_config.grad_clip 231 | ) 232 | optimizer.step() 233 | optimizer.zero_grad(set_to_none=True) 234 | if (iter_num + 1) % self.training_config.eval_interval == 0: 235 | losses, cur_train_iter, cur_eval_iter = evalaute( 236 | model, 237 | self.training_config.eval_iters, 238 | self.train_dataloader, 239 | self.eval_dataloader, 240 | cur_train_iter, 241 | cur_eval_iter, 242 | self.training_config.device, 243 | ) 244 | if self.training_config.wandb_log: 245 | wandb.log( 246 | { 247 | "Step": iter_num, 248 | "Train Loss": losses["train"], 249 | "Val Loss": losses["eval"], 250 | "Learning Rate": lr, 251 | } 252 | ) 253 | if losses["eval"] < best_val_loss: 254 | best_val_loss = losses["eval"] 255 | if iter_num > 0: 256 | checkpoint = { 257 | "model": model.state_dict(), 258 | "optimizer": optimizer.state_dict(), 259 | "model_config": model.config, 260 | "iter_num": iter_num, 261 | "best_val_loss": best_val_loss, 262 | "training_config": self.training_config, 263 | } 264 | torch.save( 265 | checkpoint, 266 | os.path.join( 267 | self.training_config.out_dir, 268 | self.training_config.checkpoint_path, 269 | ), 270 | ) 271 | print( 272 | f"step {iter_num+1}: train loss: {losses['train']:.4f}, val loss: {losses['eval']:.4f}" 273 | ) 274 | iter_num += 1 275 | -------------------------------------------------------------------------------- /moe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import RMSNorm 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from config import DeepSeekConfig 7 | import json 8 | 9 | class Distributor(object): 10 | def __init__(self, gates: torch.tensor, topk: int): 11 | super().__init__() 12 | # gates is [B*T, num_experts] 13 | self.topk = topk 14 | # [B*T*topk, 2] 15 | batch_and_experts_indices = torch.nonzero(gates) 16 | # sort the batch and experts indices along the first dimension, batch_and_experts_indices is a list 17 | # of tuples, where the first element is the batch index and the second element is the expert index 18 | # by sorting along the first dimension, we will let the same assigned expert index be adjacent 19 | # and then we will use this order to reorder the input tensors 20 | # then finally after splitting the recordered input tensors, the first group is for first expert, 21 | # the second group is for second expert, etc. 22 | # use stable sort to guarantee the relative order of the same element so that we could get the weight for the expert output 23 | # with colume-wise order from left to right 24 | # [B*topk, 2] 25 | sorted_experts, index_sorted_experts = batch_and_experts_indices.sort( 26 | dim=0, stable=True 27 | ) 28 | # get the order indices before sorting 29 | # [B*T*topk] one dimension tensor 30 | old_expert_indices = index_sorted_experts[:, 1] 31 | # find the batch index from the order of sorted experts 32 | # it will be used for the input tensors to make sure the tokens that assigned to the same expert are adjacent 33 | # and then use the _groups to split the input tensors 34 | # [B*T*topk] one dimension tensor 35 | self._batch_indices = batch_and_experts_indices[:, 0][old_expert_indices] 36 | # get the number of tokens assigned for each expert 37 | # [num_experts] one dimension tensor 38 | self._groups = (gates > 0).sum(dim=0).tolist() 39 | # get the weights for each expert output. It just get the non zero elements from the gates for each column from left to right 40 | # [B*T*topk, 1] 41 | self._weights = gates.t().reshape(-1)[gates.t().reshape(-1) > 0].view(-1, 1) 42 | 43 | def prepare_inputs_for_experts(self, x: torch.tensor) -> list[torch.tensor]: 44 | expanded_x = x[self._batch_indices] 45 | return expanded_x.split(self._groups) 46 | 47 | def combine(self, expert_outputs: list[torch.tensor]) -> torch.tensor: 48 | # [B*topk, d_model] 49 | combined_output = torch.cat(expert_outputs, dim=0) 50 | # apply the weights to the expert outputs 51 | # [B*topk, d_model] 52 | combined_output = combined_output * self._weights 53 | # use index_add to add results for each token and the index is _batch_indices 54 | # [B, d_model] 55 | output = torch.zeros( 56 | combined_output.shape[0] // self.topk, 57 | combined_output.shape[1], 58 | dtype=combined_output.dtype, 59 | ).to(combined_output.device) 60 | output.index_add_(0, self._batch_indices, combined_output) 61 | return output 62 | 63 | 64 | class FeedForward(nn.Module): 65 | def __init__(self, d_model: int, hidden_dimension: int): 66 | super().__init__() 67 | self.gate_proj = nn.Linear(d_model, hidden_dimension, bias=False) 68 | self.up_proj = nn.Linear(d_model, hidden_dimension, bias=False) 69 | self.down_proj = nn.Linear(hidden_dimension, d_model, bias=False) 70 | self.activation = nn.SiLU() 71 | 72 | def forward(self, x: torch.Tensor): 73 | # this is different from the FeedForward in Transformer paper 74 | # not sure why DeepSeek use the gate_proj instead of activation(up_proj(x)) 75 | # to be able to load the checkpoint, follow their implementation at 76 | # https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py#L389 77 | x = self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) 78 | return x 79 | 80 | 81 | class AddAuxiliaryLoss(torch.autograd.Function): 82 | """ 83 | Copied from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py#L500 84 | """ 85 | 86 | @staticmethod 87 | def forward(ctx, x, loss): 88 | assert loss.numel() == 1 89 | ctx.dtype = loss.dtype 90 | ctx.required_aux_loss = loss.requires_grad 91 | return x 92 | 93 | @staticmethod 94 | def backward(ctx, grad_output): 95 | grad_loss = None 96 | if ctx.required_aux_loss: 97 | # when we requuired aux loss, this grad loss is for the gradient of the second input of forward 98 | # which is the auxiliary loss 99 | # effectively since the grad is 1, the aux loss is added to the loss 100 | grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) 101 | return grad_output, grad_loss 102 | 103 | 104 | class MoE(nn.Module): 105 | def __init__(self, config: DeepSeekConfig): 106 | super().__init__() 107 | 108 | self.num_shared_experts = config.num_shared_experts 109 | self.moe_hidden_dimension = config.moe_hidden_dimension 110 | 111 | self.topk = config.topk 112 | self.num_routed_experts = config.num_routed_experts 113 | 114 | # the weights intialization for deepseek is 0.006 115 | # from https://arxiv.org/abs/2401.06066 116 | # the number of potential routed experts is total_num_experts * num_smaller_experts_per_expert - num_shared_experts 117 | self.experts_weights = nn.Parameter( 118 | torch.randn( 119 | self.num_routed_experts, 120 | config.d_model, 121 | ) 122 | * 0.006 123 | ) 124 | 125 | # routed experts 126 | self.experts = nn.ModuleList( 127 | [ 128 | FeedForward(config.d_model, self.moe_hidden_dimension) 129 | for _ in range(self.num_routed_experts) 130 | ] 131 | ) 132 | 133 | self.shared_experts = FeedForward( 134 | config.d_model, self.moe_hidden_dimension * self.num_shared_experts 135 | ) 136 | 137 | self.topk_norm_epsilon = config.topk_norm_epsilon 138 | self.normalized_moe_gates = config.normalized_moe_gates 139 | self.expert_load_balance_factor = config.expert_load_balance_factor 140 | 141 | def forward(self, x: torch.tensor) -> torch.tensor: 142 | """ 143 | x: tensor of shape [B, T, d_model] 144 | """ 145 | return self._forward_optimized(x) 146 | 147 | def _forward_forloop(self, x: torch.tensor): 148 | """ 149 | x: tensor of shape [B, T, d_model] 150 | """ 151 | B, T = x.shape[0], x.shape[1] 152 | # [B* T, d_model] 153 | x = x.view(B * T, -1) 154 | # first get the output for the routed MoE and then added up the results from shared MoE 155 | # [B * T, total_routed_experts] 156 | routed_experts_output = F.linear(x, self.experts_weights) 157 | scores = F.softmax(routed_experts_output, dim=-1) 158 | # apply gate along the expert dimension to get the top k experts for each token 159 | # top_values: B * T, topk, this is the score for each expert 160 | # top_indices: B * T, topk, this is the index to find the corresponding expert 161 | top_values, top_indices = torch.topk(scores, k=self.topk, dim=-1, sorted=False) 162 | routed_experts_output = torch.zeros_like(x, dtype=x.dtype).to(x.device) 163 | for i in range(x.shape[0]): 164 | for j in range(self.topk): 165 | routed_experts_output[i] += ( 166 | self.experts[top_indices[i, j]](x[i]) * top_values[i, j] 167 | ) 168 | 169 | shared_experts_output = self.shared_experts(x) 170 | 171 | # the output is sum of shared expert output and routed expert output plus the residual connection 172 | output = routed_experts_output + shared_experts_output 173 | output = output.view(B, T, -1) # [B, T, d_model] 174 | 175 | return output 176 | 177 | def _forward_optimized(self, x: torch.tensor): 178 | """ 179 | In the for loop implementation, the expert will transform the input tensor one by one. 180 | One optimization is to batch all input tensors for a given expert together and let the expert transform them with matrix multiplication. 181 | 182 | So in this implementation, we will 183 | 1. first batch input tensors for each expert 184 | 2. loop through each expert and transform its input tensors with matrix multiplication 185 | 3. for each token, since it might be routed to multiple experts, we need to sum up its resutls with index_add function 186 | 187 | The reference for this implementation is 188 | https://github.com/davidmrau/mixture-of-experts/blob/master/moe.py 189 | 190 | Args: 191 | x: tensor of shape [B, T, d_model] 192 | Returns: 193 | output: tensor of shape [B, T, d_model] 194 | """ 195 | 196 | # combine the batch and time dimension 197 | B, T = x.shape[0], x.shape[1] 198 | # [B* T, d_model] 199 | x = x.view(B * T, -1) 200 | gates = F.linear(x, self.experts_weights) 201 | gates = F.softmax(gates, dim=-1) 202 | top_values, top_indices = torch.topk(gates, k=self.topk, dim=-1, sorted=False) 203 | # [B * T, num_experts] 204 | masked_gates = torch.zeros_like(gates, dtype=gates.dtype).to(gates.device) 205 | masked_gates = torch.scatter(masked_gates, 1, top_indices, top_values) 206 | if self.normalized_moe_gates: 207 | # renormalize the masked gates 208 | masked_gates = masked_gates / ( 209 | masked_gates.sum(dim=-1, keepdim=True) + self.topk_norm_epsilon 210 | ) 211 | distributor = Distributor(masked_gates, self.topk) 212 | routed_expert_inputs = distributor.prepare_inputs_for_experts(x) 213 | routed_expert_outputs = [ 214 | self.experts[i](routed_expert_inputs[i]) 215 | for i in range(self.num_routed_experts) 216 | ] 217 | # [B*T, d_model] 218 | routed_combined_outputs = distributor.combine(routed_expert_outputs) 219 | routed_combined_outputs = routed_combined_outputs.view(B, T, -1) 220 | if self.training: 221 | # get the expert load balance loss. The definition can be found in https://arxiv.org/abs/2401.06066 222 | masked_gates = masked_gates.view(B, T, -1) 223 | gates = gates.view(B, T, -1) 224 | load = (masked_gates > 0).sum(dim=1) 225 | expert_prob_sum = gates.sum(dim=1) 226 | expert_load_balance_loss = self.expert_load_balance_factor * ( 227 | (self.num_routed_experts / (self.topk * T) * load) 228 | * (1.0 / T * expert_prob_sum) 229 | ).sum(dim=1) 230 | expert_load_balance_loss = expert_load_balance_loss.mean() 231 | routed_combined_outputs = AddAuxiliaryLoss.apply( 232 | routed_combined_outputs, expert_load_balance_loss 233 | ) 234 | shared_expert_outputs = self.shared_experts(x).view(B, T, -1) 235 | output = routed_combined_outputs + shared_expert_outputs 236 | return output 237 | 238 | 239 | if __name__ == "__main__": 240 | with open("config.json", "r") as f: 241 | config = json.load(f) 242 | config = DeepSeekConfig(**config) 243 | input = torch.randn(2, 2, 1024).to(config.device) 244 | model = MoE(config) 245 | model = model.to(config.device) 246 | output = model(input) 247 | print(f"MoE output shape: {output.shape}") 248 | sd = model.state_dict() 249 | sd_keys = sd.keys() 250 | for key in sd_keys: 251 | print(f"{key}: {sd[key].shape}") 252 | -------------------------------------------------------------------------------- /pre_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import json 6 | import tiktoken 7 | from config import DeepSeekConfig 8 | from deepseek import DeepSeekModelForCausalLM 9 | import torch.nn as nn 10 | import time 11 | import math 12 | from contextlib import nullcontext 13 | import bitsandbytes as bnb 14 | from enum import Enum 15 | from dataloader import TinyShakespeareDataLoader, FineWebEduDataLoader, DataLoader 16 | 17 | 18 | class DatasetName(Enum): 19 | TINY_SHAKESPEARE = "tinyshakespeare" 20 | FINE_WEB_EDU = "fineweb_edu" 21 | 22 | ptdtype = { 23 | "float32": torch.float32, 24 | "bfloat16": torch.bfloat16, 25 | "float16": torch.float16, 26 | } 27 | 28 | 29 | def configure_optimizers( 30 | model, weight_decay, learning_rate, betas, fused, use_eight_bit_optimizer=False 31 | ): 32 | # start with all of the candidate parameters 33 | param_dict = {pn: p for pn, p in model.named_parameters()} 34 | # filter out those that do not require grad 35 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 36 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwis no. 37 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 38 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 39 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 40 | optim_groups = [ 41 | {"params": decay_params, "weight_decay": weight_decay}, 42 | {"params": nodecay_params, "weight_decay": 0.0}, 43 | ] 44 | num_decay_params = sum(p.numel() for p in decay_params) 45 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 46 | print( 47 | f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" 48 | ) 49 | print( 50 | f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" 51 | ) 52 | # Create AdamW optimizer and use the fused version if it is available 53 | if use_eight_bit_optimizer: 54 | # fuse is not supported 55 | optimizer = bnb.optim.AdamW8bit( 56 | optim_groups, lr=learning_rate, betas=betas 57 | ) 58 | else: 59 | optimizer = torch.optim.AdamW( 60 | optim_groups, lr=learning_rate, betas=betas, fused=fused 61 | ) 62 | return optimizer 63 | 64 | 65 | # learning rate decay scheduler (cosine with warmup) 66 | def get_lr(it, warmup_iters, lr_decay_iters, learning_rate, min_lr): 67 | # 1) linear warmup for warmup_iters steps 68 | if it < warmup_iters: 69 | return learning_rate * (it + 1) / (warmup_iters + 1) 70 | # 2) if it > lr_decay_iters, return min learning rate 71 | if it > lr_decay_iters: 72 | return min_lr 73 | # 3) in between, use cosine decay down to min learning rate 74 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 75 | assert 0 <= decay_ratio <= 1 76 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 77 | return min_lr + coeff * (learning_rate - min_lr) 78 | 79 | 80 | def get_dataloader(dataset_name, split, batch_size, seq_len, device): 81 | if dataset_name == DatasetName.TINY_SHAKESPEARE.value: 82 | dataloader = TinyShakespeareDataLoader(batch_size, seq_len, split, device) 83 | elif dataset_name == DatasetName.FINE_WEB_EDU.value: 84 | dataloader = FineWebEduDataLoader(batch_size, seq_len, split, device) 85 | else: 86 | raise ValueError(f"Dataset {dataset_name} is not supported") 87 | return dataloader 88 | 89 | 90 | @torch.no_grad() 91 | def evalaute( 92 | model: nn.Module, 93 | eval_iters: int, 94 | eval_dataloader: DataLoader, 95 | ): 96 | model.eval() 97 | losses = torch.zeros(eval_iters) 98 | for i in range(eval_iters): 99 | x, y = eval_dataloader.next_batch() 100 | _, loss, _ = model(x, y) 101 | losses[i] = loss.item() 102 | model.train() 103 | return losses.mean() 104 | 105 | 106 | def get_model_config(args): 107 | with open("config.json", "r") as f: 108 | config = json.load(f) 109 | return config 110 | 111 | 112 | def get_model(args): 113 | config = get_model_config(args) 114 | model = DeepSeekModelForCausalLM(DeepSeekConfig(**config)) 115 | model.to(args.device) 116 | total_params, activated_params = model.get_total_parameters() 117 | print(f"Total parameters: {total_params:,}") 118 | print(f"Activated parameters: {activated_params:,}") 119 | print(f"Activated parameters ratio: {activated_params / total_params:.2%}") 120 | return model 121 | 122 | 123 | def forward_and_backward(model, x, y, optimizer, ctx: torch.autocast = nullcontext()): 124 | with ctx: 125 | _, loss, _ = model(x, y) 126 | loss.backward() 127 | optimizer.step() 128 | optimizer.zero_grad(set_to_none=True) 129 | 130 | 131 | def train(args): 132 | global ptdtype 133 | ptdtype = ptdtype[args.dtype] 134 | ctx = ( 135 | nullcontext() 136 | if args.device == "cpu" 137 | else torch.amp.autocast(device_type=args.device, dtype=ptdtype) 138 | ) 139 | if args.wandb_log: 140 | import wandb 141 | 142 | wandb.init( 143 | project=args.wandb_project, 144 | name=args.wandb_run_name, 145 | config=get_wandb_config(args), 146 | ) 147 | train_loader = get_dataloader(args.dataset, "train", args.batch_size, args.max_position_embeddings, args.device) 148 | val_loader = get_dataloader(args.dataset, "val", args.batch_size, args.max_position_embeddings, args.device) 149 | model = get_model(args) 150 | optimizer = configure_optimizers( 151 | model, 152 | args.adamw_weight_decay, 153 | args.learning_rate, 154 | (args.adamw_beta1, args.adamw_beta2), 155 | args.adamw_use_fused, 156 | args.use_eight_bit_optimizer, 157 | ) 158 | optimizer.zero_grad(set_to_none=True) 159 | best_val_loss = 1e9 160 | iter_num = 0 161 | if args.resume: 162 | checkpoint = torch.load(os.path.join(args.out_dir, "ckpt.pt")) 163 | config = checkpoint["model_config"] 164 | model = DeepSeekModelForCausalLM(config) 165 | model.to(args.device) 166 | model.load_state_dict(checkpoint["model"]) 167 | optimizer.load_state_dict(checkpoint["optimizer"]) 168 | best_val_loss = checkpoint["best_val_loss"] 169 | iter_num = checkpoint["iter_num"] 170 | while iter_num < args.max_train_steps: 171 | # determine and set the learning rate for this iteration 172 | lr = ( 173 | get_lr( 174 | iter_num, 175 | args.warmup_iters, 176 | args.lr_decay_iters, 177 | args.learning_rate, 178 | args.min_learning_rate, 179 | ) 180 | if args.decay_lr 181 | else args.learning_rate 182 | ) 183 | for param_group in optimizer.param_groups: 184 | param_group["lr"] = lr 185 | for _ in range(args.gradient_accumulation_steps): 186 | x, y = train_loader.next_batch() 187 | with ctx: 188 | _, train_loss, _ = model(x, y) 189 | train_loss = train_loss / args.gradient_accumulation_steps 190 | train_loss.backward() 191 | if args.grad_clip > 0: 192 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 193 | optimizer.step() 194 | optimizer.zero_grad(set_to_none=True) 195 | if (iter_num + 1) % args.eval_interval == 0: 196 | val_loss = evalaute( 197 | model, 198 | args.eval_iters, 199 | val_loader, 200 | ) 201 | if args.wandb_log: 202 | wandb.log( 203 | { 204 | "Step": iter_num, 205 | "Train Loss": train_loss.item(), 206 | "Val Loss": val_loss, 207 | "Learning Rate": lr, 208 | } 209 | ) 210 | if val_loss < best_val_loss: 211 | best_val_loss = val_loss 212 | if iter_num > 0: 213 | checkpoint = { 214 | "model": model.state_dict(), 215 | "optimizer": optimizer.state_dict(), 216 | "model_config": model.config, 217 | "iter_num": iter_num, 218 | "best_val_loss": best_val_loss, 219 | "training_args": args, 220 | } 221 | torch.save( 222 | checkpoint, os.path.join(args.out_dir, args.checkpoint_path) 223 | ) 224 | print( 225 | f"step {iter_num+1}: train loss: {train_loss.item():.4f}, val loss: {val_loss:.4f}" 226 | ) 227 | iter_num += 1 228 | 229 | 230 | def get_wandb_config(args): 231 | config = { 232 | "batch_size": args.batch_size, 233 | "learning_rate": args.learning_rate, 234 | "use_fused_adamw": args.adamw_use_fused, 235 | } 236 | config.update(get_model_config(args)) 237 | return config 238 | 239 | 240 | def estimate_throughput(args): 241 | if args.wandb_log: 242 | import wandb 243 | 244 | wandb.init( 245 | project=args.wandb_project, 246 | name=args.wandb_run_name, 247 | config=get_wandb_config(args), 248 | ) 249 | data_dir = os.path.join("data", args.dataset) 250 | model = get_model(args) 251 | optimizer = torch.optim.AdamW( 252 | model.parameters(), lr=args.learning_rate, fused=args.adamw_use_fused 253 | ) 254 | total_tokens = 0 255 | total_time = 0 256 | torch.cuda.synchronize() 257 | start_time = time.time() 258 | train_loader = get_dataloader(args.dataset, "train", args.batch_size, args.max_position_embeddings, args.device) 259 | for i in range(args.max_train_steps): 260 | x, y = train_loader.next_batch() 261 | forward_and_backward(model, x, y, optimizer) 262 | total_tokens += x.shape[0] * x.shape[1] 263 | torch.cuda.synchronize() 264 | end_time = time.time() 265 | throughput = total_tokens / (end_time - start_time) 266 | if args.wandb_log: 267 | wandb.log({"Training Throughput": throughput}) 268 | else: 269 | print(f"Training throughput: {throughput:.2f} tokens/s") 270 | 271 | 272 | def main(args): 273 | if not args.estimate_throughput: 274 | train(args) 275 | else: 276 | estimate_throughput(args) 277 | 278 | 279 | if __name__ == "__main__": 280 | args = argparse.ArgumentParser() 281 | args.add_argument("--device", type=str, default="cuda") 282 | args.add_argument("--dataset", type=str, default="fineweb_edu") 283 | 284 | args.add_argument("--warmup-steps", type=int, default=20) 285 | args.add_argument("--learning-rate", type=float, default=5e-4) 286 | args.add_argument("--min-learning-rate", type=float, default=5e-5) 287 | args.add_argument("--max-position-embeddings", type=int, default=512) 288 | 289 | args.add_argument("--eval-iters", type=int, default=5) 290 | args.add_argument("--eval-interval", type=int, default=10) 291 | args.add_argument("--dtype", type=str, default="bfloat16") 292 | args.add_argument("--measure-throughput-interval", type=int, default=100) 293 | args.add_argument("--estimate-throughput", type=bool, default=False) 294 | 295 | args.add_argument("--wandb-log", type=bool, default=True) 296 | args.add_argument("--wandb-project", type=str, default="deepseek training") 297 | args.add_argument( 298 | "--wandb-run-name", type=str, default="8_bit_optimizer" 299 | ) 300 | 301 | args.add_argument("--adamw-use-fused", type=bool, default=True) 302 | 303 | args.add_argument("--max-train-steps", type=int, default=30000) 304 | args.add_argument("--batch-size", type=int, default=8) 305 | args.add_argument("--gradient-accumulation-steps", type=int, default=8) 306 | args.add_argument("--warmup-iters", type=int, default=500) 307 | args.add_argument("--lr-decay-iters", type=int, default=1000) 308 | args.add_argument("--decay-lr", type=bool, default=True) 309 | 310 | args.add_argument("--out-dir", type=str, default="output") 311 | args.add_argument("--resume", type=bool, default=False) 312 | args.add_argument( 313 | "--checkpoint-path", 314 | type=str, 315 | default="8_bit_optimizer_ckpt.pt", 316 | ) 317 | 318 | # adamw arguments 319 | args.add_argument("--adamw-beta1", type=float, default=0.9) 320 | args.add_argument("--adamw-beta2", type=float, default=0.95) 321 | args.add_argument("--adamw-weight-decay", type=float, default=0.1) 322 | args.add_argument("--use-eight-bit-optimizer", type=bool, default=True) 323 | 324 | args.add_argument("--grad-clip", type=float, default=1.0) 325 | 326 | main(args.parse_args()) 327 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import json 6 | import tiktoken 7 | from config import DeepSeekConfig 8 | from deepseek import DeepSeekModelForCausalLM 9 | import torch.nn as nn 10 | import time 11 | import math 12 | from contextlib import nullcontext 13 | import bitsandbytes as bnb 14 | from enum import Enum 15 | from dataloader import TinyShakespeareDataLoader, FineWebEduDataLoader, DataLoader 16 | 17 | 18 | class DatasetName(Enum): 19 | TINY_SHAKESPEARE = "tinyshakespeare" 20 | FINE_WEB_EDU = "fineweb_edu" 21 | 22 | ptdtype = { 23 | "float32": torch.float32, 24 | "bfloat16": torch.bfloat16, 25 | "float16": torch.float16, 26 | } 27 | 28 | 29 | def configure_optimizers( 30 | model, weight_decay, learning_rate, betas, fused, use_eight_bit_optimizer=False 31 | ): 32 | # start with all of the candidate parameters 33 | param_dict = {pn: p for pn, p in model.named_parameters()} 34 | # filter out those that do not require grad 35 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 36 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwis no. 37 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 38 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 39 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 40 | optim_groups = [ 41 | {"params": decay_params, "weight_decay": weight_decay}, 42 | {"params": nodecay_params, "weight_decay": 0.0}, 43 | ] 44 | num_decay_params = sum(p.numel() for p in decay_params) 45 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 46 | print( 47 | f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" 48 | ) 49 | print( 50 | f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" 51 | ) 52 | # Create AdamW optimizer and use the fused version if it is available 53 | if use_eight_bit_optimizer: 54 | # fuse is not supported 55 | optimizer = bnb.optim.AdamW8bit( 56 | optim_groups, lr=learning_rate, betas=betas 57 | ) 58 | else: 59 | optimizer = torch.optim.AdamW( 60 | optim_groups, lr=learning_rate, betas=betas, fused=fused 61 | ) 62 | return optimizer 63 | 64 | 65 | # learning rate decay scheduler (cosine with warmup) 66 | def get_lr(it, warmup_iters, lr_decay_iters, learning_rate, min_lr): 67 | # 1) linear warmup for warmup_iters steps 68 | if it < warmup_iters: 69 | return learning_rate * (it + 1) / (warmup_iters + 1) 70 | # 2) if it > lr_decay_iters, return min learning rate 71 | if it > lr_decay_iters: 72 | return min_lr 73 | # 3) in between, use cosine decay down to min learning rate 74 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 75 | assert 0 <= decay_ratio <= 1 76 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 77 | return min_lr + coeff * (learning_rate - min_lr) 78 | 79 | 80 | def get_dataloader(dataset_name, split, batch_size, seq_len, device): 81 | if dataset_name == DatasetName.TINY_SHAKESPEARE.value: 82 | dataloader = TinyShakespeareDataLoader(batch_size, seq_len, split, device) 83 | elif dataset_name == DatasetName.FINE_WEB_EDU.value: 84 | dataloader = FineWebEduDataLoader(batch_size, seq_len, split, device) 85 | else: 86 | raise ValueError(f"Dataset {dataset_name} is not supported") 87 | return dataloader 88 | 89 | 90 | @torch.no_grad() 91 | def evalaute( 92 | model: nn.Module, 93 | eval_iters: int, 94 | eval_dataloader: DataLoader, 95 | ): 96 | model.eval() 97 | losses = torch.zeros(eval_iters) 98 | for i in range(eval_iters): 99 | x, y = eval_dataloader.next_batch() 100 | _, loss, _ = model(x, y) 101 | losses[i] = loss.item() 102 | model.train() 103 | return losses.mean() 104 | 105 | 106 | def get_model_config(args): 107 | with open("config.json", "r") as f: 108 | config = json.load(f) 109 | return config 110 | 111 | 112 | def get_model(args): 113 | config = get_model_config(args) 114 | model = DeepSeekModelForCausalLM(DeepSeekConfig(**config)) 115 | model.to(args.device) 116 | total_params, activated_params = model.get_total_parameters() 117 | print(f"Total parameters: {total_params:,}") 118 | print(f"Activated parameters: {activated_params:,}") 119 | print(f"Activated parameters ratio: {activated_params / total_params:.2%}") 120 | return model 121 | 122 | 123 | def forward_and_backward(model, x, y, optimizer, ctx: torch.autocast = nullcontext()): 124 | with ctx: 125 | _, loss, _ = model(x, y) 126 | loss.backward() 127 | optimizer.step() 128 | optimizer.zero_grad(set_to_none=True) 129 | 130 | 131 | def train(args): 132 | global ptdtype 133 | ptdtype = ptdtype[args.dtype] 134 | ctx = ( 135 | nullcontext() 136 | if args.device == "cpu" 137 | else torch.amp.autocast(device_type=args.device, dtype=ptdtype) 138 | ) 139 | if args.wandb_log: 140 | import wandb 141 | 142 | wandb.init( 143 | project=args.wandb_project, 144 | name=args.wandb_run_name, 145 | config=get_wandb_config(args), 146 | ) 147 | train_loader = get_dataloader(args.dataset, "train", args.batch_size, args.max_position_embeddings, args.device) 148 | val_loader = get_dataloader(args.dataset, "val", args.batch_size, args.max_position_embeddings, args.device) 149 | model = get_model(args) 150 | optimizer = configure_optimizers( 151 | model, 152 | args.adamw_weight_decay, 153 | args.learning_rate, 154 | (args.adamw_beta1, args.adamw_beta2), 155 | args.adamw_use_fused, 156 | args.use_eight_bit_optimizer, 157 | ) 158 | optimizer.zero_grad(set_to_none=True) 159 | best_val_loss = 1e9 160 | iter_num = 0 161 | if args.resume: 162 | checkpoint = torch.load(os.path.join(args.out_dir, "ckpt.pt")) 163 | config = checkpoint["model_config"] 164 | model = DeepSeekModelForCausalLM(config) 165 | model.to(args.device) 166 | model.load_state_dict(checkpoint["model"]) 167 | optimizer.load_state_dict(checkpoint["optimizer"]) 168 | best_val_loss = checkpoint["best_val_loss"] 169 | iter_num = checkpoint["iter_num"] 170 | while iter_num < args.max_train_steps: 171 | # determine and set the learning rate for this iteration 172 | lr = ( 173 | get_lr( 174 | iter_num, 175 | args.warmup_iters, 176 | args.lr_decay_iters, 177 | args.learning_rate, 178 | args.min_learning_rate, 179 | ) 180 | if args.decay_lr 181 | else args.learning_rate 182 | ) 183 | for param_group in optimizer.param_groups: 184 | param_group["lr"] = lr 185 | for _ in range(args.gradient_accumulation_steps): 186 | x, y = train_loader.next_batch() 187 | with ctx: 188 | _, train_loss, _ = model(x, y) 189 | train_loss = train_loss / args.gradient_accumulation_steps 190 | train_loss.backward() 191 | if args.grad_clip > 0: 192 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 193 | optimizer.step() 194 | optimizer.zero_grad(set_to_none=True) 195 | if (iter_num + 1) % args.eval_interval == 0: 196 | val_loss = evalaute( 197 | model, 198 | args.eval_iters, 199 | val_loader, 200 | ) 201 | if args.wandb_log: 202 | wandb.log( 203 | { 204 | "Step": iter_num, 205 | "Train Loss": train_loss.item(), 206 | "Val Loss": val_loss, 207 | "Learning Rate": lr, 208 | } 209 | ) 210 | if val_loss < best_val_loss: 211 | best_val_loss = val_loss 212 | if iter_num > 0: 213 | checkpoint = { 214 | "model": model.state_dict(), 215 | "optimizer": optimizer.state_dict(), 216 | "model_config": model.config, 217 | "iter_num": iter_num, 218 | "best_val_loss": best_val_loss, 219 | "training_args": args, 220 | } 221 | torch.save( 222 | checkpoint, os.path.join(args.out_dir, args.checkpoint_path) 223 | ) 224 | print( 225 | f"step {iter_num+1}: train loss: {train_loss.item():.4f}, val loss: {val_loss:.4f}" 226 | ) 227 | iter_num += 1 228 | 229 | 230 | def get_wandb_config(args): 231 | config = { 232 | "batch_size": args.batch_size, 233 | "learning_rate": args.learning_rate, 234 | "use_fused_adamw": args.adamw_use_fused, 235 | } 236 | config.update(get_model_config(args)) 237 | return config 238 | 239 | 240 | def estimate_throughput(args): 241 | if args.wandb_log: 242 | import wandb 243 | 244 | wandb.init( 245 | project=args.wandb_project, 246 | name=args.wandb_run_name, 247 | config=get_wandb_config(args), 248 | ) 249 | data_dir = os.path.join("data", args.dataset) 250 | model = get_model(args) 251 | optimizer = torch.optim.AdamW( 252 | model.parameters(), lr=args.learning_rate, fused=args.adamw_use_fused 253 | ) 254 | total_tokens = 0 255 | total_time = 0 256 | torch.cuda.synchronize() 257 | start_time = time.time() 258 | train_loader = get_dataloader(args.dataset, "train", args.batch_size, args.max_position_embeddings, args.device) 259 | for i in range(args.max_train_steps): 260 | x, y = train_loader.next_batch() 261 | forward_and_backward(model, x, y, optimizer) 262 | total_tokens += x.shape[0] * x.shape[1] 263 | torch.cuda.synchronize() 264 | end_time = time.time() 265 | throughput = total_tokens / (end_time - start_time) 266 | if args.wandb_log: 267 | wandb.log({"Training Throughput": throughput}) 268 | else: 269 | print(f"Training throughput: {throughput:.2f} tokens/s") 270 | 271 | 272 | def main(args): 273 | if not args.estimate_throughput: 274 | train(args) 275 | else: 276 | estimate_throughput(args) 277 | 278 | 279 | if __name__ == "__main__": 280 | args = argparse.ArgumentParser() 281 | args.add_argument("--device", type=str, default="cuda") 282 | args.add_argument("--dataset", type=str, default="fineweb_edu") 283 | 284 | args.add_argument("--warmup-steps", type=int, default=20) 285 | args.add_argument("--learning-rate", type=float, default=5e-4) 286 | args.add_argument("--min-learning-rate", type=float, default=5e-5) 287 | args.add_argument("--max-position-embeddings", type=int, default=512) 288 | 289 | args.add_argument("--eval-iters", type=int, default=5) 290 | args.add_argument("--eval-interval", type=int, default=10) 291 | args.add_argument("--dtype", type=str, default="bfloat16") 292 | args.add_argument("--measure-throughput-interval", type=int, default=100) 293 | args.add_argument("--estimate-throughput", type=bool, default=False) 294 | 295 | args.add_argument("--wandb-log", type=bool, default=True) 296 | args.add_argument("--wandb-project", type=str, default="deepseek training") 297 | args.add_argument( 298 | "--wandb-run-name", type=str, default="8_bit_optimizer" 299 | ) 300 | 301 | args.add_argument("--adamw-use-fused", type=bool, default=True) 302 | 303 | args.add_argument("--max-train-steps", type=int, default=30000) 304 | args.add_argument("--batch-size", type=int, default=8) 305 | args.add_argument("--gradient-accumulation-steps", type=int, default=8) 306 | args.add_argument("--warmup-iters", type=int, default=500) 307 | args.add_argument("--lr-decay-iters", type=int, default=1000) 308 | args.add_argument("--decay-lr", type=bool, default=True) 309 | 310 | args.add_argument("--out-dir", type=str, default="output") 311 | args.add_argument("--resume", type=bool, default=False) 312 | args.add_argument( 313 | "--checkpoint-path", 314 | type=str, 315 | default="8_bit_optimizer_ckpt.pt", 316 | ) 317 | 318 | # adamw arguments 319 | args.add_argument("--adamw-beta1", type=float, default=0.9) 320 | args.add_argument("--adamw-beta2", type=float, default=0.95) 321 | args.add_argument("--adamw-weight-decay", type=float, default=0.1) 322 | args.add_argument("--use-eight-bit-optimizer", type=bool, default=True) 323 | 324 | args.add_argument("--grad-clip", type=float, default=1.0) 325 | 326 | main(args.parse_args()) 327 | -------------------------------------------------------------------------------- /mla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import RMSNorm 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from config import DeepSeekConfig 7 | from typing import Optional 8 | import math 9 | import json 10 | 11 | 12 | # https://arxiv.org/abs/2104.09864 13 | # look at https://github.com/kingoflolz/mesh-transformer-jax/blob/f2aa66e0925de6593dcbb70e72399b97b4130482/mesh_transformer/layers.py#L144 14 | # to see if we could use this implementation 15 | def apply_original_rope( 16 | q: torch.Tensor, 17 | k: torch.Tensor, 18 | positions: torch.Tensor, 19 | cos: torch.Tensor, 20 | sin: torch.Tensor, 21 | ) -> torch.Tensor: 22 | """ 23 | Apply rope to the q and k 24 | 25 | Args: 26 | q: the query tensor of shape (batch_size, nheads, seq_len, rope_head_dim) 27 | k: the key tensor of shape (batch_size, nheads, seq_len, rope_head_dim) 28 | positions: the positions of the input tensor 29 | cos: the cosine part of the rope embedding, shape (seq_len, rope_head_dim / 2) 30 | sin: the sine part of the rope embedding, shape (seq_len, rope_head_dim / 2) 31 | Returns: 32 | q and k after applying rope 33 | """ 34 | q_even = q[..., 0::2] 35 | q_odd = q[..., 1::2] 36 | k_even = k[..., 0::2] 37 | k_odd = k[..., 1::2] 38 | cos = cos[positions] 39 | sin = sin[positions] 40 | q_rotated = torch.zeros_like(q) 41 | k_rotated = torch.zeros_like(k) 42 | # in deepseek implementation, x_even * cos_embed - x_odd * sin_embed is first half 43 | # x_odd * cos_embed + x_even * sin_embed is second half 44 | # https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py#L339 45 | q_rotated[..., 0::2] = q_even * cos - q_odd * sin 46 | q_rotated[..., 1::2] = q_odd * cos + q_even * sin 47 | k_rotated[..., 0::2] = k_even * cos - k_odd * sin 48 | k_rotated[..., 1::2] = k_odd * cos + k_even * sin 49 | return q_rotated, k_rotated 50 | 51 | 52 | def rotate_half(x): 53 | x1 = x[..., : x.shape[-1] // 2] 54 | x2 = x[..., x.shape[-1] // 2 :] 55 | return torch.cat((-x2, x1), dim=-1) 56 | 57 | 58 | # https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py#L339 59 | def apply_deepseek_rope( 60 | q: torch.Tensor, 61 | k: torch.Tensor, 62 | positions: torch.Tensor, 63 | cos: torch.Tensor, 64 | sin: torch.Tensor, 65 | ) -> torch.Tensor: 66 | """ 67 | Apply rope to the q and k. 68 | 69 | The difference between the original rope and the deepseek rope is that the deepseek rope 70 | 71 | Args: 72 | q: the query tensor of shape (batch_size, nheads, seq_len, rope_head_dim) 73 | k: the key tensor of shape (batch_size, nheads, seq_len, rope_head_dim) 74 | positions: the positions of the input tensor 75 | cos: the cosine part of the rope embedding, shape (seq_len, rope_head_dim) 76 | sin: the sine part of the rope embedding, shape (seq_len, rope_head_dim) 77 | Returns: 78 | q and k after applying rope 79 | """ 80 | cos = cos[positions] 81 | sin = sin[positions] 82 | 83 | b, h, s, d = q.shape 84 | # transform q and k so that the even indices elements are the first half 85 | # the odd indices elements are the second half 86 | q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) 87 | # the sequence length for q and k might be different due to kv cache 88 | b, h, s, d = k.shape 89 | k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) 90 | 91 | q_rotated = (q * cos) + (rotate_half(q) * sin) 92 | k_rotated = (k * cos) + (rotate_half(k) * sin) 93 | return q_rotated, k_rotated 94 | 95 | 96 | """ 97 | Get the cos and sin for the rope embeddings 98 | """ 99 | 100 | 101 | class RotaryEmbedding(nn.Module): 102 | def __init__( 103 | self, 104 | rope_head_dim: int, 105 | max_position_embeddings: int, 106 | base: int = 10000, 107 | device: str = "cuda", 108 | ): 109 | super().__init__() 110 | self.rope_head_dim = rope_head_dim 111 | self.max_position_embeddings = max_position_embeddings 112 | self.base = base 113 | self.device = device 114 | 115 | inv_freqs = 1.0 / ( 116 | base 117 | ** (torch.arange(0, rope_head_dim, 2).float().to(device) / rope_head_dim) 118 | ) 119 | self.register_buffer("inv_freqs", inv_freqs, persistent=False) 120 | self.max_seq_len_cached = None 121 | self._set_cos_sin_cache( 122 | seq_len=max_position_embeddings, 123 | device=self.inv_freqs.device, 124 | dtype=self.inv_freqs.dtype, 125 | ) 126 | 127 | def _set_cos_sin_cache(self, seq_len, device, dtype): 128 | self.max_seq_len_cached = seq_len 129 | t = torch.arange(seq_len, device=device, dtype=dtype) 130 | freqs = torch.outer(t, self.inv_freqs.to(t.device)) 131 | emb = torch.cat((freqs, freqs), dim=-1) 132 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 133 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 134 | 135 | def forward(self, x, seq_len=None): 136 | if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: 137 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 138 | return self.cos_cached[:seq_len], self.sin_cached[:seq_len] 139 | 140 | 141 | # proposed by reddit user /u/kaiokendev 142 | # https://www.reddit.com/r/LocalLLaMA/comments/14fgjqj/a_simple_way_to_extending_context_to_8k/ 143 | # concurrent work from meta: https://arxiv.org/abs/2306.15595 144 | class PositinalInterpolationRotaryEmbedding(RotaryEmbedding): 145 | def __init__( 146 | self, 147 | rope_head_dim: int, 148 | max_position_embeddings: int, 149 | base: int = 10000, 150 | device: str = "cuda", 151 | scaling_factor: float = 1.0, 152 | ): 153 | self.scaling_factor = scaling_factor 154 | super().__init__(rope_head_dim, max_position_embeddings, base, device) 155 | 156 | def _set_cos_sin_cache(self, seq_len, device, dtype): 157 | self.max_seq_len_cached = seq_len 158 | t = torch.arange(seq_len, device=device, dtype=self.inv_freqs.dtype) 159 | t = t / self.scaling_factor 160 | freqs = torch.outer(t, self.inv_freqs) 161 | emb = torch.cat((freqs, freqs), dim=-1) 162 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 163 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 164 | 165 | 166 | class NTKAwareRotaryEmbedding(RotaryEmbedding): 167 | def __init__( 168 | self, 169 | rope_head_dim: int, 170 | max_position_embeddings: int, 171 | base: int = 10000, 172 | device: str = "cuda", 173 | alpha: float = 1.0, 174 | ): 175 | base = base * alpha ** (rope_head_dim / (rope_head_dim - 2)) 176 | super().__init__(rope_head_dim, max_position_embeddings, base, device) 177 | 178 | 179 | class DynamicNTKAwareScalingRotaryEmbedding(RotaryEmbedding): 180 | def __init__( 181 | self, 182 | rope_head_dim: int, 183 | max_position_embeddings: int, 184 | base: int = 10000, 185 | device: str = "cuda", 186 | scaling_factor: float = 1.0, 187 | ): 188 | super().__init__(rope_head_dim, max_position_embeddings, base, device) 189 | self.scaling_factor = scaling_factor 190 | 191 | def _set_cos_sin_cache(self, seq_len, device, dtype): 192 | self.max_seq_len_cached = seq_len 193 | if seq_len > self.max_position_embeddings: 194 | # (self.scaling_factor * seq_len / self.max_position_embeddings)- (self.scaling_factor - 1) 195 | # is the same as 196 | # scaling_factor * (seq_len - max_position_embeddings) / max_position_embeddings + 1 197 | # which makes more sense and easier to understand 198 | base = self.base * ( 199 | (self.scaling_factor * seq_len / self.max_position_embeddings) 200 | - (self.scaling_factor - 1) 201 | ) ** (self.rope_head_dim / (self.rope_head_dim - 2)) 202 | inv_freqs = 1.0 / ( 203 | base 204 | ** ( 205 | torch.arange(0, self.rope_head_dim, 2).float().to(device) 206 | / self.rope_head_dim 207 | ) 208 | ) 209 | self.register_buffer("inv_freqs", inv_freqs, persistent=False) 210 | t = torch.arange(seq_len, device=device, dtype=dtype) 211 | freqs = torch.outer(t, self.inv_freqs) 212 | emb = torch.cat((freqs, freqs), dim=-1) 213 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 214 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 215 | 216 | 217 | def find_yarn_dim(ratio, training_context_length, rope_head_dim, base): 218 | # ratio = training_context_length / wave_length 219 | # wave_length = 2 * math.pi * base ** (2 * d / rope_head_dim) 220 | # this is to solve the above equations to find d 221 | return ( 222 | rope_head_dim 223 | * math.log(training_context_length / (2 * math.pi * ratio)) 224 | / (2 * math.log(base)) 225 | ) 226 | 227 | 228 | def find_yarn_cut_dims(alpha, beta, training_context_length, rope_head_dim, base): 229 | low = find_yarn_dim(beta, training_context_length, rope_head_dim, base) 230 | low = math.floor(low) 231 | high = find_yarn_dim(alpha, training_context_length, rope_head_dim, base) 232 | high = math.ceil(high) 233 | return max(low, 0), min(high, rope_head_dim - 1) 234 | 235 | 236 | def yarn_ramp_mask(min, max, dim): 237 | if min == max: 238 | max += 0.001 239 | linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) 240 | ramp_func = torch.clamp(linear_func, 0, 1) 241 | return ramp_func 242 | 243 | 244 | def get_yarn_mscale(scaling_factor, attn_factor): 245 | if scaling_factor <= 1: 246 | return 1.0 247 | return (0.1 * math.log(scaling_factor) + 1.0) * attn_factor 248 | 249 | 250 | class YarnRotaryEmbedding(RotaryEmbedding): 251 | def __init__( 252 | self, 253 | rope_head_dim: int, 254 | max_position_embeddings: int, 255 | base: int = 10000, 256 | device: str = "cuda", 257 | scaling_factor: float = 1.0, 258 | training_context_length: int = 1024, 259 | alpha: int = 1, 260 | beta: int = 32, 261 | attn_factor: float = 1.0, 262 | ): 263 | """ 264 | From the paper: https://arxiv.org/pdf/2309.00071 265 | scaling_factor: the ratio between inference context length and training context length 266 | alpha: eq(18) of YaRN paper 267 | beta: eq(18) of YaRN paper 268 | attn_factor: the scaling factor for the temperature 269 | """ 270 | self.scaling_factor = scaling_factor 271 | self.training_context_length = training_context_length 272 | self.alpha = alpha 273 | self.beta = beta 274 | self.attn_factor = attn_factor 275 | super().__init__(rope_head_dim, max_position_embeddings, base, device) 276 | 277 | def _set_cos_sin_cache(self, seq_len, device, dtype): 278 | self.max_seq_len_cached = seq_len 279 | inv_freqs_interpolation = 1.0 / ( 280 | self.scaling_factor 281 | * self.base ** torch.arange(0, self.rope_head_dim, 2).float().to(device) 282 | / self.rope_head_dim 283 | ) 284 | inv_freqs_extrapolation = 1.0 / ( 285 | self.base ** torch.arange(0, self.rope_head_dim, 2).float().to(device) 286 | / self.rope_head_dim 287 | ) 288 | low, high = find_yarn_cut_dims( 289 | self.alpha, 290 | self.beta, 291 | self.training_context_length, 292 | self.rope_head_dim, 293 | self.base, 294 | ) 295 | # for dimension lower than low, use extrapolation 296 | # for dimension higher than high, use interpolation 297 | # for dimension between low and high, use both extrapolation and interpolation 298 | # for mask is 1, use extrapolation, for mask is 0, use interpolation 299 | # in between use both 300 | inv_freq_mask = 1.0 - yarn_ramp_mask( 301 | low, high, self.rope_head_dim // 2 302 | ).float().to(device) 303 | inv_freqs = ( 304 | 1.0 - inv_freq_mask 305 | ) * inv_freqs_interpolation + inv_freq_mask * inv_freqs_extrapolation 306 | self.register_buffer("inv_freqs", inv_freqs, persistent=False) 307 | 308 | t = torch.arange(seq_len, device=device, dtype=dtype) 309 | freqs = torch.outer(t, self.inv_freqs) 310 | emb = torch.cat((freqs, freqs), dim=-1) 311 | 312 | _mscale = get_yarn_mscale(self.scaling_factor, self.attn_factor) 313 | 314 | self.register_buffer( 315 | "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False 316 | ) 317 | self.register_buffer( 318 | "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False 319 | ) 320 | 321 | 322 | class KVCache(object): 323 | def __init__(self, num_layers: int): 324 | self.key_cache = [None] * num_layers 325 | self.value_cache = [None] * num_layers 326 | 327 | def update( 328 | self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int 329 | ): 330 | # key_states and value_states are of shape [bsz, num_heads, seq_len, head_dim] 331 | past_keys = self.key_cache[layer_idx] 332 | past_values = self.value_cache[layer_idx] 333 | if past_keys is None: 334 | self.key_cache[layer_idx] = key_states 335 | self.value_cache[layer_idx] = value_states 336 | else: 337 | # concatenate along the sequence length dimension 338 | self.key_cache[layer_idx] = torch.cat((past_keys, key_states), dim=-2) 339 | self.value_cache[layer_idx] = torch.cat((past_values, value_states), dim=-2) 340 | return self.key_cache[layer_idx], self.value_cache[layer_idx] 341 | 342 | def get_cache_length(self, layer_idx: int): 343 | if self.key_cache[layer_idx] is None: 344 | return 0 345 | return self.key_cache[layer_idx].shape[-2] 346 | 347 | 348 | def get_attention_mask(attn_mask: torch.tensor): 349 | """ 350 | Args: 351 | attn_mask: (B, T) 352 | Returns: 353 | attn_mask: (B, 1, T, T) 354 | """ 355 | batch_size, seq_length = attn_mask.shape 356 | 357 | # Create the outer product of attention_mask with itself for each batch 358 | attn_mask = attn_mask.to(torch.float32) 359 | mask_matrix = attn_mask.unsqueeze(-1) @ attn_mask.unsqueeze(1) 360 | 361 | # Create a causal mask (lower triangular) 362 | causal_mask = torch.tril( 363 | torch.ones(seq_length, seq_length, device=attn_mask.device) 364 | ) 365 | 366 | # Combine the attention mask with the causal mask 367 | # A position is valid if both: 368 | # 1. Both tokens are real (from the attention mask) 369 | # 2. The position is in the lower triangle (from the causal mask) 370 | mask_matrix = mask_matrix * causal_mask.unsqueeze(0) 371 | 372 | # Convert to boolean 373 | mask_matrix = mask_matrix.bool().unsqueeze(1) 374 | 375 | return mask_matrix 376 | 377 | 378 | class MultiHeadLatentAttention(nn.Module): 379 | def __init__(self, config: DeepSeekConfig, layer_idx: int): 380 | super().__init__() 381 | self.config = config 382 | self.layer_idx = layer_idx 383 | # transformation for Q 384 | self.q_head_dim = config.rope_head_dim + config.nope_head_dim 385 | self.q_lora_rank = config.q_lora_rank 386 | if config.q_lora_rank is not None: 387 | self.q_a_proj = nn.Linear(config.d_model, config.q_lora_rank, bias=False) 388 | self.q_a_layernorm = RMSNorm(config.q_lora_rank) 389 | self.q_b_proj = nn.Linear( 390 | config.q_lora_rank, 391 | config.nheads * self.q_head_dim, 392 | bias=False, 393 | ) 394 | else: 395 | self.q_proj = nn.Linear( 396 | config.d_model, config.nheads * self.q_head_dim, bias=False 397 | ) 398 | 399 | # tranformation for K and V 400 | self.kv_a_proj_with_mqa = nn.Linear( 401 | config.d_model, config.kv_lora_rank + config.rope_head_dim, bias=False 402 | ) 403 | self.kv_a_layernorm = RMSNorm(config.kv_lora_rank) 404 | self.kv_b_proj = nn.Linear( 405 | config.kv_lora_rank, 406 | config.nheads * (config.nope_head_dim + config.v_head_dim), 407 | bias=False, 408 | ) 409 | self.o_proj = nn.Linear( 410 | config.nheads * config.v_head_dim, config.d_model, bias=False 411 | ) 412 | self.attention_dropout_rate = config.dropout 413 | self.residual_dropout = nn.Dropout(config.dropout) 414 | 415 | self.nheads = config.nheads 416 | self.rope_head_dim = config.rope_head_dim 417 | self.nope_head_dim = config.nope_head_dim 418 | self.max_position_embeddings = config.max_position_embeddings 419 | self.rope_base = config.rope_base 420 | self.kv_lora_rank = config.kv_lora_rank 421 | self.q_lora_rank = config.q_lora_rank 422 | self.v_head_dim = config.v_head_dim 423 | self._init_rope() 424 | 425 | def _init_rope(self): 426 | if self.config.rope_scaling is None: 427 | self.rope = RotaryEmbedding( 428 | self.rope_head_dim, 429 | self.max_position_embeddings, 430 | base=self.rope_base, 431 | ) 432 | else: 433 | scaling_type = self.config.rope_scaling["type"] 434 | scaling_factor = self.config.rope_scaling["scaling_factor"] 435 | if scaling_type == "pi": 436 | self.rope = PositinalInterpolationRotaryEmbedding( 437 | self.rope_head_dim, 438 | self.max_position_embeddings, 439 | base=self.rope_base, 440 | scaling_factor=scaling_factor, 441 | ) 442 | elif scaling_type == "dynamic": 443 | self.rope = DynamicNTKAwareScalingRotaryEmbedding( 444 | self.rope_head_dim, 445 | self.max_position_embeddings, 446 | base=self.rope_base, 447 | scaling_factor=scaling_factor, 448 | ) 449 | elif scaling_type == "yarn": 450 | alpha = self.config.rope_scaling["alpha"] 451 | beta = self.config.rope_scaling["beta"] 452 | attn_factor = self.config.rope_scaling["attn_factor"] 453 | self.rope = YarnRotaryEmbedding( 454 | self.rope_head_dim, 455 | self.max_position_embeddings, 456 | base=self.rope_base, 457 | scaling_factor=scaling_factor, 458 | alpha=alpha, 459 | beta=beta, 460 | attn_factor=attn_factor, 461 | training_context_length=self.config.rope_scaling[ 462 | "training_context_length" 463 | ], 464 | ) 465 | 466 | def forward( 467 | self, 468 | x: torch.Tensor, 469 | past_key_value: Optional[KVCache] = None, 470 | attn_mask: Optional[torch.tensor] = None, 471 | ): 472 | B, q_len = x.shape[:2] 473 | 474 | if self.q_lora_rank is not None: 475 | q = self.q_a_layernorm(self.q_a_proj(x)) 476 | q = self.q_b_proj(q) 477 | else: 478 | q = self.q_proj(x) 479 | # B, nheads, q_len, rope_head_dim + nope_head_dim 480 | q = q.view(B, q_len, self.nheads, self.q_head_dim).transpose(1, 2) 481 | # q_nope: B, nheads, q_len, nope_head_dim 482 | # q_rope: B, nheads, q_len, rope_head_dim 483 | q_nope, q_rope = torch.split( 484 | q, [self.nope_head_dim, self.rope_head_dim], dim=-1 485 | ) 486 | 487 | # B, q_len, kv_lora_rank + rope_head_dim 488 | kv_compressed = self.kv_a_proj_with_mqa(x) 489 | # kv_compressed: B, q_len, kv_lora_rank 490 | # k_rope: B, q_len, rope_head_dim 491 | kv_compressed, k_rope = kv_compressed.split( 492 | [self.kv_lora_rank, self.rope_head_dim], dim=-1 493 | ) 494 | # add head dimension to k_rope. This k_rope is shared across all heads 495 | k_rope = k_rope.view(B, 1, q_len, self.rope_head_dim) 496 | 497 | # B, nheads, T, (nope_head_dim + v_head_dim) 498 | kv = ( 499 | self.kv_b_proj(self.kv_a_layernorm(kv_compressed)) 500 | .view(B, -1, self.nheads, self.nope_head_dim + self.v_head_dim) 501 | .transpose(1, 2) 502 | ) 503 | k_nope, value_states = torch.split( 504 | kv, [self.nope_head_dim, self.v_head_dim], dim=-1 505 | ) 506 | 507 | # apply rope to k_rope and q_rope 508 | # first get the seq_len including cache before this token 509 | past_seq_len = 0 510 | if past_key_value is not None: 511 | past_seq_len = past_key_value.get_cache_length(self.layer_idx) 512 | 513 | positions = torch.arange( 514 | past_seq_len, q_len + past_seq_len, device=x.device, dtype=torch.long 515 | ) 516 | kv_seq_len = q_len + past_seq_len 517 | cos, sin = self.rope(value_states, seq_len=kv_seq_len) 518 | # q_rope: B, nheads, q_len, rope_head_dim 519 | # k_rope: B, 1, q_len, rope_head_dim 520 | q_rope, k_rope = apply_deepseek_rope(q_rope, k_rope, positions, cos, sin) 521 | 522 | # concatenate q/k_rope and q/k_nope 523 | query_states = q_rope.new_empty(B, self.nheads, q_len, self.q_head_dim) 524 | query_states[..., : self.nope_head_dim] = q_nope 525 | query_states[..., self.nope_head_dim :] = q_rope 526 | 527 | key_states = k_rope.new_empty(B, self.nheads, q_len, self.q_head_dim) 528 | key_states[..., : self.nope_head_dim] = k_nope 529 | key_states[..., self.nope_head_dim :] = k_rope 530 | 531 | # update kv cache 532 | if past_key_value is not None: 533 | key_states, value_states = past_key_value.update( 534 | key_states, value_states, self.layer_idx 535 | ) 536 | 537 | # when the q_len = kv_seq_len, the attention mask is casual mask 538 | # otherwise, the query can attend to all past tokens, so the attention bias is all 0 539 | if q_len == kv_seq_len: 540 | attn_mask = get_attention_mask(attn_mask) 541 | # B, nheads, q_len, v_head_dim 542 | output = F.scaled_dot_product_attention( 543 | query_states, 544 | key_states, 545 | value_states, 546 | attn_mask=attn_mask, 547 | dropout_p=self.attention_dropout_rate, 548 | ) 549 | # B, q_len, nheads * v_head_dim 550 | output = ( 551 | output.transpose(1, 2) 552 | .contiguous() 553 | .view(B, -1, self.nheads * self.v_head_dim) 554 | ) 555 | output = self.residual_dropout(self.o_proj(output)) 556 | return output, past_key_value 557 | 558 | 559 | if __name__ == "__main__": 560 | with open("config.json", "r") as f: 561 | config = json.load(f) 562 | config = DeepSeekConfig(**config) 563 | mla = MultiHeadLatentAttention(config, layer_idx=0) 564 | mla = mla.to(config.device) 565 | input = torch.randn(2, 2, 1024).to(config.device) 566 | output, _ = mla(input) 567 | print(f"MLA output shape: {output.shape}") 568 | print("Add another token") 569 | new_token_ermbedding = torch.randn(2, 1, 1024).to(config.device) 570 | input = torch.cat((input, new_token_ermbedding), dim=1) 571 | output, _ = mla(input) 572 | print(f"MLA output shape: {output.shape}") 573 | 574 | # use kv cache 575 | mla = MultiHeadLatentAttention(config, layer_idx=0).to(config.device) 576 | kv_cache = KVCache(config.num_layers) 577 | output, kv_cache = mla(input, kv_cache) 578 | print(f"MLA output shape: {output.shape}") 579 | print(f"KV cache shape: {kv_cache.key_cache[0].shape}") 580 | new_token_ermbedding = torch.randn(2, 1, 1024).to(config.device) 581 | output, kv_cache = mla(new_token_ermbedding, kv_cache) 582 | print(f"MLA output shape: {output.shape}") 583 | print(f"KV cache shape: {kv_cache.key_cache[0].shape}") 584 | --------------------------------------------------------------------------------