├── requirements.txt ├── .gitignore ├── accelerate_config.yaml ├── CONTRIBUTING.md ├── model ├── utils.py ├── utils_test.py ├── calm_test.py ├── layers.py └── calm.py ├── train.py ├── README.md └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.6.0 2 | transformers==4.48.0 3 | numpy 4 | datasets==2.15.0 5 | accelerate 6 | tensorboard 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | fsdp_config: 6 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 7 | fsdp_backward_prefetch: BACKWARD_PRE 8 | fsdp_cpu_ram_efficient_loading: true 9 | fsdp_forward_prefetch: false 10 | fsdp_offload_params: true 11 | fsdp_sharding_strategy: FULL_SHARD 12 | fsdp_state_dict_type: SHARDED_STATE_DICT 13 | fsdp_transformer_layer_cls_to_wrap: GemmaDecoderLayer 14 | fsdp_sync_module_states: true 15 | fsdp_use_orig_params: true 16 | machine_rank: 0 17 | main_training_function: train 18 | mixed_precision: bf16 19 | num_machines: 1 20 | num_processes: 4 21 | rdzv_backend: static 22 | same_network: true 23 | tpu_env: [] 24 | tpu_use_cluster: false 25 | tpu_use_sudo: false 26 | use_cpu: false 27 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils for CALM.""" 17 | 18 | import numpy as np 19 | 20 | 21 | def check_connections( 22 | connections: list[tuple[int, int]], 23 | num_anchor_layers: int, 24 | num_aug_layers: int, 25 | ) -> bool: 26 | """Checks if the connections are valid.""" 27 | for connection in connections: 28 | if connection[0] < 0 or connection[0] >= num_anchor_layers: 29 | print( 30 | f"Please verify your connections again. Index {connection[0]} doesn't" 31 | f" exist as anchor model only has {num_anchor_layers} layers" 32 | ) 33 | return False 34 | if connection[1] < 0 or connection[1] >= num_aug_layers: 35 | print( 36 | f"Please verify your connections again. Index {connection[1]} doesn't" 37 | f" exist as augmenting model only has {num_aug_layers} layers" 38 | ) 39 | return False 40 | return True 41 | 42 | 43 | def get_connections( 44 | num_connections: int, 45 | num_anchor_layers: int, 46 | num_aug_layers: int, 47 | ) -> list[tuple[int, int]]: 48 | """Gets the connections for CALM.""" 49 | anchor_layer = np.linspace(0, num_anchor_layers-1, num_connections, dtype=int) 50 | aug_layer = np.linspace(0, num_aug_layers-1, num_connections, dtype=int) 51 | 52 | return list(zip(anchor_layer, aug_layer)) 53 | 54 | 55 | def get_hidden_dims( 56 | anchor_model, 57 | aug_model, 58 | connection: tuple[int, int], 59 | ) -> tuple[int, int]: 60 | """Gets the hidden dimensions for the given layers.""" 61 | anchor_layer, aug_layer = connection 62 | anchor_hidden_dim = anchor_model.model.layers[anchor_layer].hidden_size 63 | aug_hidden_dim = aug_model.model.layers[aug_layer].hidden_size 64 | 65 | return anchor_hidden_dim, aug_hidden_dim 66 | 67 | -------------------------------------------------------------------------------- /model/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for utils.py.""" 17 | 18 | import unittest 19 | 20 | from model import calm 21 | from model import utils 22 | 23 | 24 | class UtilsTest(unittest.TestCase): 25 | 26 | def test_check_connections(self): 27 | """Tests if a few example connections are valid.""" 28 | self.assertTrue(utils.check_connections([(0, 0), (1, 1)], 2, 2)) 29 | self.assertFalse(utils.check_connections([(0, 0), (1, 1)], 1, 2)) 30 | self.assertFalse(utils.check_connections([(0, 0), (1, 1)], 2, 1)) 31 | 32 | def test_get_connections(self): 33 | """Tests that connections are formed correctly using get_connections.""" 34 | self.assertEqual(utils.get_connections(2, 2, 2), [(0, 0), (1, 1)]) 35 | self.assertEqual(utils.get_connections(1, 2, 2), [(0, 0)]) 36 | self.assertEqual(utils.get_connections(2, 1, 2), [(0, 0), (0, 1)]) 37 | self.assertEqual(utils.get_connections(2, 2, 1), [(0, 0), (1, 0)]) 38 | 39 | def test_get_hidden_dims(self): 40 | """Tests that the hidden dimensions are set correctly.""" 41 | config = calm.CALMConfig( 42 | anchor_model="google/gemma-2b", 43 | aug_model="google/gemma-2b", 44 | num_connections=2, 45 | num_heads=1, 46 | ) 47 | model = calm.CALM(config) 48 | for connection in model.connections: 49 | anchor_hidden_dim, aug_hidden_dim = utils.get_hidden_dims( 50 | model.anchor_model, model.aug_model, tuple(connection) 51 | ) 52 | self.assertEqual( 53 | model.anchor_model.model.layers[connection[0]].hidden_size, 54 | anchor_hidden_dim 55 | ) 56 | self.assertEqual( 57 | model.aug_model.model.layers[connection[1]].hidden_size, 58 | aug_hidden_dim 59 | ) 60 | 61 | if __name__ == "__main__": 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /model/calm_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for calm.py.""" 17 | 18 | import unittest 19 | 20 | from model import calm 21 | from model import utils 22 | import torch 23 | 24 | 25 | class CalmTest(unittest.TestCase): 26 | 27 | def setUp(self): 28 | """Sets up the CALM model for testing.""" 29 | super().setUp() 30 | self.config = calm.CALMConfig( 31 | anchor_model="google/gemma-2b", 32 | aug_model="google/gemma-2b", 33 | num_connections=2, 34 | num_heads=1, 35 | ) 36 | self.model = calm.CALM(self.config) 37 | 38 | def test_calm_config(self): 39 | """Tests that the CALM configuration is set correctly.""" 40 | config = calm.CALMConfig( 41 | anchor_model="google/gemma-2b", 42 | aug_model="google/gemma-2b", 43 | num_connections=2, 44 | num_heads=1, 45 | ) 46 | self.assertEqual(config.anchor_model, "google/gemma-2b") 47 | self.assertEqual(config.aug_model, "google/gemma-2b") 48 | self.assertEqual(config.num_connections, 2) 49 | 50 | def test_calm_forward(self): 51 | """Tests whether the CALM model returns the same output shape as the anchor model.""" 52 | output = self.model( 53 | input_ids=torch.ones(1, 10), 54 | attention_mask=torch.ones(1, 10), 55 | ) 56 | output_anchor_model = self.model.anchor_model( 57 | input_ids=torch.ones(1, 10), 58 | attention_mask=torch.ones(1, 10), 59 | ) 60 | self.assertEqual(output[0].shape, output_anchor_model[0].shape) 61 | 62 | def test_calm_connections(self): 63 | """Tests that the CALM connections are set correctly.""" 64 | config = calm.CALMConfig( 65 | anchor_model="google/gemma-2b", 66 | aug_model="google/gemma-2b", 67 | num_connections=2, 68 | num_heads=1, 69 | ) 70 | model = calm.CALM(config) 71 | self.assertEqual(model.connections, [(0, 0), (17, 17)]) 72 | 73 | def test_get_hidden_dim(self): 74 | """Tests that the hidden dimensions are set correctly.""" 75 | for connection in self.model.connections: 76 | anchor_hidden_dim, aug_hidden_dim = utils.get_hidden_dims( 77 | self.model.anchor_model, self.model.aug_model, tuple(connection) 78 | ) 79 | self.assertEqual( 80 | self.model.anchor_model.model.layers[connection[0]].hidden_size, 81 | anchor_hidden_dim 82 | ) 83 | self.assertEqual( 84 | self.model.aug_model.model.layers[connection[1]].hidden_size, 85 | aug_hidden_dim 86 | ) 87 | 88 | def test_cross_attention_hook(self): 89 | """Tests that the cross attention hook's embed_dim is same as anchor model's hidden size.""" 90 | for connection_idx, connection in enumerate(self.model.connections): 91 | anchor_hidden_dim, _ = utils.get_hidden_dims( 92 | self.model.anchor_model, self.model.aug_model, tuple(connection) 93 | ) 94 | self.assertEqual( 95 | self.model.cross_attention_hooks[ 96 | connection_idx 97 | ].cross_attention.embed_dim, 98 | anchor_hidden_dim, 99 | ) 100 | 101 | def test_calm_generate(self): 102 | """Tests if generate is working correctly.""" 103 | input_ids = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=int) 104 | generate_ids = self.model.generate(input_ids, max_length=10) 105 | print(generate_ids) 106 | 107 | 108 | if __name__ == "__main__": 109 | unittest.main() 110 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """CALM training script for finetuning. 17 | 18 | Hugging Face Trainer is used to train the CALM model. 19 | Reference: 20 | https://huggingface.co/docs/transformers/main_classes/trainer 21 | """ 22 | 23 | from collections.abc import Sequence 24 | 25 | from absl import app 26 | from absl import flags 27 | from absl import logging 28 | import datasets 29 | from model import calm 30 | from transformers import AutoTokenizer 31 | from transformers import DataCollatorForLanguageModeling 32 | from transformers import Trainer 33 | from transformers import TrainingArguments 34 | 35 | 36 | _ANCHOR_MODEL_DIR = flags.DEFINE_string( 37 | 'anchor_model_dir', None, 'anchor model path.' 38 | ) 39 | _AUG_MODEL_DIR = flags.DEFINE_string('aug_model_dir', None, 'aug model path.') 40 | _OUTPUT_DIR = flags.DEFINE_string('output_dir', None, 'output directory.') 41 | _LEARNING_RATE = flags.DEFINE_float('learning_rate', 2e-5, 'learning rate.') 42 | _EPOCHS = flags.DEFINE_integer('epochs', 3, 'number of epochs.') 43 | _BATCH_SIZE = flags.DEFINE_integer('batch_size', 1, 'batch size.') 44 | _NUM_HEADS = flags.DEFINE_integer('num_heads', 1, 'number of heads.') 45 | _NUM_CONNECTIONS = flags.DEFINE_integer( 46 | 'num_connections', 2, 'number of connections.' 47 | ) 48 | _CONNECTIONS = flags.DEFINE_list( 49 | 'connections', 50 | None, 51 | 'connections between the anchor and aug model. You cannot provide both' 52 | 'connections and num_connections simultaneously.', 53 | ) 54 | _EVAL_STEPS = flags.DEFINE_integer('eval_steps', 50, 'eval steps.') 55 | _LOGGING_STEPS = flags.DEFINE_integer('logging_steps', 50, 'logging steps.') 56 | _SAVE_STEPS = flags.DEFINE_integer('save_steps', 50, 'save steps.') 57 | _MAX_STEPS = flags.DEFINE_integer('max_steps', 100, 'max steps.') 58 | 59 | 60 | def train(argv: Sequence[str]) -> None: 61 | """Trains the CALM model.""" 62 | del argv # Unused. 63 | anchor_model_path = _ANCHOR_MODEL_DIR.value 64 | aug_model_path = _AUG_MODEL_DIR.value 65 | num_heads = _NUM_HEADS.value 66 | num_connections = _NUM_CONNECTIONS.value 67 | logging.info('anchor_model_path: %s', anchor_model_path) 68 | logging.info('aug_model_path: %s', aug_model_path) 69 | logging.info('Loading Tokenizer...') 70 | tokenizer = AutoTokenizer.from_pretrained(anchor_model_path) 71 | logging.info('Loading Composed Model...') 72 | calm_config = calm.CALMConfig( 73 | anchor_model=anchor_model_path, 74 | aug_model=aug_model_path, 75 | anchor_config=None, 76 | aug_config=None, 77 | num_connections=num_connections, 78 | num_heads=num_heads, 79 | ) 80 | 81 | model = calm.CALM(calm_config) 82 | train_data = datasets.load_dataset( 83 | path='Salesforce/wikitext', name='wikitext-2-raw-v1' 84 | ) 85 | 86 | def preprocess_function(examples): 87 | return tokenizer( 88 | examples['text'], truncation=True, padding='max_length', max_length=512 89 | ) 90 | 91 | train_data = train_data.map(preprocess_function, batched=True) 92 | data_collator = DataCollatorForLanguageModeling( 93 | tokenizer=tokenizer, mlm=False 94 | ) 95 | 96 | epochs = _EPOCHS.value 97 | batch_size = _BATCH_SIZE.value 98 | learning_rate = _LEARNING_RATE.value 99 | output_dir = _OUTPUT_DIR.value 100 | eval_steps = _EVAL_STEPS.value 101 | logging_steps = _LOGGING_STEPS.value 102 | save_steps = _SAVE_STEPS.value 103 | max_steps = _MAX_STEPS.value 104 | training_args = TrainingArguments( 105 | output_dir=output_dir, 106 | overwrite_output_dir=True, 107 | num_train_epochs=epochs, 108 | do_train=True, 109 | do_eval=True, 110 | per_device_train_batch_size=batch_size, 111 | per_device_eval_batch_size=batch_size, 112 | eval_strategy='steps', # pylint:disable=unexpected-keyword-arg 113 | eval_steps=eval_steps, 114 | logging_steps=logging_steps, 115 | save_steps=save_steps, 116 | max_steps=max_steps, 117 | learning_rate=learning_rate, 118 | label_names=[], 119 | report_to=['tensorboard'], 120 | ) 121 | 122 | trainer = Trainer( 123 | model=model, 124 | args=training_args, 125 | train_dataset=train_data['train'], 126 | eval_dataset=train_data['test'], 127 | data_collator=data_collator, 128 | tokenizer=tokenizer, 129 | ) 130 | 131 | trainer.can_return_loss = True 132 | 133 | trainer.train() 134 | 135 | trainer.save_model( 136 | output_dir, 137 | ) 138 | 139 | print(f'Training complete! Model saved to {output_dir}') 140 | 141 | 142 | if __name__ == '__main__': 143 | app.run(train) 144 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Layer operation classes for CALM.""" 17 | 18 | from typing import Union 19 | 20 | import torch 21 | from transformers.models.gemma import modeling_gemma 22 | 23 | 24 | def freeze_model(model): 25 | """Freezes the model.""" 26 | for param in model.parameters(): 27 | param.requires_grad = False 28 | 29 | 30 | def process_hook_args( 31 | model: torch.nn.Module, # pylint: disable=unused-argument 32 | inp: Union[torch.Tensor, tuple[torch.Tensor, ...]], # pylint: disable=unused-argument 33 | out: Union[torch.Tensor, tuple[torch.Tensor, ...]], 34 | ): 35 | """Extracts the main output tensor from a PyTorch hook output. 36 | 37 | Args: 38 | model: The nn.Module object to which the hook is attached. 39 | inp: Input tensor to the layer (ignored). 40 | out: Output from the layer. This can be a tensor or a tuple containing the 41 | tensor. 42 | Reference: 43 | register_forward_hook in 44 | https://pytorch.org/docs/stable/generated/torch.nn.Module.html 45 | Returns: 46 | The main output tensor from the hooked block. 47 | """ 48 | anchor_hidden_state = out[0] if isinstance(out, tuple) else out 49 | query = anchor_hidden_state 50 | return query, out 51 | 52 | 53 | class CrossAttentionHook(torch.nn.Module): 54 | """cross attention hook for CALM.""" 55 | 56 | def __init__( 57 | self, 58 | anchor_hidden_dim: int, 59 | aug_hidden_dim: int, 60 | num_heads: int, 61 | rms_norm_eps: float = 1e-6, 62 | ): 63 | """Initializes the cross attention hook. 64 | 65 | Args: 66 | anchor_hidden_dim: The hidden dimension of the anchor model. 67 | aug_hidden_dim: The hidden dimension of the augmented model. 68 | num_heads: The number of attention heads in the hook 69 | rms_norm_eps: The epsilon value for the post-attention RMS norm layer 70 | 71 | Attributes: 72 | proj: The projection layer to project the augmented hidden state to the 73 | anchor hidden dimension. 74 | embed_dim: The hidden dimension of the anchor model. 75 | num_heads: The number of attention heads in the hook. 76 | cross_attention: The cross attention layer. 77 | aug_hidden_state: The augmented hidden state tensor. This is set by 78 | forward_aug in CALM. 79 | aug_mask: The augmented mask tensor. This is set by forward_aug in CALM. 80 | attn_weights: The attention weights tensor. This is set by the forward 81 | pass of the cross attention hook. 82 | Example: 83 | hook = CrossAttentionHook(anchor_hidden_dim, aug_hidden_dim, num_heads) 84 | model.register_forward_hook(hook) 85 | model(input) 86 | print(hook.attn_weights) 87 | """ 88 | super().__init__() 89 | self.proj = torch.nn.Linear(aug_hidden_dim, anchor_hidden_dim) 90 | self.embed_dim = anchor_hidden_dim 91 | self.num_heads = num_heads 92 | self.post_attention_layernorm = modeling_gemma.GemmaRMSNorm( 93 | self.embed_dim, eps=rms_norm_eps 94 | ) 95 | self.cross_attention = torch.nn.MultiheadAttention( 96 | self.embed_dim, 97 | num_heads, 98 | kdim=self.embed_dim, 99 | vdim=self.embed_dim, 100 | batch_first=True, 101 | ) 102 | self.aug_hidden_state = None 103 | self.aug_mask = None 104 | self.attn_weights = None 105 | 106 | def forward(self, *hook_args): 107 | """Forward pass of the cross attention hook. 108 | 109 | Args: 110 | *hook_args: The arguments passed to the hook. 111 | 112 | Raises: 113 | ValueError: If aug_hidden_state or aug_mask is None. 114 | 115 | The cross attention hook is registered to the anchor model. The hook 116 | extracts the hidden state from the anchor model and uses it as the query 117 | for the cross attention. The key and value for the cross attention are 118 | computed by projecting the hidden state from the augmented model. The 119 | augmented hidden state and mask are set by forward_aug in CALM. 120 | 121 | Returns: 122 | The modified output of the cross attention hook. 123 | """ 124 | query, output = process_hook_args(*hook_args) 125 | assert self.aug_hidden_state is not None 126 | assert self.aug_mask is not None 127 | key = self.proj(self.aug_hidden_state) 128 | value = self.proj(self.aug_hidden_state) 129 | 130 | self.aug_mask = self.aug_mask.float() 131 | attn_output, attn_weights = self.cross_attention( 132 | query, key, value, need_weights=True 133 | ) 134 | self.attn_weights = attn_weights 135 | 136 | attn_output = self.post_attention_layernorm(attn_output) 137 | output_fin = attn_output + query 138 | new_output = (output_fin,) + output[1:] 139 | return new_output 140 | 141 | 142 | class ExtractHiddenStateHook(torch.nn.Module): 143 | """Extract hidden state hook for CALM.""" 144 | 145 | def __init__(self): 146 | """Initializes the extract hidden state hook. 147 | 148 | Attributes: 149 | hidden_state: The hidden state tensor. This is set by the forward pass of 150 | the extract hidden state hook. 151 | Example: 152 | ```python 153 | hook = ExtractHiddenStateHook() 154 | model.register_forward_hook(hook) 155 | model(input) 156 | print(hook.hidden_state) 157 | ``` 158 | """ 159 | super().__init__() 160 | self.hidden_state = None 161 | 162 | def forward(self, *hook_args): 163 | hidden_state, out = process_hook_args(*hook_args) 164 | self.hidden_state = hidden_state 165 | return out 166 | 167 | 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # CALM: Expanding LLM Capabilities through Composition 3 | 4 | This repository provides the code for implementing the CALM (Composition to Augment Language Models) framework described in the paper ["LLMs Augmented LLMs: Expanding Capabilities through Composition"](https://arxiv.org/abs/2401.02412). In this paper, we describe composing two language models by introducing cross-attention between models to compose their representations and enable new capabilities. The code currently supports combining any two models built with the Gemma architecture. 5 | 6 | ## Installation 7 | 8 | Clone the repo 9 | 10 | ``` 11 | git clone https://github.com/google-deepmind/calm.git 12 | cd calm 13 | ``` 14 | 15 | Create a virtual environment using virtualenv or conda depending on your preferences and install the requirements. We require Python 3.11 or above: 16 | 17 | ``` 18 | conda create -n calm python=3.11.11 && conda activate calm 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | Ensure you have logged in using a 🤗 read access token for using the Gemma models. For more information, see: [🤗 User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens). 23 | 24 | ``` 25 | huggingface-cli login 26 | ``` 27 | 28 | ## Usage 29 | 30 | Ex. Initialising a composed model. You can compose the models by providing the model paths under the arguments `anchor_model` and `aug_model`. Models that are loaded via `transformers.AutoModelForCausalLM` (autoregressive, decoder-only Gemma style models) in Huggingface / Local Directories are supported. 31 | 32 | ``` 33 | from model import calm 34 | 35 | calm_config = calm.CALMConfig( 36 | anchor_model="google/gemma-2b", 37 | aug_model="google/gemma-2b", 38 | connections=[(0,0),(1,1)], # Each element is a tuple (anchor_model_layer_index, aug_model_layer_index) 39 | num_heads=2, 40 | ) 41 | 42 | model = calm.CALM(calm_config) 43 | ``` 44 | You can also use the `num_connections` argument to initialize the composed model, in which case connections are created uniformly across anchor and augmenting models. 45 | 46 | ``` 47 | calm_config = calm.CALMConfig( 48 | anchor_model="google/gemma-2b", 49 | aug_model="google/gemma-2b", 50 | num_connections=2, 51 | num_heads=2, 52 | ) 53 | ``` 54 | 55 | Ex. Saving and Loading a model 56 | 57 | ``` 58 | calm_config.save_pretrained('./calm_config') 59 | model.save_pretrained('./calm') 60 | 61 | config = CALMConfig.from_pretrained("./calm_config") 62 | loaded_model = CALM.from_pretrained("./calm", config = config) 63 | ``` 64 | 65 | You can finetune the composed model using [🤗 Trainer](https://huggingface.co/docs/transformers/en/main_classes/trainer) 66 | 67 | ``` 68 | training_args = TrainingArguments( 69 | output_dir="./tmp", 70 | overwrite_output_dir=True, 71 | num_train_epochs=epochs, 72 | do_train=True, 73 | do_eval=True, 74 | per_device_train_batch_size=batch_size, 75 | per_device_eval_batch_size=batch_size, 76 | eval_strategy='steps', 77 | eval_steps=eval_steps, 78 | logging_steps=logging_steps, 79 | save_steps=save_steps, 80 | max_steps=max_steps, 81 | learning_rate=learning_rate, 82 | label_names=[], 83 | report_to=['tensorboard'], 84 | ) 85 | 86 | trainer = Trainer( 87 | model=model, 88 | args=training_args, 89 | train_dataset=data['train'], 90 | eval_dataset=data['test'], 91 | data_collator=data_collator, 92 | tokenizer=tokenizer, 93 | ) 94 | 95 | trainer.can_return_loss = True 96 | 97 | trainer.train() 98 | ``` 99 | 100 | An example multi-gpu training pipeline is given in `train.py` where we train a composed gemma-2b and gemma-7b using [Wikitext](https://huggingface.co/datasets/Salesforce/wikitext) data. You can run it using [🤗 Accelerate FSDP](https://huggingface.co/docs/accelerate/en/usage_guides/fsdp) 101 | 102 | An example accelerate config file is provided in `accelerate_config.yaml` 103 | 104 | ``` 105 | accelerate launch --config_file accelerate_config.yaml train.py \ 106 | --anchor_model_dir google/gemma-7b \ 107 | --aug_model_dir google/gemma-2b \ 108 | --num_heads 2 \ 109 | --num_connections 2 \ 110 | --learning_rate 3e-4 \ 111 | --batch_size 8 \ 112 | --output_dir './tmp' 113 | ``` 114 | 115 | You can generate from the model the same way as any transformers model 116 | 117 | ``` 118 | from transformers import AutoTokenizer 119 | tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") 120 | prompt = "I am going " 121 | inputs = tokenizer(prompt, return_tensors="pt") 122 | 123 | generate_ids = model.generate(inputs.input_ids, max_length=10) 124 | print(tokenizer.decode(generate_ids[0], skip_special_tokens=True)) 125 | ``` 126 | 127 | ## Citing this work 128 | 129 | 130 | ```latex 131 | @misc{bansal2024llmaugmentedllmsexpanding, 132 | title={LLM Augmented LLMs: Expanding Capabilities through Composition}, 133 | author={Rachit Bansal and Bidisha Samanta and Siddharth Dalmia and Nitish Gupta and 134 | Shikhar Vashishth and Sriram Ganapathy and Abhishek Bapna and Prateek Jain and Partha Talukdar}, 135 | year={2024}, 136 | eprint={2401.02412}, 137 | archivePrefix={arXiv}, 138 | primaryClass={cs.LG}, 139 | url={https://arxiv.org/abs/2401.02412}, 140 | } 141 | ``` 142 | 143 | ## License and disclaimer 144 | 145 | Copyright 2024 DeepMind Technologies Limited 146 | 147 | For HuggingFace Transformers - Copyright 2018 Hugging Face, licensed under Apache 2.0 148 | 149 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 150 | you may not use this file except in compliance with the Apache 2.0 license. 151 | You may obtain a copy of the Apache 2.0 license at: 152 | https://www.apache.org/licenses/LICENSE-2.0 153 | 154 | All other materials are licensed under the Creative Commons Attribution 4.0 155 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 156 | https://creativecommons.org/licenses/by/4.0/legalcode 157 | 158 | Unless required by applicable law or agreed to in writing, all software and 159 | materials distributed here under the Apache 2.0 or CC-BY licenses are 160 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 161 | either express or implied. See the licenses for the specific language governing 162 | permissions and limitations under those licenses. 163 | 164 | This is not an official Google product. 165 | 166 | ## Contact 167 | 168 | Please direct any questions at calm-contact@google.com 169 | 170 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /model/calm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """CALM implementation.""" 17 | 18 | import os 19 | from typing import Callable, List, Optional, Tuple, Union 20 | from model import layers 21 | from model import utils 22 | import torch 23 | import transformers 24 | 25 | 26 | class CALMConfig(transformers.PretrainedConfig): 27 | """CALM configuration. 28 | 29 | Configuration file for CALM. Enables the user to specify the anchor and 30 | augmented models, the number of connections, and the number of heads in each 31 | cross attention hook. 32 | """ 33 | model_type = "calm" 34 | 35 | def __init__( 36 | self, 37 | anchor_model: str = "google/gemma-2b", 38 | aug_model: str = "google/gemma-2b", 39 | anchor_config: Optional[transformers.AutoConfig] = None, 40 | aug_config: Optional[transformers.AutoConfig] = None, 41 | connections: list[Tuple[int, int]] = None, 42 | num_connections: int = None, 43 | num_heads: int = 1, 44 | **kwargs, 45 | ): 46 | """CALM configuration. 47 | 48 | Args: 49 | anchor_model: HF Repo ID or Path to the anchor model. 50 | aug_model: HF Repo ID or Path to the augmented model. 51 | anchor_config: Config for the anchor model. If None, the config will be 52 | loaded from the anchor model. If a dict is provided, it will be 53 | converted to a GemmaConfig. 54 | aug_config: Config for the augmenting model. If None, the config will be 55 | loaded from the augmenting model. If a dict is provided, it will be 56 | converted to a GemmaConfig. 57 | connections: The connections between the anchor and augmented models. If 58 | None, num_connections must be set. Every connection is a tuple of 59 | (anchor_layer_idx, aug_layer_idx). 60 | num_connections: The number of connections between the anchor and 61 | augmented models. If none, connections must be set. 62 | num_heads: The number of attention heads in each cross attention hook. 63 | **kwargs: 64 | """ 65 | 66 | self.anchor_model = anchor_model 67 | self.aug_model = aug_model 68 | self.connections = connections 69 | self.num_connections = num_connections 70 | self.num_heads = num_heads 71 | self.anchor_config = anchor_config 72 | self.aug_config = aug_config 73 | super().__init__(**kwargs) 74 | 75 | 76 | class CALM(transformers.PreTrainedModel): 77 | """CALM implementation. 78 | 79 | Class for composing the anchor and augmented models. The class is designed to 80 | integrate with the transformers library. You can use the CALM object for 81 | training, evaluation, and inference just like any other transformers model. 82 | """ 83 | 84 | config_class = CALMConfig 85 | 86 | @property 87 | def lm_head(self): 88 | """Returns the language model head.""" 89 | return self.anchor_model.lm_head 90 | 91 | def __init__(self, config: CALMConfig): 92 | """CALM implementation. 93 | 94 | Args: 95 | config: CALMConfig. 96 | 97 | Raises: 98 | ValueError: If config.connections is None and config.num_connections is 99 | None. 100 | 101 | Initializes the CALM model by composing the anchor and augmented models. 102 | The anchor model and the augmenting model are frozen and the augmented model 103 | is used to provide hidden states for the cross attention hooks. The 104 | cross attention hooks are registered to the anchor model. 105 | """ 106 | super().__init__(config) # pylint: disable=too-many-function-args 107 | if config.anchor_config is None: 108 | config.anchor_config = transformers.AutoConfig.from_pretrained( 109 | config.anchor_model 110 | ) 111 | if config.aug_config is None: 112 | config.aug_config = transformers.AutoConfig.from_pretrained( 113 | config.aug_model 114 | ) 115 | if isinstance(config.anchor_config, dict): 116 | config.anchor_config = transformers.GemmaConfig.from_dict( 117 | config.anchor_config 118 | ) 119 | if isinstance(config.aug_config, dict): 120 | config.aug_config = transformers.GemmaConfig.from_dict(config.aug_config) 121 | 122 | self.anchor_model = transformers.AutoModelForCausalLM.from_pretrained( 123 | config.anchor_model, 124 | config=config.anchor_config, 125 | ) 126 | self.aug_model = transformers.AutoModelForCausalLM.from_pretrained( 127 | config.aug_model, 128 | config=config.aug_config, 129 | ) 130 | self.vocab_size = self.anchor_model.config.vocab_size 131 | self.config = config 132 | self.num_anchor_layers = len(self.anchor_model.model.layers) 133 | self.num_aug_layers = len(self.aug_model.model.layers) 134 | 135 | assert (config.connections is None) ^ (config.num_connections is None) 136 | 137 | if config.connections is not None: 138 | assert utils.check_connections( 139 | config.connections, self.num_anchor_layers, self.num_aug_layers 140 | ) 141 | self.connections = config.connections 142 | self.num_connections = len(config.connections) 143 | else: 144 | self.num_connections = config.num_connections 145 | self.connections = utils.get_connections( 146 | config.num_connections, self.num_anchor_layers, self.num_aug_layers 147 | ) 148 | 149 | self.extract_hidden_state_hooks = {} 150 | for connection in self.connections: 151 | aug_connection_idx = connection[1] 152 | hook = layers.ExtractHiddenStateHook() 153 | self.extract_hidden_state_hooks[tuple(connection)] = hook 154 | self.aug_model.model.layers[aug_connection_idx].register_forward_hook( 155 | hook 156 | ) 157 | 158 | self.connection_hidden_dims = [] 159 | for connection in self.connections: 160 | anchor_hidden_dim, aug_hidden_dim = utils.get_hidden_dims( 161 | self.anchor_model, self.aug_model, tuple(connection) 162 | ) 163 | self.connection_hidden_dims.append((anchor_hidden_dim, aug_hidden_dim)) 164 | 165 | self.cross_attention_hooks = torch.nn.ModuleList([]) 166 | 167 | for _, connection_hidden_dim in zip( 168 | self.connections, self.connection_hidden_dims 169 | ): 170 | self.cross_attention_hooks.append( 171 | layers.CrossAttentionHook( 172 | anchor_hidden_dim=connection_hidden_dim[0], 173 | aug_hidden_dim=connection_hidden_dim[1], 174 | num_heads=config.num_heads, 175 | rms_norm_eps=self.anchor_model.config.rms_norm_eps, 176 | ) 177 | ) 178 | 179 | layers.freeze_model(self.anchor_model) 180 | layers.freeze_model(self.aug_model) 181 | 182 | for connection_idx, connection in enumerate(self.connections): 183 | connection_anchor_layer_idx = connection[0] 184 | layer = self.anchor_model.model.layers[connection_anchor_layer_idx] 185 | layer.register_forward_hook(self.cross_attention_hooks[connection_idx]) 186 | 187 | def release_memory(self): 188 | """Frees the memory of the CALM model after every forward pass.""" 189 | 190 | for cross_attention_hook in self.cross_attention_hooks: 191 | cross_attention_hook.aug_hidden_state = None 192 | cross_attention_hook.aug_mask = None 193 | cross_attention_hook.attn_weights = None 194 | for extract_hidden_state_hook in self.extract_hidden_state_hooks.values(): 195 | extract_hidden_state_hook.hidden_state = None 196 | 197 | def _forward_aug( 198 | self, 199 | input_ids: torch.LongTensor = None, 200 | attention_mask: Optional[torch.Tensor] = None, 201 | position_ids: Optional[torch.LongTensor] = None, 202 | past_key_values: Optional[ 203 | Union[transformers.Cache, List[torch.FloatTensor]] 204 | ] = None, 205 | inputs_embeds: Optional[torch.FloatTensor] = None, 206 | labels: Optional[torch.LongTensor] = None, 207 | use_cache: Optional[bool] = True, 208 | output_attentions: Optional[bool] = None, 209 | output_hidden_states: Optional[bool] = None, 210 | return_dict: Optional[bool] = None, 211 | cache_position: Optional[torch.LongTensor] = None, 212 | ): 213 | """Forwards the sequence through the augmented model. 214 | 215 | Args: 216 | input_ids: Input sequence. 217 | attention_mask: Input sequence mask. 218 | position_ids: Position ids. 219 | past_key_values: Past key values. 220 | inputs_embeds: Input embeddings. 221 | labels: Labels. If None, the model will be used in inference mode. If 222 | labels are provided, the model will be used in training mode. 223 | use_cache: Use cache. 224 | output_attentions: Output attentions. 225 | output_hidden_states: Output hidden states. 226 | return_dict: Return dict. If True, the output will be a dict. If False, 227 | the output will be a tuple. 228 | cache_position: Cache position. 229 | 230 | Returns: 231 | output: Output of the augmented model. 232 | 233 | The intermediate hidden states are extracted and are used to provide hidden 234 | states for the cross attention hooks. The output of the augmented model is 235 | returned. 236 | """ 237 | 238 | with torch.no_grad(): 239 | self.aug_model.eval() 240 | output = self.aug_model( 241 | input_ids=input_ids, 242 | attention_mask=attention_mask, 243 | position_ids=position_ids, 244 | past_key_values=past_key_values, 245 | inputs_embeds=inputs_embeds, 246 | labels=labels, 247 | use_cache=use_cache, 248 | output_attentions=output_attentions, 249 | output_hidden_states=output_hidden_states, 250 | return_dict=return_dict, 251 | cache_position=cache_position, 252 | ) 253 | for connection_idx, connection in enumerate(self.connections): 254 | aug_hidden_state = self.extract_hidden_state_hooks[ 255 | tuple(connection) 256 | ].hidden_state 257 | self.cross_attention_hooks[connection_idx].aug_hidden_state = ( 258 | aug_hidden_state 259 | ) 260 | self.cross_attention_hooks[connection_idx].aug_mask = attention_mask 261 | del aug_hidden_state 262 | return output 263 | 264 | def forward( 265 | self, 266 | input_ids: torch.LongTensor = None, 267 | attention_mask: Optional[torch.Tensor] = None, 268 | position_ids: Optional[torch.LongTensor] = None, 269 | past_key_values: Optional[ 270 | Union[transformers.Cache, List[torch.FloatTensor]] 271 | ] = None, 272 | inputs_embeds: Optional[torch.FloatTensor] = None, 273 | labels: Optional[torch.LongTensor] = None, 274 | use_cache: Optional[bool] = True, 275 | output_attentions: Optional[bool] = None, 276 | output_hidden_states: Optional[bool] = None, 277 | return_dict: Optional[bool] = None, 278 | cache_position: Optional[torch.LongTensor] = None, 279 | ): 280 | """CALM forward pass. 281 | 282 | Args: 283 | input_ids: Input sequence. 284 | attention_mask: Input sequence mask. 285 | position_ids: Position ids. 286 | past_key_values: Past key values. 287 | inputs_embeds: Input embeddings. 288 | labels: Labels. If None, the model will be used in inference mode. If 289 | labels are provided, the model will be used in training mode. 290 | use_cache: Use cache. 291 | output_attentions: Output attentions. 292 | output_hidden_states: Output hidden states. 293 | return_dict: Return dict. If True, the output will be a dict. If False, 294 | the output will be a tuple. 295 | cache_position: Cache position. 296 | 297 | Returns: 298 | The output of the CALM model. If labels are provided, the output 299 | will be the loss. If labels are not provided, the output will be the 300 | same class of the anchor model's output. 301 | 302 | Example: 303 | config = CALMConfig( 304 | anchor_model='google/gemma-2b', 305 | aug_model='google/gemma-2b', 306 | num_connections=2, 307 | num_heads=1, 308 | ) 309 | model = CALM(config) 310 | output = model(input_ids, attention_mask) 311 | """ 312 | aug_output = self._forward_aug( 313 | input_ids=input_ids, 314 | attention_mask=attention_mask, 315 | position_ids=position_ids, 316 | past_key_values=past_key_values, 317 | inputs_embeds=inputs_embeds, 318 | labels=labels, 319 | use_cache=use_cache, 320 | output_attentions=output_attentions, 321 | output_hidden_states=output_hidden_states, 322 | return_dict=return_dict, 323 | cache_position=cache_position, 324 | ) 325 | del aug_output 326 | 327 | output = self.anchor_model( 328 | input_ids=input_ids, 329 | attention_mask=attention_mask, 330 | position_ids=position_ids, 331 | past_key_values=past_key_values, 332 | inputs_embeds=inputs_embeds, 333 | labels=labels, 334 | use_cache=use_cache, 335 | output_attentions=output_attentions, 336 | output_hidden_states=output_hidden_states, 337 | return_dict=return_dict, 338 | cache_position=cache_position, 339 | ) 340 | return output 341 | 342 | def save_pretrained( 343 | self, 344 | save_directory: Union[str, os.PathLike[str]], 345 | is_main_process: bool = True, 346 | state_dict: Optional[ 347 | dict[str, dict[str, dict[str, torch.Tensor]]] 348 | ] = None, 349 | save_function: Callable[..., None] = torch.save, 350 | push_to_hub: bool = False, 351 | max_shard_size: Union[int, str] = "10GB", 352 | safe_serialization: bool = True, 353 | variant: Optional[str] = None, 354 | token: Optional[Union[bool, str]] = None, 355 | save_peft_format: bool = False, 356 | **kwargs, 357 | ): 358 | """Save the CALM model to a directory. 359 | 360 | Args: 361 | save_directory: The directory to save the model to. 362 | is_main_process: Whether this process is the main process. 363 | state_dict: The state dictionary to save. 364 | save_function: The function to use to save the state dictionary. 365 | push_to_hub: Whether to push the model to the hub. 366 | max_shard_size: The maximum shard size. 367 | safe_serialization: Whether to allow safe serialization. Set false to 368 | allow saving models with shared weights. 369 | variant: The variant of the model to save. 370 | token: The token to use to push the model to the hub. 371 | save_peft_format: Whether to save the model in the PEFT format. 372 | **kwargs: Additional keyword arguments. 373 | 374 | This method overrides the default save_pretrained method to handle the 375 | shared weights issue. It sets safe_serialization to False to allow saving 376 | models with shared weights. 377 | """ 378 | super().save_pretrained( # pytype: disable=attribute-error 379 | save_directory=save_directory, 380 | is_main_process=is_main_process, 381 | state_dict=state_dict, 382 | save_function=save_function, 383 | push_to_hub=push_to_hub, 384 | max_shard_size=max_shard_size, 385 | safe_serialization=False, 386 | variant=variant, 387 | token=token, 388 | save_peft_format=save_peft_format, 389 | **kwargs, 390 | ) 391 | 392 | def prepare_inputs_for_generation( 393 | self, 394 | input_ids, 395 | past_key_values=None, 396 | attention_mask=None, 397 | inputs_embeds=None, 398 | cache_position=None, 399 | use_cache=True, 400 | **kwargs, 401 | ): 402 | """Prepares the inputs for generation. 403 | 404 | Args: 405 | input_ids: Input sequence. 406 | past_key_values: Past key values. 407 | attention_mask: Input sequence mask. 408 | inputs_embeds: Input embeddings. 409 | cache_position: Cache position. 410 | use_cache: Use cache. 411 | **kwargs: Additional keyword arguments. 412 | 413 | Returns: 414 | The prepared inputs for generation. 415 | """ 416 | past_length = 0 417 | if past_key_values is not None: 418 | if isinstance(past_key_values, transformers.Cache): 419 | past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() # pylint: disable=line-too-long 420 | max_cache_length = ( 421 | torch.tensor(past_key_values.get_max_length(), device=input_ids.device) # pylint: disable=line-too-long 422 | if past_key_values.get_max_length() is not None 423 | else None 424 | ) 425 | cache_length = ( 426 | past_length 427 | if max_cache_length is None 428 | else torch.min(max_cache_length, past_length) 429 | ) 430 | else: 431 | cache_length = past_length = past_key_values[0][0].shape[2] 432 | max_cache_length = None 433 | 434 | # Keep only the unprocessed tokens: 435 | # 1 - If the length of the attention_mask exceeds the length of 436 | # input_ids, then we are in a setting where some of the inputs are 437 | # exclusively passed as part of the cache (e.g. when passing 438 | # input_embeds as input) 439 | if ( 440 | attention_mask is not None 441 | and attention_mask.shape[1] > input_ids.shape[1] 442 | ): 443 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 444 | # 2 - If the past_length is smaller than input_ids.shape[1], then 445 | # input_ids holds all input tokens. We can discard input_ids based on 446 | # the past_length. 447 | elif past_length < input_ids.shape[1]: 448 | input_ids = input_ids[:, past_length:] 449 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume 450 | # input_ids only has unprocessed tokens. 451 | 452 | # If we are about to go beyond the maximum cache length, we need to crop 453 | # the input attention mask. 454 | if ( 455 | max_cache_length is not None 456 | and attention_mask is not None 457 | and cache_length + input_ids.shape[1] > max_cache_length 458 | ): 459 | attention_mask = attention_mask[:, -max_cache_length :] 460 | 461 | position_ids = kwargs.get("position_ids", None) 462 | if attention_mask is not None and position_ids is None: 463 | # create position_ids on the fly for batch generation 464 | position_ids = attention_mask.long().cumsum(-1) - 1 465 | position_ids.masked_fill_(attention_mask == 0, 1) 466 | if past_key_values: 467 | position_ids = position_ids[:, -input_ids.shape[1] :] 468 | 469 | # if `inputs_embeds` are passed, we only want to use them in the 1st 470 | # generation step 471 | if inputs_embeds is not None and past_key_values is None: 472 | model_inputs = {"inputs_embeds": inputs_embeds.contiguous()} 473 | else: 474 | model_inputs = {"input_ids": input_ids.contiguous()} 475 | 476 | input_length = ( 477 | position_ids.shape[-1] 478 | if position_ids is not None 479 | else input_ids.shape[-1] 480 | ) 481 | if cache_position is None: 482 | cache_position = torch.arange( 483 | past_length, past_length + input_length, device=input_ids.device 484 | ) 485 | elif use_cache: 486 | cache_position = cache_position[-input_length:] 487 | 488 | model_inputs.update({ 489 | "position_ids": position_ids, 490 | "cache_position": cache_position, 491 | "past_key_values": past_key_values, 492 | "use_cache": use_cache, 493 | "attention_mask": attention_mask, 494 | }) 495 | 496 | return model_inputs 497 | --------------------------------------------------------------------------------