├── LICENSE ├── README.md ├── index.html ├── modified_llama.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MatFormer: Nested Transformer for Elastic Inference 2 | 3 | This repository provides a public reproduction and open-source implementation of the [MatFormer](https://nips.cc/virtual/2024/poster/94199)'s language modeling experiments (MatLM). It includes the essential building blocks and code required to reproduce the results presented in the paper. 4 | 5 | ## Features 6 | - Simplified implementation of MatFormer for language modeling tasks. 7 | - Open-source release for community use and further research. 8 | - Reproducibility: Includes key components to replicate the experiments. 9 | 10 | ## Running LM Pre-training Jobs 11 | 12 | To run the training script, execute: 13 | 14 | ```bash 15 | python3 train.py 16 | ``` 17 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Redirecting... 8 | 9 | 10 |

If you are not redirected automatically, click here.

11 | 12 | 13 | -------------------------------------------------------------------------------- /modified_llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers import LlamaForCausalLM 4 | from transformers.models.llama.modeling_llama import LlamaMLP 5 | import torch.nn as nn 6 | 7 | class ModifiedLlamaMLP(LlamaMLP): 8 | def __init__(self, config, scale_factors): 9 | super().__init__( 10 | hidden_size=config.hidden_size, 11 | intermediate_size=config.intermediate_size, 12 | hidden_act=config.hidden_act) 13 | self.intermediate_size = config.intermediate_size 14 | self.scale_factors = scale_factors # List of scale factors for 's', 'm', 'l', 'xl' 15 | self.current_subset_hd = None 16 | 17 | def configure_subnetwork(self, flag): 18 | """Configure subnetwork size based on flag.""" 19 | hd = self.intermediate_size 20 | if flag == 's': 21 | scale = self.scale_factors[0] # hd/8 22 | elif flag == 'm': 23 | scale = self.scale_factors[1] # hd/4 24 | elif flag == 'l': 25 | scale = self.scale_factors[2] # hd/2 26 | else: # 'xl' 27 | scale = self.scale_factors[3] # hd 28 | 29 | self.current_subset_hd = int(hd * scale) 30 | 31 | def forward(self, x): 32 | if self.current_subset_hd is None: 33 | raise ValueError("Subnetwork size not configured. Call `configure_subnetwork` first.") 34 | gate_proj = self.gate_proj.weight[:self.current_subset_hd] 35 | up_proj = self.up_proj.weight[:self.current_subset_hd] 36 | down_proj = self.down_proj.weight[:, :self.current_subset_hd] 37 | down_proj = F.linear(self.act_fn(F.linear(x, gate_proj) * F.linear(x, up_proj)), down_proj) 38 | self.current_subset_hd = None 39 | 40 | return down_proj 41 | 42 | 43 | class ModifiedLlamaForCausalLM(LlamaForCausalLM): 44 | def __init__(self, config): 45 | super().__init__(config) 46 | scale_factors = [1/8, 1/4, 1/2, 1] # s, m, l, xl 47 | 48 | # Replace FFN in each layer with ModifiedFFN 49 | for layer_idx in range(config.num_hidden_layers): 50 | self.model.layers[layer_idx].mlp = ModifiedLlamaMLP(config, scale_factors) 51 | 52 | def configure_subnetwork(self, flag): 53 | """Configure the subnetwork for all layers based on the flag.""" 54 | for layer_idx in range(len(self.model.layers)): 55 | self.model.layers[layer_idx].mlp.configure_subnetwork(flag) 56 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from transformers import AutoTokenizer, LlamaConfig 4 | from datasets import load_dataset 5 | from modified_llama import ModifiedLlamaForCausalLM 6 | from transformers import get_scheduler 7 | import functools 8 | 9 | device = torch.device( 'cuda' ) if torch.cuda.is_available() else torch.device( 'cpu' ) 10 | 11 | def preprocess_data(example, tokenizer): 12 | tokenizer.pad_token = tokenizer.eos_token 13 | return tokenizer(example["text"], truncation=True, padding="max_length", max_length=512) 14 | 15 | def collate_fn(batch): 16 | input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) 17 | attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) 18 | labels = input_ids.clone() 19 | 20 | # Generate a random flag for the entire batch 21 | flag = random.choice(['s', 'm', 'l', 'xl']) 22 | return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "flag": flag} 23 | 24 | def evaluate_model(model, eval_dataloader, flags): 25 | """Evaluate the model on the eval dataset for each flag and return losses.""" 26 | model.eval() 27 | eval_losses = {flag: 0.0 for flag in flags} 28 | num_batches = len(eval_dataloader) 29 | 30 | with torch.no_grad(): 31 | for flag in flags: 32 | total_loss = 0.0 33 | for batch in eval_dataloader: 34 | input_ids = batch["input_ids"].to(model.device) 35 | attention_mask = batch["attention_mask"].to(model.device) 36 | labels = batch["labels"].to(model.device) 37 | 38 | # Configure the subnetwork for the flag 39 | model.configure_subnetwork(flag) 40 | 41 | # Forward pass 42 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 43 | total_loss += outputs.loss.item() 44 | 45 | eval_losses[flag] = total_loss / num_batches 46 | 47 | model.train() 48 | return eval_losses 49 | 50 | if __name__ == "__main__": 51 | # Load tokenizer 52 | print("loading tokenizer", flush=True) 53 | tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B") 54 | tokenizer.pad_token = tokenizer.eos_token 55 | 56 | # Load dataset and preprocess 57 | print("loading dataset") 58 | dataset = load_dataset("vilm/RedPajama-v2-small", split="train") 59 | # Shuffle the dataset to ensure randomness 60 | dataset = dataset.shuffle(seed=42) 61 | # Select the first 100,000 examples 62 | dataset = dataset.select(range(10000)) 63 | # map over the dataset and transform. 64 | print("preprocessing dataset") 65 | dataset = dataset.map(functools.partial(preprocess_data, tokenizer=tokenizer), num_proc=32) 66 | 67 | # Initialize Llama configuration and model from scratch 68 | print("loading config", flush=True) 69 | config = LlamaConfig.from_pretrained("NousResearch/Llama-3.2-1B") 70 | print("initializing model. This may take a while... ", end="", flush=True) 71 | model = ModifiedLlamaForCausalLM(config).to(device) 72 | print("Done!") 73 | 74 | # Split dataset into train and evaluation 75 | batch_size = 8 76 | eval_dataset = dataset.select(range(20 * batch_size)) # 10 batches for evaluation 77 | train_dataset = dataset.select(range(20 * batch_size, len(dataset))) 78 | 79 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn) 80 | eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size, collate_fn=collate_fn) 81 | 82 | # Training arguments 83 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) 84 | 85 | # Define the number of training steps 86 | num_training_steps = 10000 # 200 warmup, remaining cosine decay 87 | num_warmup_steps = 200 88 | 89 | # Define the scheduler 90 | scheduler = get_scheduler( 91 | "cosine", 92 | optimizer=optimizer, 93 | num_warmup_steps=num_warmup_steps, 94 | num_training_steps=num_training_steps, 95 | ) 96 | model.train() 97 | 98 | flags = ['s', 'm', 'l', 'xl'] 99 | step = 0 100 | 101 | for batch in train_dataloader: 102 | input_ids = batch["input_ids"].to(model.device) 103 | attention_mask = batch["attention_mask"].to(model.device) 104 | labels = batch["labels"].to(model.device) 105 | flag = batch["flag"] 106 | 107 | # Configure the subnetwork for the entire batch 108 | model.configure_subnetwork(flag) 109 | 110 | # Forward pass 111 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 112 | loss = outputs.loss 113 | 114 | # Backward pass 115 | optimizer.zero_grad() 116 | loss.backward() 117 | optimizer.step() 118 | 119 | # Step the scheduler 120 | scheduler.step() 121 | 122 | step += 1 123 | print(f"Step {step}, Loss: {loss.item()}") 124 | 125 | # Evaluate every 100 steps 126 | if step % 100 == 0: 127 | eval_losses = evaluate_model(model, eval_dataloader, flags) 128 | print(f"Step {step}, Eval Losses: {eval_losses}") 129 | 130 | # Stop training after the defined number of steps 131 | if step >= num_training_steps: 132 | break 133 | --------------------------------------------------------------------------------