├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── assets ├── hedgehog_llamas.png ├── hedgehog_llamas_big.png └── lolcats_and_tk_llamas.png ├── configs ├── experiment │ ├── distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml │ ├── distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml │ ├── eval_alpaca_clean.yaml │ ├── finetune_lora_fqkvo_alpaca_clean.yaml │ ├── finetune_lora_qkvo_alpaca_clean.yaml │ └── no_distill_alpaca_clean.yaml └── model │ ├── base_llama3_1_8b.yaml │ ├── base_llama3_8b.yaml │ ├── base_mistral_7b.yaml │ ├── chunked_experimental │ ├── distill_long_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml │ ├── distill_long_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml │ ├── distill_long_llama3_8b_lk_smd_wsw64_fd64_w01.yaml │ ├── distill_long_llama3_8b_lk_smd_wtk64_fd64_w01.yaml │ ├── distill_long_mistral_7b_lk_smd_wsw64_fd64_w01.yaml │ └── distill_long_mistral_7b_lk_smd_wtk64_fd64_w01.yaml │ ├── distill_llama3_1_8b_lk_smd_fd64.yaml │ ├── distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml │ ├── distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml │ ├── distill_llama3_1_8b_lk_t2r.yaml │ ├── distill_llama3_8b_lk_smd_fd64.yaml │ ├── distill_llama3_8b_lk_smd_wsw64_fd64_w01.yaml │ ├── distill_llama3_8b_lk_smd_wtk64_fd64_w01.yaml │ ├── distill_llama3_8b_lk_t2r.yaml │ ├── distill_mistral_7b_lk_smd_fd64.yaml │ ├── distill_mistral_7b_lk_smd_wsw64_fd64_w01.yaml │ ├── distill_mistral_7b_lk_smd_wtk64_fd64_w01.yaml │ └── distill_mistral_7b_lk_t2r.yaml ├── csrc ├── __init__.py ├── causal_attention.cpp ├── causal_attention.py ├── causal_attention_cuda.cu ├── causal_attention_kv_cuda.cu └── setup.py ├── demo_lolcats_llm.py ├── demos ├── README.md ├── benchmark_8b.sh ├── demo_8b.sh ├── demo_lolcats_hf.py └── llm_mmlu_eval │ ├── demo_405b.sh │ ├── demo_70b.sh │ ├── eval_mmlu.py │ └── mmlu.pkl.zip ├── distill_llama.py ├── environment.yaml ├── lm_eval_harness ├── README.md ├── __init__.py ├── eval_lm_harness.py ├── eval_lm_harness_big.py ├── models.py └── models_huggingface.py ├── lolcats_preprint_v0.pdf └── src ├── __init__.py ├── dataloaders ├── __init__.py ├── alpaca_clean.py ├── alpaca_clean_instruct.py └── utils │ ├── __init__.py │ ├── llama3.py │ ├── packing.py │ └── setup.py ├── finetune.py ├── model ├── __init__.py ├── convert_model.py ├── feature_map.py ├── linear_attention │ ├── __init__.py │ ├── linear_attention.py │ ├── linear_window_attention_sw.py │ ├── linear_window_attention_sw_linear.py │ ├── linear_window_attention_sw_long.py │ ├── linear_window_attention_tk.py │ ├── linear_window_attention_tk_gen.py │ ├── linear_window_attention_tk_long.py │ └── utils.py ├── load_model.py ├── load_model_for_eval.py ├── modeling_llama.py ├── modeling_llama_sharded.py ├── modeling_mistral.py ├── peft.py ├── pretrained.py ├── rotary.py └── utils.py ├── trainer ├── __init__.py ├── default_lm.py ├── distill_attention_mse_linear.py ├── distill_attention_xent_mse.py ├── finetune_seq2seq.py ├── optim.py └── utils.py └── utils ├── __init__.py ├── logging.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../") 3 | 4 | import torch 5 | import gradio as gr 6 | from omegaconf import OmegaConf 7 | from transformers import AutoTokenizer 8 | from huggingface_hub import hf_hub_download 9 | 10 | from src.utils.setup import seed_everything 11 | from src.utils.logging import print_header 12 | from src.model.pretrained import get_pretrained_loader 13 | from src.model.load_model import load_and_convert_attns, load_and_convert_finetune 14 | 15 | def load_model_from_checkpoint( 16 | attn_mlp_checkpoint_path: str = None, 17 | finetune_checkpoint_path: str = None, 18 | model_config_path: str = None, 19 | distill_config_path: str = None, 20 | finetune_config_path: str = None, 21 | config_dir: str = 'configs', 22 | print_model: bool = False, 23 | debug: bool = False, 24 | huggingface_token: str = None, 25 | use_cuda_kernels: bool = False, 26 | use_attention: bool = False 27 | ): 28 | 29 | is_local = attn_mlp_checkpoint_path.endswith(".pt") 30 | 31 | model_config = OmegaConf.load(model_config_path) 32 | distill_config = OmegaConf.load(distill_config_path) 33 | finetune_config = OmegaConf.load(finetune_config_path) 34 | 35 | model_loader = get_pretrained_loader(**model_config.model, 36 | huggingface_token=huggingface_token) 37 | tokenizer = model_loader.load_tokenizer() 38 | tokenizer.pad_token_id = tokenizer.eos_token_id 39 | tokenizer.padding_side = 'left' 40 | if use_attention: 41 | model = model_loader.load('softmax') 42 | return model, model_config, tokenizer 43 | 44 | model = model_loader.load(model_config['attention']['attention_type']) 45 | if use_cuda_kernels: 46 | print('*** Using TK CUDA kernels **') 47 | model_config['attention']['attention_type'] = 'lolcats_llama_window_tk_gen' 48 | 49 | if is_local: 50 | checkpoint_path = attn_mlp_checkpoint_path 51 | else: 52 | checkpoint_path = None 53 | model, distill_peft_config = load_and_convert_attns( 54 | model, model_config, 55 | attention_type=None, 56 | checkpoint_path=checkpoint_path, 57 | print_model=debug, 58 | merge_loras=False, 59 | peft_gradient_checkpointing=False, 60 | train_attention=False) 61 | 62 | if is_local: 63 | checkpoint_path = attn_mlp_checkpoint_path 64 | else: 65 | checkpoint_path = None 66 | model, ft_peft_config = load_and_convert_finetune( 67 | model, finetune_config, 68 | checkpoint_path=checkpoint_path, 69 | print_model=debug, 70 | merge_loras=False, 71 | peft_gradient_checkpointing=False) 72 | 73 | if not is_local: 74 | model = load_hf_weights( 75 | model, 76 | attn_mlp_checkpoint_path, finetune_checkpoint_path, 77 | filename="model.pt" 78 | ) 79 | if use_cuda_kernels: 80 | print('*** Using TK CUDA kernels ***') 81 | 82 | if print_model: 83 | print('*** Model after checkpoint load ***') 84 | print(model) 85 | 86 | return model, model_config, tokenizer 87 | 88 | def load_hf_weights(model, distill_repo_id, ft_repo_id, filename="model.pt"): 89 | for repo_id in [distill_repo_id, ft_repo_id]: 90 | if repo_id is None: continue 91 | 92 | print(f"Loading weights from {repo_id}") 93 | 94 | local_file_path = hf_hub_download(repo_id=repo_id, filename=filename) 95 | state_dict = torch.load(local_file_path) 96 | if 'model_state_dict' in state_dict: 97 | state_dict = state_dict['model_state_dict'] 98 | else: 99 | pass 100 | _keys = model.load_state_dict(state_dict, strict=False) 101 | if len(_keys.unexpected_keys) > 0: 102 | new_state_dict = {k.replace('model.', 'model.model.'): v for k, v in state_dict.items()} 103 | _keys = model.load_state_dict(new_state_dict, strict=False) 104 | if len(_keys.unexpected_keys) > 0: 105 | new_state_dict = {k.replace('model.', 'base_model.model.model.'): v for k, v in state_dict.items()} 106 | _keys = model.load_state_dict(new_state_dict, strict=False) 107 | 108 | try: 109 | assert len(_keys.unexpected_keys) == 0 110 | print('*** All expected keys matched successfully ***') 111 | except Exception as e: 112 | print(e) 113 | print('*** Error: unexpected keys in checkpoint - please fix ***') 114 | print('Unexpected keys:') 115 | for k in _keys.unexpected_keys: 116 | print(k) 117 | exit() 118 | 119 | return model 120 | 121 | def load_model_and_tokenizer(): 122 | CONFIG_DIR = 'configs' # Update to your path 123 | 124 | model_config_path = f"{CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml" 125 | distill_config_path = f"{CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml" 126 | finetune_config_path = f"{CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml" 127 | attn_mlp_checkpoint_path = 'hazyresearch/lolcats-llama-3.1-8b-distill' 128 | finetune_checkpoint_path = 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' 129 | 130 | model, model_config, tokenizer = load_model_from_checkpoint( 131 | attn_mlp_checkpoint_path=attn_mlp_checkpoint_path, 132 | finetune_checkpoint_path=finetune_checkpoint_path, 133 | model_config_path=model_config_path, 134 | distill_config_path=distill_config_path, 135 | finetune_config_path=finetune_config_path, 136 | config_dir=CONFIG_DIR, 137 | print_model=False, 138 | debug=False, 139 | huggingface_token=None, 140 | use_cuda_kernels=False, 141 | use_attention=False 142 | ) 143 | model = model.to('cuda') 144 | model.eval() 145 | return model, tokenizer 146 | 147 | model, tokenizer = load_model_and_tokenizer() 148 | 149 | def generate_response(prompt): 150 | all_prompts = [prompt] 151 | 152 | with torch.no_grad(): 153 | model_input = tokenizer(all_prompts, return_tensors="pt").to(model.device) 154 | model_output = model.generate( 155 | **model_input, use_cache=True, 156 | max_new_tokens=50, 157 | do_sample=False, 158 | top_k=1, 159 | top_p=1.0, 160 | num_return_sequences=1, 161 | pad_token_id=tokenizer.eos_token_id) 162 | generated_tokens = model_output[0] 163 | input_len = model_input['input_ids'].shape[1] 164 | generated_tokens = generated_tokens[input_len:] 165 | generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) 166 | 167 | return generated_text 168 | 169 | iface = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="LOLcats Model Demo") 170 | 171 | iface.launch() 172 | -------------------------------------------------------------------------------- /assets/hedgehog_llamas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/lolcats/375df84f58417a57b78c875fa3f8fa76e84ab12e/assets/hedgehog_llamas.png -------------------------------------------------------------------------------- /assets/hedgehog_llamas_big.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/lolcats/375df84f58417a57b78c875fa3f8fa76e84ab12e/assets/hedgehog_llamas_big.png -------------------------------------------------------------------------------- /assets/lolcats_and_tk_llamas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/lolcats/375df84f58417a57b78c875fa3f8fa76e84ab12e/assets/lolcats_and_tk_llamas.png -------------------------------------------------------------------------------- /configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: alpaca_clean 3 | dataset_config: 4 | name: default 5 | path: yahma/alpaca-cleaned 6 | chunk_size: 1024 # sequence length for distilling 7 | concat_data: true 8 | cache_dir: 'data/alpaca' # Change this to where you want to save 9 | pretrained_model_config: # will be updated based on model_config 10 | pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B' 11 | cache_dir: '/scratch/' 12 | preprocess_config: null 13 | 14 | dataloader: 15 | batch_size: 1 16 | num_workers: 2 17 | drop_last: false 18 | pin_memory: true 19 | 20 | optimizer: 21 | optim: adamw_torch_fused 22 | lr: 0.01 23 | weight_decay: 0.0 24 | 25 | lr_scheduler: 26 | lr_scheduler_type: reduce_lr_on_plateau 27 | mode: min 28 | factor: 0.1 29 | patience: 10 30 | min_lr: 0.00001 31 | 32 | trainer: # HuggingFace Trainer-like arguments 33 | name: distill_attention_xent_mse 34 | reverse_kl: false 35 | mse_factor: 1000 36 | xent_factor: 0 37 | 38 | bf16: true 39 | train_split: train 40 | val_split: validation 41 | num_train_epochs: 2 42 | gradient_accumulation_steps: 8 43 | seed: 42 44 | batch_size: 1 45 | load_best_model_at_end: true 46 | greater_is_better: false 47 | metric_for_best_model: distill/eval/loss 48 | logging_steps: 100 49 | evaluation_strategy: steps 50 | max_steps: -1 51 | eval_steps: 100 52 | max_eval_batches: null 53 | -------------------------------------------------------------------------------- /configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: alpaca_clean 3 | dataset_config: 4 | name: default 5 | path: yahma/alpaca-cleaned 6 | chunk_size: 1024 # sequence length for distilling 7 | concat_data: true 8 | cache_dir: 'data/alpaca' # Change this to where you want to save 9 | pretrained_model_config: # will be updated based on model_config 10 | pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3.1-8B' 11 | cache_dir: '/data_persistent2/sim_data/llama-3_1-8b/' 12 | preprocess_config: null 13 | 14 | dataloader: 15 | batch_size: 1 16 | num_workers: 2 17 | drop_last: false 18 | pin_memory: true 19 | 20 | optimizer: 21 | optim: adamw_torch_fused 22 | lr: 0.01 23 | weight_decay: 0.0 24 | 25 | lr_scheduler: 26 | lr_scheduler_type: reduce_lr_on_plateau 27 | mode: min 28 | factor: 0.1 29 | patience: 10 30 | min_lr: 0.00001 31 | 32 | trainer: # HuggingFace Trainer-like arguments 33 | name: distill_attention_xent_mse 34 | reverse_kl: false 35 | mse_factor: 1000 36 | xent_factor: 1 37 | 38 | bf16: true 39 | train_split: train 40 | val_split: validation 41 | num_train_epochs: 2 42 | gradient_accumulation_steps: 8 43 | seed: 42 44 | batch_size: 1 45 | load_best_model_at_end: true 46 | greater_is_better: false 47 | metric_for_best_model: distill/eval/loss 48 | logging_steps: 100 49 | evaluation_strategy: steps 50 | max_steps: -1 51 | eval_steps: 100 52 | max_eval_batches: null 53 | -------------------------------------------------------------------------------- /configs/experiment/eval_alpaca_clean.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: alpaca_clean 3 | dataset_config: 4 | name: alpaca 5 | path: yahma/alpaca-cleaned 6 | chunk_size: 1024 # sequence length for distilling 7 | concat_data: true 8 | cache_dir: 'data/alpaca' # Change this to where you want to save 9 | pretrained_model_config: 10 | pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config 11 | cache_dir: '/scratch/' 12 | preprocess_config: null 13 | 14 | dataloader: 15 | batch_size: 1 16 | num_workers: 2 17 | drop_last: false 18 | pin_memory: true 19 | 20 | optimizer: 21 | optim: adamw_torch_fused 22 | lr: 1e-4 23 | weight_decay: 0.0 24 | 25 | lr_scheduler: 26 | lr_scheduler_type: reduce_lr_on_plateau 27 | mode: min 28 | factor: 0.1 29 | patience: 10 30 | min_lr: 0.00001 31 | 32 | trainer: # HuggingFace Trainer-like arguments 33 | name: finetune_seq2seq 34 | bf16: true 35 | train_split: train 36 | val_split: test 37 | num_train_epochs: 2 38 | gradient_accumulation_steps: 8 39 | seed: 42 40 | batch_size: 1 41 | load_best_model_at_end: true 42 | greater_is_better: true 43 | metric_for_best_model: eval/rouge/geometric_mean 44 | logging_steps: 100 45 | evaluation_strategy: steps 46 | max_steps: -1 47 | eval_steps: 100 48 | max_eval_batches: null 49 | 50 | finetune: 51 | method: lora 52 | kwargs: 53 | r: 8 54 | lora_alpha: 16 55 | lora_dropout: 0 # 0.05 56 | target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj'] -------------------------------------------------------------------------------- /configs/experiment/finetune_lora_fqkvo_alpaca_clean.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: alpaca_clean 3 | dataset_config: 4 | name: default 5 | path: yahma/alpaca-cleaned 6 | chunk_size: 1024 7 | concat_data: true 8 | cache_dir: "data/alpaca" 9 | pretrained_model_config: 10 | pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config 11 | cache_dir: "/data_persistent2/sim_data/" 12 | preprocess_config: null 13 | 14 | dataloader: 15 | batch_size: 1 16 | num_workers: 2 17 | drop_last: false 18 | pin_memory: true 19 | 20 | optimizer: 21 | optim: adamw_torch_fused 22 | lr: 1e-4 23 | weight_decay: 0.0 24 | 25 | lr_scheduler: 26 | lr_scheduler_type: reduce_lr_on_plateau 27 | mode: min 28 | factor: 0.1 29 | patience: 10 30 | min_lr: 0.00001 31 | 32 | trainer: # HuggingFace Trainer-like arguments 33 | name: default_lm 34 | bf16: true 35 | train_split: train 36 | val_split: validation 37 | num_train_epochs: 2 38 | gradient_accumulation_steps: 8 39 | seed: 42 40 | batch_size: 1 41 | load_best_model_at_end: true 42 | greater_is_better: false 43 | metric_for_best_model: eval/loss # eval/rouge/geometric_mean 44 | logging_steps: 100 45 | evaluation_strategy: steps 46 | max_steps: -1 47 | eval_steps: 100 48 | max_eval_batches: null 49 | num_save_ckpt_steps: 200 50 | 51 | finetune: 52 | method: lora 53 | kwargs: 54 | r: 8 55 | lora_alpha: 16 56 | lora_dropout: 0 # 0.05 57 | target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] 58 | trainable_weights: ['feature_map_q.mlp.layer', 'feature_map_k.mlp.layer', 'window_factors'] 59 | -------------------------------------------------------------------------------- /configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: alpaca_clean 3 | dataset_config: 4 | name: default 5 | path: yahma/alpaca-cleaned 6 | chunk_size: 1024 7 | concat_data: true 8 | cache_dir: "data/alpaca" 9 | pretrained_model_config: 10 | pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config 11 | cache_dir: "/scratch/" 12 | preprocess_config: null 13 | 14 | dataloader: 15 | batch_size: 1 16 | num_workers: 2 17 | drop_last: false 18 | pin_memory: true 19 | 20 | optimizer: 21 | optim: adamw_torch_fused 22 | lr: 1e-4 23 | weight_decay: 0.0 24 | 25 | lr_scheduler: 26 | lr_scheduler_type: reduce_lr_on_plateau 27 | mode: min 28 | factor: 0.1 29 | patience: 10 30 | min_lr: 0.00001 31 | 32 | trainer: # HuggingFace Trainer-like arguments 33 | name: default_lm 34 | bf16: true 35 | train_split: train 36 | val_split: validation 37 | num_train_epochs: 2 38 | gradient_accumulation_steps: 8 39 | seed: 42 40 | batch_size: 1 41 | load_best_model_at_end: true 42 | greater_is_better: false 43 | metric_for_best_model: eval/loss # eval/rouge/geometric_mean 44 | logging_steps: 100 45 | evaluation_strategy: steps 46 | max_steps: -1 47 | eval_steps: 100 48 | max_eval_batches: null 49 | 50 | finetune: 51 | method: lora 52 | kwargs: 53 | r: 8 54 | lora_alpha: 16 55 | lora_dropout: 0 # 0.05 56 | target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] 57 | -------------------------------------------------------------------------------- /configs/experiment/no_distill_alpaca_clean.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: alpaca_clean 3 | dataset_config: 4 | name: alpaca 5 | path: yahma/alpaca-cleaned 6 | chunk_size: 1024 # sequence length for distilling 7 | concat_data: true 8 | cache_dir: 'data/alpaca' # Change this to where you want to save 9 | pretrained_model_config: 10 | pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config 11 | cache_dir: '/scr-ssd/mzhang/models/mistral-v0.1' 12 | preprocess_config: null 13 | 14 | dataloader: 15 | batch_size: 1 16 | num_workers: 2 17 | drop_last: false 18 | pin_memory: true 19 | 20 | optimizer: 21 | optim: adamw_torch_fused 22 | lr: 0.01 23 | weight_decay: 0.0 24 | 25 | lr_scheduler: 26 | lr_scheduler_type: none 27 | 28 | trainer: # HuggingFace Trainer-like arguments 29 | name: null 30 | -------------------------------------------------------------------------------- /configs/model/base_llama3_1_8b.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3.1-8B' 4 | cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 12 | rope_theta: 500000.0 13 | 14 | attention: 15 | attention_type: softmax 16 | -------------------------------------------------------------------------------- /configs/model/base_llama3_8b.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B' 4 | cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 12 | rope_theta: 500000.0 13 | 14 | attention: 15 | attention_type: softmax 16 | -------------------------------------------------------------------------------- /configs/model/base_mistral_7b.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" 4 | cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 12 | rope_theta: 10000.0 13 | 14 | attention: 15 | attention_type: softmax 16 | -------------------------------------------------------------------------------- /configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | # Experimental config for chunked linear attention 2 | name: llama 3 | model: 4 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" 5 | cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights 6 | return_dict: true 7 | load_in_8bit: false 8 | load_in_4bit: false 9 | device_map: auto 10 | low_cpu_mem_usage: true 11 | torch_dtype: bfloat16 12 | attn_implementation: flash_attention_2 13 | rope_theta: 500000.0 14 | rope_scaling: 15 | factor: 8.0 16 | low_freq_factor: 1.0 17 | high_freq_factor: 4.0 18 | original_max_position_embeddings: 8192 19 | rope_type: llama3 20 | 21 | attention: 22 | attention_type: lolcats_long_llama_window_sw 23 | state_chunk_len: 1024 24 | window_size: 64 25 | affine_attention_factors: false 26 | init_window_factor: -2.1972245773362196 27 | feature_map: softmax_dim 28 | feature_map_kwargs: 29 | eps: 1e-12 30 | # mlp: null # to set 31 | fullspace: true 32 | layer_idx: null # to set 33 | learned_kernel: untied_head_einsum 34 | learned_kernel_kwargs: 35 | feature_dim: 64 36 | skip_connection: false 37 | bias: false 38 | zero_init: false 39 | tie_qk_kernels: false 40 | train_qk: false 41 | -------------------------------------------------------------------------------- /configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | # Experimental config for chunked linear attention 2 | name: llama 3 | model: 4 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" 5 | cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights 6 | return_dict: true 7 | load_in_8bit: false 8 | load_in_4bit: false 9 | device_map: auto 10 | low_cpu_mem_usage: true 11 | torch_dtype: bfloat16 12 | attn_implementation: flash_attention_2 13 | rope_theta: 500000.0 14 | rope_scaling: 15 | factor: 8.0 16 | low_freq_factor: 1.0 17 | high_freq_factor: 4.0 18 | original_max_position_embeddings: 8192 19 | rope_type: llama3 20 | 21 | attention: 22 | attention_type: lolcats_long_llama_window_tk 23 | state_chunk_len: 1024 24 | window_size: 64 25 | affine_attention_factors: false 26 | init_window_factor: -2.1972245773362196 27 | feature_map: softmax_dim 28 | feature_map_kwargs: 29 | eps: 1e-12 30 | # mlp: null # to set 31 | fullspace: true 32 | layer_idx: null # to set 33 | learned_kernel: untied_head_einsum 34 | learned_kernel_kwargs: 35 | feature_dim: 64 36 | skip_connection: false 37 | bias: false 38 | zero_init: false 39 | tie_qk_kernels: false 40 | train_qk: false 41 | -------------------------------------------------------------------------------- /configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wsw64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | # Experimental config for chunked linear attention 2 | name: llama 3 | model: 4 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" 5 | cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights 6 | return_dict: true 7 | load_in_8bit: false 8 | load_in_4bit: false 9 | device_map: auto 10 | low_cpu_mem_usage: true 11 | torch_dtype: bfloat16 12 | attn_implementation: flash_attention_2 13 | rope_theta: 500000.0 14 | 15 | attention: 16 | attention_type: lolcats_long_llama_window_sw 17 | state_chunk_len: 1024 18 | window_size: 64 19 | affine_attention_factors: false 20 | init_window_factor: -2.1972245773362196 21 | feature_map: softmax_dim 22 | feature_map_kwargs: 23 | eps: 1e-12 24 | # mlp: null # to set 25 | fullspace: true 26 | layer_idx: null # to set 27 | learned_kernel: untied_head_einsum 28 | learned_kernel_kwargs: 29 | feature_dim: 64 30 | skip_connection: false 31 | bias: false 32 | zero_init: false 33 | tie_qk_kernels: false 34 | train_qk: false 35 | -------------------------------------------------------------------------------- /configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | # Experimental config for chunked linear attention 2 | name: llama 3 | model: 4 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" 5 | cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights 6 | return_dict: true 7 | load_in_8bit: false 8 | load_in_4bit: false 9 | device_map: auto 10 | low_cpu_mem_usage: true 11 | torch_dtype: bfloat16 12 | attn_implementation: flash_attention_2 13 | rope_theta: 500000.0 14 | 15 | attention: 16 | attention_type: lolcats_long_llama_window_tk 17 | state_chunk_len: 1024 18 | window_size: 64 19 | affine_attention_factors: false 20 | init_window_factor: -2.1972245773362196 21 | feature_map: softmax_dim 22 | feature_map_kwargs: 23 | eps: 1e-12 24 | # mlp: null # to set 25 | fullspace: true 26 | layer_idx: null # to set 27 | learned_kernel: untied_head_einsum 28 | learned_kernel_kwargs: 29 | feature_dim: 64 30 | skip_connection: false 31 | bias: false 32 | zero_init: false 33 | tie_qk_kernels: false 34 | train_qk: false 35 | -------------------------------------------------------------------------------- /configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wsw64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | # Experimental config for chunked linear attention 2 | name: llama 3 | model: 4 | pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" 5 | cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights 6 | return_dict: true 7 | load_in_8bit: false 8 | load_in_4bit: false 9 | device_map: auto 10 | low_cpu_mem_usage: true 11 | torch_dtype: bfloat16 12 | attn_implementation: flash_attention_2 # eager # so we can load attention weights 13 | rope_theta: 10000.0 14 | 15 | attention: 16 | attention_type: lolcats_long_llama_window_sw 17 | state_chunk_len: 512 # 1024 18 | window_size: 64 19 | affine_attention_factors: false 20 | init_window_factor: -2.1972245773362196 21 | train_window_factor: true 22 | train_attention_weights: false 23 | feature_map: softmax_dim 24 | feature_map_kwargs: 25 | eps: 1e-12 26 | # mlp: null # to set 27 | fullspace: true 28 | layer_idx: null # to set 29 | learned_kernel: untied_head_einsum 30 | learned_kernel_kwargs: 31 | feature_dim: 64 32 | skip_connection: false 33 | bias: false 34 | zero_init: false 35 | tie_qk_kernels: false 36 | train_qk: false 37 | -------------------------------------------------------------------------------- /configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wtk64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" 4 | cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 # eager # so we can load attention weights 12 | rope_theta: 10000.0 13 | 14 | attention: 15 | attention_type: lolcats_long_llama_window_tk 16 | state_chunk_len: 512 # 1024 17 | window_size: 64 18 | affine_attention_factors: false 19 | init_window_factor: -2.1972245773362196 20 | train_window_factor: true 21 | train_attention_weights: false 22 | feature_map: softmax_dim 23 | feature_map_kwargs: 24 | eps: 1e-12 25 | # mlp: null # to set 26 | fullspace: true 27 | layer_idx: null # to set 28 | learned_kernel: untied_head_einsum 29 | learned_kernel_kwargs: 30 | feature_dim: 64 31 | skip_connection: false 32 | bias: false 33 | zero_init: false 34 | tie_qk_kernels: false 35 | train_qk: false 36 | -------------------------------------------------------------------------------- /configs/model/distill_llama3_1_8b_lk_smd_fd64.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" 4 | cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: eager 12 | rope_theta: 500000.0 13 | rope_scaling: 14 | factor: 8.0 15 | low_freq_factor: 1.0 16 | high_freq_factor: 4.0 17 | original_max_position_embeddings: 8192 18 | rope_type: llama3 19 | 20 | attention: 21 | attention_type: lolcats_llama 22 | feature_map: softmax_dim 23 | feature_map_kwargs: 24 | eps: 1e-12 25 | # mlp: null # to set 26 | fullspace: true 27 | layer_idx: null # to set 28 | learned_kernel: untied_head_einsum 29 | learned_kernel_kwargs: 30 | feature_dim: 64 31 | skip_connection: false 32 | bias: false 33 | zero_init: false 34 | tie_qk_kernels: false 35 | train_qk: false 36 | -------------------------------------------------------------------------------- /configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" 4 | cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: eager 12 | rope_theta: 500000.0 13 | rope_scaling: 14 | factor: 8.0 15 | low_freq_factor: 1.0 16 | high_freq_factor: 4.0 17 | original_max_position_embeddings: 8192 18 | rope_type: llama3 19 | 20 | attention: 21 | attention_type: lolcats_llama_window_sw 22 | state_chunk_len: 1024 23 | window_size: 64 24 | affine_attention_factors: false 25 | init_window_factor: -2.1972245773362196 26 | feature_map: softmax_dim 27 | feature_map_kwargs: 28 | eps: 1e-12 29 | # mlp: null # to set 30 | fullspace: true 31 | layer_idx: null # to set 32 | learned_kernel: untied_head_einsum 33 | learned_kernel_kwargs: 34 | feature_dim: 64 35 | skip_connection: false 36 | bias: false 37 | zero_init: false 38 | tie_qk_kernels: false 39 | train_qk: false 40 | -------------------------------------------------------------------------------- /configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" 4 | cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: eager 12 | rope_theta: 500000.0 13 | rope_scaling: 14 | factor: 8.0 15 | low_freq_factor: 1.0 16 | high_freq_factor: 4.0 17 | original_max_position_embeddings: 8192 18 | rope_type: llama3 19 | 20 | attention: 21 | attention_type: lolcats_llama_window_tk 22 | state_chunk_len: 1024 23 | window_size: 64 24 | affine_attention_factors: false 25 | init_window_factor: -2.1972245773362196 26 | feature_map: softmax_dim 27 | feature_map_kwargs: 28 | eps: 1e-12 29 | # mlp: null # to set 30 | fullspace: true 31 | layer_idx: null # to set 32 | learned_kernel: untied_head_einsum 33 | learned_kernel_kwargs: 34 | feature_dim: 64 35 | skip_connection: false 36 | bias: false 37 | zero_init: false 38 | tie_qk_kernels: false 39 | train_qk: false 40 | -------------------------------------------------------------------------------- /configs/model/distill_llama3_1_8b_lk_t2r.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" 4 | cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: eager 12 | rope_theta: 500000.0 13 | rope_scaling: 14 | factor: 8.0 15 | low_freq_factor: 1.0 16 | high_freq_factor: 4.0 17 | original_max_position_embeddings: 8192 18 | rope_type: llama3 19 | 20 | attention: 21 | attention_type: lolcats_llama 22 | feature_map: relu 23 | feature_map_kwargs: 24 | eps: 1e-12 25 | # mlp: null # to set 26 | fullspace: true 27 | layer_idx: null # to set 28 | learned_kernel: untied_head_einsum 29 | learned_kernel_kwargs: 30 | feature_dim: 128 31 | skip_connection: false 32 | bias: true 33 | zero_init: false 34 | tie_qk_kernels: false 35 | train_qk: false 36 | -------------------------------------------------------------------------------- /configs/model/distill_llama3_8b_lk_smd_fd64.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" 4 | cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 12 | rope_theta: 500000.0 13 | 14 | attention: 15 | attention_type: lolcats_llama 16 | feature_map: softmax_dim 17 | feature_map_kwargs: 18 | eps: 1e-12 19 | # mlp: null # to set 20 | fullspace: true 21 | layer_idx: null # to set 22 | learned_kernel: untied_head_einsum 23 | learned_kernel_kwargs: 24 | feature_dim: 64 25 | skip_connection: false 26 | bias: false 27 | zero_init: false 28 | tie_qk_kernels: false 29 | train_qk: false 30 | -------------------------------------------------------------------------------- /configs/model/distill_llama3_8b_lk_smd_wsw64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" 4 | cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 12 | rope_theta: 500000.0 13 | 14 | attention: 15 | attention_type: lolcats_llama_window_sw 16 | state_chunk_len: 1024 17 | window_size: 64 18 | affine_attention_factors: false 19 | init_window_factor: -2.1972245773362196 20 | feature_map: softmax_dim 21 | feature_map_kwargs: 22 | eps: 1e-12 23 | # mlp: null # to set 24 | fullspace: true 25 | layer_idx: null # to set 26 | learned_kernel: untied_head_einsum 27 | learned_kernel_kwargs: 28 | feature_dim: 64 29 | skip_connection: false 30 | bias: false 31 | zero_init: false 32 | tie_qk_kernels: false 33 | train_qk: false 34 | -------------------------------------------------------------------------------- /configs/model/distill_llama3_8b_lk_smd_wtk64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" 4 | cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 12 | rope_theta: 500000.0 13 | 14 | attention: 15 | attention_type: lolcats_llama_window_tk 16 | state_chunk_len: 1024 17 | window_size: 64 18 | affine_attention_factors: false 19 | init_window_factor: -2.1972245773362196 20 | feature_map: softmax_dim 21 | feature_map_kwargs: 22 | eps: 1e-12 23 | # mlp: null # to set 24 | fullspace: true 25 | layer_idx: null # to set 26 | learned_kernel: untied_head_einsum 27 | learned_kernel_kwargs: 28 | feature_dim: 64 29 | skip_connection: false 30 | bias: false 31 | zero_init: false 32 | tie_qk_kernels: false 33 | train_qk: false 34 | -------------------------------------------------------------------------------- /configs/model/distill_llama3_8b_lk_t2r.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" 4 | cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 12 | rope_theta: 500000.0 13 | 14 | attention: 15 | attention_type: lolcats_llama 16 | feature_map: relu 17 | feature_map_kwargs: 18 | eps: 1e-12 19 | # mlp: null # to set 20 | fullspace: true 21 | layer_idx: null # to set 22 | learned_kernel: untied_head_einsum 23 | learned_kernel_kwargs: 24 | feature_dim: 128 25 | skip_connection: false 26 | bias: true 27 | zero_init: false 28 | tie_qk_kernels: false 29 | train_qk: false 30 | -------------------------------------------------------------------------------- /configs/model/distill_mistral_7b_lk_smd_fd64.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" 4 | cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 # eager # so we can load attention weights 12 | rope_theta: 10000.0 13 | 14 | attention: 15 | attention_type: lolcats_llama 16 | feature_map: softmax_dim 17 | feature_map_kwargs: 18 | eps: 1e-12 19 | # mlp: null # to set 20 | fullspace: true 21 | layer_idx: null # to set 22 | learned_kernel: untied_head_einsum 23 | learned_kernel_kwargs: 24 | feature_dim: 64 25 | skip_connection: false 26 | bias: false 27 | zero_init: false 28 | tie_qk_kernels: false 29 | train_qk: false 30 | -------------------------------------------------------------------------------- /configs/model/distill_mistral_7b_lk_smd_wsw64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" 4 | cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 # eager # so we can load attention weights 12 | rope_theta: 10000.0 13 | 14 | attention: 15 | attention_type: lolcats_llama_window_sw 16 | state_chunk_len: 512 # 1024 17 | window_size: 64 18 | affine_attention_factors: false 19 | init_window_factor: -2.1972245773362196 20 | train_window_factor: true 21 | train_attention_weights: false 22 | feature_map: softmax_dim 23 | feature_map_kwargs: 24 | eps: 1e-12 25 | # mlp: null # to set 26 | fullspace: true 27 | layer_idx: null # to set 28 | learned_kernel: untied_head_einsum 29 | learned_kernel_kwargs: 30 | feature_dim: 64 31 | skip_connection: false 32 | bias: false 33 | zero_init: false 34 | tie_qk_kernels: false 35 | train_qk: false 36 | -------------------------------------------------------------------------------- /configs/model/distill_mistral_7b_lk_smd_wtk64_fd64_w01.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" 4 | cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 # eager # so we can load attention weights 12 | rope_theta: 10000.0 13 | 14 | attention: 15 | attention_type: lolcats_llama_window_tk 16 | state_chunk_len: 512 # 1024 17 | window_size: 64 18 | affine_attention_factors: false 19 | init_window_factor: -2.1972245773362196 20 | train_window_factor: true 21 | train_attention_weights: false 22 | feature_map: softmax_dim 23 | feature_map_kwargs: 24 | eps: 1e-12 25 | # mlp: null # to set 26 | fullspace: true 27 | layer_idx: null # to set 28 | learned_kernel: untied_head_einsum 29 | learned_kernel_kwargs: 30 | feature_dim: 64 31 | skip_connection: false 32 | bias: false 33 | zero_init: false 34 | tie_qk_kernels: false 35 | train_qk: false 36 | -------------------------------------------------------------------------------- /configs/model/distill_mistral_7b_lk_t2r.yaml: -------------------------------------------------------------------------------- 1 | name: llama 2 | model: 3 | pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" 4 | cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights 5 | return_dict: true 6 | load_in_8bit: false 7 | load_in_4bit: false 8 | device_map: auto 9 | low_cpu_mem_usage: true 10 | torch_dtype: bfloat16 11 | attn_implementation: flash_attention_2 # eager # so we can load attention weights 12 | rope_theta: 10000.0 13 | 14 | attention: 15 | attention_type: lolcats_llama 16 | feature_map: relu 17 | feature_map_kwargs: 18 | eps: 1e-12 19 | # mlp: null # to set 20 | fullspace: true 21 | layer_idx: null # to set 22 | learned_kernel: untied_head_einsum 23 | learned_kernel_kwargs: 24 | feature_dim: 128 25 | skip_connection: false 26 | bias: true 27 | zero_init: false 28 | tie_qk_kernels: false 29 | train_qk: false 30 | -------------------------------------------------------------------------------- /csrc/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | from .causal_attention import causal_dot_product 7 | -------------------------------------------------------------------------------- /csrc/causal_attention.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | // Written by Angelos Katharopoulos , 4 | // Apoorv Vyas 5 | // 6 | 7 | #include 8 | 9 | 10 | /** 11 | * Compute a*b^T and save it into out. 12 | * 13 | * a \in R^A 14 | * b \in R^B 15 | */ 16 | inline void vvt_dot(float *a, float *b, float *out, int A, int B) { 17 | for (int i=0; i(); 97 | auto ka = keys.accessor(); 98 | auto va = values.accessor(); 99 | auto pa = product.accessor(); 100 | 101 | #pragma omp parallel for collapse(2) 102 | for (int n=0; n(); 106 | for (int l=0; l(); 151 | auto ka = keys.accessor(); 152 | auto va = values.accessor(); 153 | auto ga = grad_out.accessor(); 154 | auto gqa = grad_queries.accessor(); 155 | auto gka = grad_keys.accessor(); 156 | auto gva = grad_values.accessor(); 157 | 158 | #pragma omp parallel for collapse(2) 159 | for (int n=0; n(); 163 | 164 | // Compute the gradient wrt the queries 165 | for (int l=0; l=0; l--) { 185 | vvt_dot( 186 | &qa[n][h][l][0], 187 | &ga[n][h][l][0], 188 | kvp, 189 | E, 190 | M 191 | ); 192 | vmt_dot( 193 | &va[n][h][l][0], 194 | kvp, 195 | &gka[n][h][l][0], 196 | E, 197 | M 198 | ); 199 | vm_dot( 200 | &ka[n][h][l][0], 201 | kvp, 202 | &gva[n][h][l][0], 203 | E, 204 | M 205 | ); 206 | } 207 | } 208 | } 209 | } 210 | 211 | 212 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 213 | m.def( 214 | "causal_dot_product", 215 | &causal_dot_product, 216 | "Compute the weighted sum of values but attending only to previous " 217 | "values." 218 | ); 219 | m.def( 220 | "causal_dot_backward", 221 | &causal_dot_backward, 222 | "Compute the gradient of queries, keys and values given the gradient " 223 | "of causal_dot_product." 224 | ); 225 | } -------------------------------------------------------------------------------- /csrc/causal_attention.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | import torch 8 | 9 | try: 10 | from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda 11 | from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda 12 | except ImportError as e: 13 | print(e) 14 | causal_dot_product_cuda = causal_dot_backward_cuda = None 15 | 16 | 17 | class CausalDotProduct(torch.autograd.Function): 18 | """Compute the weighted sum of values but attending only to previous 19 | values.""" 20 | dot = { 21 | # "cpu": causal_dot_product_cpu, 22 | "cuda": causal_dot_product_cuda 23 | } 24 | dot_backward = { 25 | # "cpu": causal_dot_backward_cpu, 26 | "cuda": causal_dot_backward_cuda 27 | } 28 | 29 | @staticmethod 30 | def forward(ctx, Q, K, V): 31 | # Save the inputs for the gradient computation 32 | ctx.save_for_backward(Q, K, V) 33 | 34 | # Create the output tensor 35 | device = Q.device 36 | N, H, L, _ = Q.shape 37 | _, _, _, M = V.shape 38 | product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device) 39 | 40 | # Actually perform the dot product 41 | CausalDotProduct.dot[device.type]( 42 | Q.data, 43 | K.data, 44 | V.data, 45 | product 46 | ) 47 | # breakpoint() 48 | # CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product) 49 | 50 | return product 51 | 52 | @staticmethod 53 | def backward(ctx, grad_out): 54 | # Extract the saved tensors 55 | Q, K, V = ctx.saved_tensors 56 | 57 | # Allocate memory for the gradients 58 | grad_Q = torch.zeros_like(Q) 59 | grad_K = torch.zeros_like(K) 60 | grad_V = torch.zeros_like(V) 61 | 62 | # Actually compute the gradients 63 | CausalDotProduct.dot_backward[Q.device.type]( 64 | Q.data, 65 | K.data, 66 | V.data, 67 | grad_out, 68 | grad_Q, 69 | grad_K, 70 | grad_V 71 | ) 72 | 73 | return grad_Q, grad_K, grad_V 74 | 75 | 76 | # Alias the autograd functions to python style snake case naming 77 | causal_dot_product = CausalDotProduct.apply -------------------------------------------------------------------------------- /csrc/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Angelos Katharopoulos , 4 | # Apoorv Vyas 5 | # 6 | 7 | import torch 8 | from setuptools import setup 9 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME 10 | import subprocess 11 | 12 | def get_last_arch_torch(): 13 | arch = torch.cuda.get_arch_list()[-1] 14 | print(f"Found arch: {arch} from existing torch installation") 15 | return arch 16 | 17 | def get_cuda_bare_metal_version(cuda_dir): 18 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 19 | output = raw_output.split() 20 | release_idx = output.index("release") + 1 21 | release = output[release_idx].split(".") 22 | bare_metal_major = release[0] 23 | bare_metal_minor = release[1][0] 24 | return raw_output, bare_metal_major, bare_metal_minor 25 | 26 | def append_nvcc_threads(nvcc_extra_args): 27 | _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) 28 | if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: 29 | return nvcc_extra_args + ["--threads", "4"] 30 | return nvcc_extra_args 31 | 32 | arch = get_last_arch_torch() 33 | sm_num = arch[-2:] 34 | cc_flag = ['--generate-code=arch=compute_90,code=compute_90'] # for H100 35 | # cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100 36 | # cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090 37 | # cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090 38 | # cc_flag = ['--generate-code=arch=compute_75,code=compute_75'] 39 | 40 | setup( 41 | name='causal_attention_cuda_cpp', 42 | ext_modules=[ 43 | CUDAExtension('causal_attention_cuda', [ 44 | # 'causal_attention.cpp', 45 | 'causal_attention_cuda.cu', 46 | ], 47 | extra_compile_args={'cxx': ['-O3'], 48 | 'nvcc': append_nvcc_threads(['-O3', '-lineinfo', '--use_fast_math', '-std=c++17'] + cc_flag) 49 | }) 50 | ], 51 | cmdclass={ 52 | 'build_ext': BuildExtension 53 | }) 54 | -------------------------------------------------------------------------------- /demos/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Demos 3 | 4 | We describe how to use LoLCATS checkpoints. We include: 5 | 1. Demo script to talk to our models using Hugging Face checkpoints 6 | 2. Demo script to benchmark the pretrained 8B linearized versus base softmax attention models 7 | 3. Code to reproduce the MMLU numbers at 70B and 405B numbers using our uploaded HuggingFace checkpoints 8 | 4. Coming soon: VLLM integration with custom LoLCATS CUDA kernels! 9 | 10 | ### Talk to pre-trained LoLCATS LLMs 11 | 12 | Use the commands provided at `demo_8b.sh` to run inference with our LoLCATS - Llama 3.1 8B checkpoint, which will be downloaded from Hugging Face. The downloaded checkpoints require under <1GB, and are inserted into your local Meta Llama 3.1 model in 16-bit precision -- please ensure you have downloaded the base model and specify your path to it in the configs in `demo_8b.sh`. To run the demo: 13 | ```bash 14 | bash demo_8b.sh 15 | ``` 16 | 17 | ### Fast inference with custom CUDA kernels 18 | 19 | We provide a custom CUDA prefill kernel written in the [ThunderKittens framework](https://github.com/HazyResearch/ThunderKittens). 20 | 21 | To install the kernel: 22 | ```bash 23 | # Clone the repo 24 | git clone https://github.com/HazyResearch/ThunderKittens 25 | cd ThunderKittens 26 | # In config.py, select 'hedgehog', then run: 27 | source env.src 28 | python setup.py install 29 | ``` 30 | 31 | As a quick end-to-end compare the prefill speed of the linearized LoLCATS 8B vs. the base Llama 8B model, we provide a script at: 32 | ```bash 33 | bash benchmark_8b.sh 34 | ``` 35 | Our benchmarking implementation is currently restricted to prefill lengths that are multiples of 64. 36 | 37 | The code will print out the inference tokens per second per method. 38 | 39 | ### 5-shot MMLU Eval 40 | 41 | First get the 5-shot MMLU data. We directly saved the tokenized examples produced by the [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) codebase to a pickle file 42 | ``` 43 | cd lolcats/inference/ 44 | unzip mmlu.pkl.zip 45 | ``` 46 | 47 | We provide scripts to eval our 70B and 405B LoLCATS linearized checkpoints on HuggingFace on MMLU 48 | ```bash 49 | cd lolcats/ 50 | bash demos/llm_mmlu_eval/demo_70b.sh # runs on 1 8x80GB H100 node 51 | sbatch demos/llm_mmlu_eval/demo_405b.sh # set to use 2 8x80GB H100 nodes 52 | ``` 53 | 54 | These call to the `demos/llm_mmlu_eval/eval_mmlu.py` file, which just loops through mmlu.pkl and uses the last-token model logits to get the predictions. 55 | 56 | 57 | ### VLLM Integration 58 | 59 | Coming Soon! 60 | -------------------------------------------------------------------------------- /demos/benchmark_8b.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | CONFIG_DIR='/home/bfs/simran/attention/lolcats/configs/' # update to your path 4 | 5 | # """ Benchmarking the 8b model on the LOLCATS dataset """ 6 | 7 | # Run the linearized model with the ThunderKittens kernel 8 | CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ 9 | --model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ 10 | --distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ 11 | --finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \ 12 | --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ 13 | --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ 14 | --num_generations 1 \ 15 | --use_cuda_kernels 1 \ 16 | --benchmark \ 17 | --max_new_tokens 1 18 | 19 | # Run the linearized model *without* the ThunderKittens kernel 20 | CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ 21 | --model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ 22 | --distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ 23 | --finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \ 24 | --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ 25 | --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ 26 | --num_generations 1 \ 27 | --use_cuda_kernels 0 \ 28 | --benchmark \ 29 | --max_new_tokens 1 30 | 31 | # Run the base Llama model with Transformers SDPA attention 32 | CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ 33 | --model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ 34 | --distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ 35 | --finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \ 36 | --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ 37 | --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ 38 | --num_generations 1 \ 39 | --use_attention \ 40 | --benchmark \ 41 | --max_new_tokens 1 42 | -------------------------------------------------------------------------------- /demos/demo_8b.sh: -------------------------------------------------------------------------------- 1 | 2 | CONFIG_DIR='/home/bfs/simran/attention/lolcats/configs/' # update to your path 3 | 4 | # Using huggingface checkpoints 5 | CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ 6 | --model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ 7 | --distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ 8 | --finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \ 9 | --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ 10 | --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ 11 | --num_generations 1 \ 12 | --max_new_tokens 50 13 | 14 | 15 | # Reference script: 16 | # if you train your own LoLCATS weights, you can use the following command to run inference with your local checkpoints: 17 | # CHECKPOINT_DIR='/home/mzhang/projects/lolcats/checkpoints/' 18 | # CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ 19 | # --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ 20 | # --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ 21 | # --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ 22 | # --attn_mlp_checkpoint_path ${CHECKPOINT_DIR}/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1_distill.pt \ 23 | # --finetune_checkpoint_path ${CHECKPOINT_DIR}/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1-bs=1-gas=8-nte=2-ms=2500-se=0-re=100_ft.pt \ 24 | # --num_generations 1 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /demos/llm_mmlu_eval/demo_405b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llama-405b 3 | #SBATCH --partition=sixhour 4 | #SBATCH --nodes=2 5 | #SBATCH --nodelist=mk-xii-05,mk-xii-06 # TODO: set to your nodenames 6 | #SBATCH --gres=gpu:8 7 | #SBATCH --cpus-per-task=19 8 | #SBATCH --time=5:59:00 9 | #SBATCH --output=/home/simarora/utils/slurm_logs/slurm-%j.out # TODO: make your own directory 10 | #SBATCH --error=/home/simarora/utils/slurm_logs/slurm-%j.err 11 | #SBATCH --ntasks=2 # Add this line 12 | #SBATCH --ntasks-per-node=1 # Add this line 13 | 14 | # Initialize HPC-X toolkit for high-performance computing 15 | . /opt/hpcx/hpcx-init.sh 16 | hpcx_load 17 | 18 | export NCCL_IGNORE_CPU_AFFINITY=1 # Ignore CPU affinity settings 19 | export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # Enable asynchronous error handling for PyTorch NCCL 20 | export CUDA_DEVICE_ORDER=PCI_BUS_ID # Set CUDA device order to PCI bus ID 21 | export NCCL_IB_DISABLE=0 # Enable InfiniBand if available 22 | export NCCL_NET_GDR_LEVEL=5 # Enable GPUDirect RDMA for faster GPU-to-GPU communication 23 | export NCCL_P2P_DISABLE=0 # Enable peer-to-peer communication between GPUs 24 | export NCCL_BUFFSIZE=2097152 # Set 2MB buffer size for NCCL operations 25 | export NCCL_IB_HCA=mlx5 # Specify the InfiniBand Host Channel Adapter to use 26 | 27 | export MASTER_HOSTNAME="mk-xii-05" # # TODO change to your nodenames 28 | export MASTER_ADDR=$(host $MASTER_HOSTNAME | awk '/has address/ { print $4 }') 29 | export MASTER_PORT=29500 30 | 31 | export PYTHONPATH=/home/simarora/code/lolcats/ # TODO change to your folder 32 | 33 | # Save the model outputs 34 | srun torchrun --nnodes 2 --node_rank $SLURM_NODEID --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --nproc_per_node 8 \ 35 | demos/llm_mmlu_eval/eval_mmlu.py \ 36 | --model_config llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01 \ 37 | --distill_config llama3_1_405b/rp_distill_llama_405b_xent1_mse1000_lr1e-2 \ 38 | --finetune_config llama3_1_405b/finetune_rp_llama_405b_qkvo_e2 \ 39 | --verbose --replicate 0 --seed 0 --lk_zero_init \ 40 | --eval_steps 100 --dataset_chunk_size 1024 \ 41 | --enable_fsdp --low_cpu_fsdp --fsdp_activation_checkpointing \ 42 | --tag hf_405b_mmlu \ 43 | --finetune_checkpoint_path hazyresearch/lolcats-llama-3.1-405b 44 | 45 | 46 | # Alternatively, you can run with your own locally trained paths by passing in the the checkpoint_path like follows: 47 | # --finetune_checkpoint_path /home/simarora/code/lolcats/checkpoints/ckpt_lora-dl-d=rp_distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=rp_finetune_llama_40b_qv_hparams-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo_e2_rp-ft_lora=0-se=0-re=0-s=3550.pt 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /demos/llm_mmlu_eval/demo_70b.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=/home/simarora/code/lolcats/ 2 | 3 | # Use HF checkpoint paths (can also prob get away with 2 GPUs - longer contexts may not fit tho) 4 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes 1 --nproc_per_nodes 8 \ 5 | demos/llm_mmlu_eval/eval_mmlu.py \ 6 | --model_config llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ 7 | --distill_config llama3_1_70b/distill_rp_llama_70b_xent0_mse1000_lr1e-2 \ 8 | --finetune_config llama3_1_70b/finetune_rp_llama_70b_qkvo \ 9 | --eval_config eval_alpaca_clean \ 10 | --verbose --replicate 0 --seed 0 \ 11 | --lk_zero_init \ 12 | --eval_steps 100 --dataset_chunk_size 1024 \ 13 | --enable_fsdp --low_cpu_fsdp --fsdp_activation_checkpointing \ 14 | --experiment_tag lolcats_hf_70b \ 15 | --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-70b' 16 | 17 | # Example using local paths, in case you train your own model. 18 | # CUDA_VISIBLE_DEVICES=6,7 torchrun --nnodes 1 --nproc_per_node 2 \ 19 | # demos/llm_mmlu_eval/eval_mmlu.py \ 20 | # --model_config llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ 21 | # --distill_config llama3_1_70b/distill_rp_llama_70b_xent0_mse1000_lr1e-2 \ 22 | # --finetune_config llama3_1_70b/finetune_rp_llama_70b_qkvo \ 23 | # --eval_config eval_alpaca_clean \ 24 | # --verbose --replicate 0 --seed 0 \ 25 | # --lk_zero_init \ 26 | # --eval_steps 100 --dataset_chunk_size 1024 \ 27 | # --enable_fsdp --low_cpu_fsdp --fsdp_activation_checkpointing \ 28 | # --experiment_tag my_lolcats_70b \ 29 | # --finetune_checkpoint_path ckpt_lora-dl-d=distill_rp_llama_70b_xent0_mse1000_lr1e-2-m=distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=finetune_rp_llama_70b_qkvo-fac=1-se=0-re=0-se=0-re=0.pt 30 | 31 | -------------------------------------------------------------------------------- /demos/llm_mmlu_eval/mmlu.pkl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/lolcats/375df84f58417a57b78c875fa3f8fa76e84ab12e/demos/llm_mmlu_eval/mmlu.pkl.zip -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: lolcats-env 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - huggingface 6 | - nvidia 7 | dependencies: 8 | - python>=3.10 9 | - pip 10 | - pytorch 11 | - pytorch-cuda=12.1 # change depending on your supported CUDA version 12 | - transformers=4.43.1 # Llama 3.1 support (updated RoPE) 13 | - peft=0.9.0 14 | - datasets=2.15.0 15 | - nltk 16 | - rouge-score 17 | - sentencepiece 18 | - ninja 19 | - numpy 20 | - pandas 21 | - einops 22 | - opt_einsum 23 | - matplotlib 24 | - omegaconf 25 | - openai 26 | - rich 27 | - tqdm 28 | - wandb 29 | -------------------------------------------------------------------------------- /lm_eval_harness/README.md: -------------------------------------------------------------------------------- 1 | # LM Evaluation Harness Setup + Sample Scripts 2 | 3 | To setup the evaluations, we clone the Language Model Evaluation Harness from [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/b281b0921b636bc36ad05c0b0b0763bd6dd43463) to a separate directory (e.g., outside the lolcats directory). 4 | 5 | - Note we use the `b281b09` branch following Hugging Face's [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard). 6 | 7 | ```bash 8 | git checkout b281b09 9 | ``` 10 | 11 | We then point to this path in `./lm_eval_harness/eval_lm_harness.py`, e.g. 12 | 13 | ```python 14 | LM_EVALUATION_HARNESS_PATH = '/juice2/scr2/mzhang/projects/lm-evaluation-harness' # Change this to where you clone LM eval harness from 15 | ``` 16 | 17 | You may also need to install the following packages: 18 | 19 | ```bash 20 | pip install --upgrade --force-reinstall sacrebleu 21 | pip install evaluate sqlitedict scikit-learn pycountry 22 | ``` 23 | 24 | Finally, we'll want to replace the current file in `lm-evaluation-harness/lm_eval/models/huggingface.py` with `lolcats/lm_eval_harness/models_huggingface.py` to better support loading our linearized checkpoints (some missing keyword args in the original... sorry). 25 | --- 26 | 27 | ## Running the evaluations 28 | 29 | All evaluation scripts take the following template: 30 | 31 | ```bash 32 | python lm_eval_harness/eval_lm_harness.py \ 33 | --model_type lolcats_ckpt \ 34 | --attn_mlp_checkpoint_path \ 35 | --finetune_checkpoint_path \ 36 | --task --num_shots --no_cache --verbose 37 | ``` 38 | 39 | We provide examples of such below. 40 | 41 | --- 42 | 43 | ### PiQA (zero-shot) 44 | 45 | ```bash 46 | python lm_eval_harness/eval_lm_harness.py \ 47 | --model_type lolcats_ckpt \ 48 | --attn_mlp_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1_distill.pt \ 49 | --finetune_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1-bs=1-gas=1-nte=2-se=0-re=614_ft.pt \ 50 | --task piqa --num_shots 0 --no_cache --verbose 51 | 52 | ``` 53 | 54 | ### ARC-Easy (zero-shot) 55 | 56 | ```bash 57 | python lm_eval_harness/eval_lm_harness.py \ 58 | --model_type lolcats_ckpt \ 59 | --attn_mlp_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1_distill.pt \ 60 | --finetune_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1-bs=1-gas=1-nte=2-se=0-re=614_ft.pt \ 61 | --task arc_easy --num_shots 0 --no_cache --verbose 62 | ``` 63 | 64 | ### ARC-Challenge (zero-shot) 65 | 66 | ```bash 67 | python lm_eval_harness/eval_lm_harness.py \ 68 | --model_type lolcats_ckpt \ 69 | --attn_mlp_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1_distill.pt \ 70 | --finetune_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1-bs=1-gas=1-nte=2-se=0-re=614_ft.pt \ 71 | --task arc_challenge --num_shots 0 --no_cache --verbose 72 | ``` 73 | 74 | ### Hellaswag (zero-shot) 75 | 76 | ```bash 77 | python lm_eval_harness/eval_lm_harness.py --model_type lolcats_ckpt \ 78 | --attn_mlp_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1_distill.pt \ 79 | --finetune_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1-bs=1-gas=1-nte=2-se=0-re=614_ft.pt \ 80 | --task hellaswag --num_shots 0 --no_cache --verbose 81 | ``` 82 | 83 | ### Winogrande (zero-shot) 84 | 85 | ```bash 86 | python lm_eval_harness/eval_lm_harness.py --model_type lolcats_ckpt \ 87 | --attn_mlp_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1_distill.pt \ 88 | --finetune_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1-bs=1-gas=1-nte=2-se=0-re=614_ft.pt \ 89 | --task winogrande --num_shots 0 --no_cache --verbose 90 | ``` 91 | 92 | ### MMLU (5-shot) 93 | 94 | ```bash 95 | python lm_eval_harness/eval_lm_harness.py --model_type lolcats_ckpt \ 96 | --attn_mlp_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1_distill.pt \ 97 | --finetune_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192-s=0-gas=1-nte=2-se=0-re=614-scl=1024-lzi=1-bs=1-gas=1-nte=2-se=0-re=614_ft.pt \ 98 | --task hendrycksTest --num_shots 5 --no_cache --verbose 99 | ``` 100 | 101 | ### 70B Evaluation 102 | 103 | For running 70B evals, we can use the following example 104 | 105 | ```bash 106 | python lm_eval_harness/eval_lm_harness_big.py \ 107 | --model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ 108 | --distill_config distill_alpaca_clean_llama3_1_70b_xent1_mse1000_lr1e-2 \ 109 | --finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ 110 | --lk_zero_init \ 111 | --verbose --replicate 0 --seed 0 \ 112 | --eval_steps 10 --dataset_chunk_size 256 \ 113 | --enable_fsdp --low_cpu_fsdp \ 114 | --task piqa --num_shots 0 --no_cache \ 115 | --attn_mlp_checkpoint_path /scr-ssd/mzhang/projects/lolcats/checkpoints/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01/distill-dl-d=distill_alpaca_clean_llama3_1_70b_xent1_mse1000_lr1e-2-m=distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean_llama3_1_70b-dcs=512-se=0-re=0-at=lolcats_llama_window_tk_bf16 \ 116 | --finetune_checkpoint_path /scr-ssd/mzhang/projects/lolcats/checkpoints/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01/finetune-dl-d=distill_alpaca_clean_llama3_1_70b_xent1_mse1000_lr1e-2-m=distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean_llama3_1_70b-dcs=256-se=0-re=0-at=lolcats_llama_window_tk_bf16-dcs=256-se=0-re=0 117 | ``` 118 | 119 | where `attn_mlp_checkpoint_path` and `finetune_checkpoint_path` now point to the directory paths where the sharded attention and finetune checkpoints are saved. 120 | -------------------------------------------------------------------------------- /lm_eval_harness/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/lolcats/375df84f58417a57b78c875fa3f8fa76e84ab12e/lm_eval_harness/__init__.py -------------------------------------------------------------------------------- /lm_eval_harness/eval_lm_harness.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate models with lm-evaluation-harness 3 | """ 4 | import sys 5 | import os 6 | from os.path import join 7 | os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' 8 | 9 | import argparse 10 | import torch 11 | import numpy as np 12 | import pandas as pd 13 | 14 | from src.model.load_model_for_eval import load_model_from_checkpoint, load_model_from_config 15 | 16 | LM_EVALUATION_HARNESS_PATH = '/juice2/scr2/mzhang/projects/lm-evaluation-harness' # Change this to where you clone LM eval harness from 17 | 18 | RESULTS_PATH = '/scr-ssd/mzhang/projects/lolcats/lm_eval_harness/results_lm_eval.csv' 19 | 20 | 21 | OPEN_LLM = [ # task, shots 22 | ('arc-challenge', 25), 23 | ('hellaswag', 10), 24 | ('truthfulqa-mc', 0), 25 | ('winogrande', 5), 26 | ('gsm8k', 5), 27 | ] 28 | ZERO_SHOT = [ 29 | ('hellaswag', 0), 30 | ('piqa', 0), 31 | ('arc-challenge', 0), 32 | ('arc-easy', 0), 33 | ('winogrande', 0), 34 | ('hendrycksTest', 5), 35 | ] 36 | 37 | 38 | def get_args(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--project_name", type=str, default='lolcats-eval') 41 | parser.add_argument("--model_type", type=str, default=None, 42 | choices=['lolcats_ckpt', 'model_config', 'huggingface']) 43 | parser.add_argument("--model_config", type=str, default=None) 44 | parser.add_argument("--cache_dir", type=str, default=None) 45 | 46 | parser.add_argument("--attn_mlp_checkpoint_path", type=str, default=None) 47 | parser.add_argument("--finetune_checkpoint_path", type=str, default=None) 48 | parser.add_argument("--config_dir", type=str, default='./configs') 49 | 50 | parser.add_argument("--task", type=str, default=None) 51 | parser.add_argument("--num_shots", type=int, default=0) 52 | 53 | # LM Evaluation Harness args 54 | parser.add_argument("--batch_size", type=int, default=1) 55 | parser.add_argument("--max_batch_size", type=int, default=1) 56 | parser.add_argument("--no_cache", action='store_true', default=None) 57 | parser.add_argument("--limit", type=float, default=None, # helpful for debugging 58 | help="Limit the number of examples per task. " 59 | "If <1, limit is a percentage of the total number of examples.") 60 | # Miscellaneous 61 | parser.add_argument("--verbose", action='store_true', default=False) 62 | parser.add_argument("--debug", action='store_true', default=False) 63 | parser.add_argument("--no_wandb", action='store_true', default=False) 64 | parser.add_argument("--wandb_entity", type=str, default='hazy-research') 65 | parser.add_argument("--replicate", type=int, default=None) 66 | 67 | args = parser.parse_args() 68 | 69 | args.run_name = f'd={args.task}-ns={args.num_shots}' 70 | if args.limit is not None: 71 | args.run_name += f'-li={args.limit}' 72 | if args.model_type == 'lolcats_ckpt': 73 | if args.finetune_checkpoint_path is not None: 74 | args.run_name += f"-c={args.finetune_checkpoint_path.split('/')[-1]}" 75 | elif args.model_type == 'model_config': 76 | args.run_name += f"-c={args.model_config}" 77 | if args.replicate is not None: 78 | args.run_name += f"-r={args.replicate}" 79 | return args 80 | 81 | 82 | def init_wandb(args): 83 | if args.no_wandb: 84 | wandb = None 85 | else: 86 | import wandb 87 | wandb.init(config={}, 88 | entity=args.wandb_entity, 89 | name=args.run_name, 90 | project=args.project_name) 91 | return wandb 92 | 93 | 94 | def create_new_save_dict(results_path): 95 | # Save locally 96 | if not os.path.isfile(results_path): 97 | results_dict = { 98 | 'task': [], 99 | 'shots': [], 100 | 'acc': [], 101 | 'acc_norm': [], 102 | 'acc_stderr': [], 103 | 'acc_norm_stderr': [], 104 | 'attn_mlp_path': [], 105 | 'finetune_path': [], 106 | } 107 | print(f'Creating new results dict at {results_path}') 108 | pd.DataFrame(results_dict).to_csv(results_path, index=False) 109 | results_dict = pd.read_csv(results_path).to_dict(orient='list') 110 | return results_dict 111 | 112 | def save_results_to_dict(results, results_dict, results_path, args): 113 | # Save results locally 114 | # results are lm_eval results 115 | results_dict['task'].append(args.task) 116 | results_dict['shots'].append(args.num_shots) 117 | if args.task in ['mmlu', 'hendrycksTest', 'mmlu_cloze', 'mmlu_2']: 118 | try: 119 | acc = sum(mmlu_accs) / len(mmlu_accs) 120 | acc_stderr = np.std(mmlu_acc) # stdev over samples 121 | except: 122 | acc = 0 123 | acc_stderr = 0 124 | acc_norm = 0 125 | acc_norm_stderr = 0 126 | else: 127 | acc = results['results'][args.task]['acc'] 128 | acc_stderr = results['results'][args.task]['acc_stderr'] 129 | try: 130 | acc_norm = results['results'][args.task]['acc_norm'] 131 | acc_norm_stderr = results['results'][args.task]['acc_norm_stderr'] 132 | except: 133 | acc_norm = 0 134 | acc_norm_stderr = 0 135 | results_dict['acc'].append(acc) 136 | results_dict['acc_stderr'].append(acc_stderr) 137 | results_dict['acc_norm'].append(acc_norm) 138 | results_dict['acc_norm_stderr'].append(acc_norm_stderr) 139 | results_dict['attn_mlp_path'].append(args.attn_mlp_checkpoint_path) 140 | results_dict['finetune_path'].append(args.finetune_checkpoint_path) 141 | pd.DataFrame(results_dict).to_csv(results_path, index=False) 142 | 143 | 144 | def main(): 145 | sys.path.append(LM_EVALUATION_HARNESS_PATH) 146 | from lm_eval import evaluator 147 | 148 | args = get_args() 149 | 150 | try: 151 | # Save locally 152 | results_dict = create_new_save_dict(RESULTS_PATH) 153 | if 'dl-d=drxmldl8lswfwflqr000_lzi=1_distill_' in args.finetune_checkpoint_path: 154 | finetune_flag = args.finetune_checkpoint_path.split('dl-d=drxmldl8lswfwflqr000_lzi=1_distill_')[-1].split('-')[0] 155 | _RESULTS_PATH = RESULTS_PATH.replace('.csv', f'-{finetune_flag}.csv') 156 | _results_dict = create_new_save_dict(_RESULTS_PATH) 157 | else: 158 | _RESULTS_PATH = None 159 | _results_dict = None 160 | except: 161 | pass 162 | 163 | if args.model_type == 'lolcats_ckpt': # load hedgehog model 164 | model, model_config, tokenizer = load_model_from_checkpoint( 165 | attn_mlp_checkpoint_path=args.attn_mlp_checkpoint_path, 166 | finetune_checkpoint_path=args.finetune_checkpoint_path, 167 | config_dir=args.config_dir, 168 | print_model=args.verbose, 169 | debug=args.debug, 170 | lm_eval_model=True, 171 | path_to_lm_eval_harness=LM_EVALUATION_HARNESS_PATH, 172 | ) 173 | elif args.model_type == 'model_config': 174 | model, model_config, tokenizer = load_model_from_config( 175 | model_config_name=args.model_config, 176 | config_dir=args.config_dir, 177 | lm_eval_model=True, 178 | path_to_lm_eval_harness=LM_EVALUATION_HARNESS_PATH, 179 | ) 180 | elif args.model_type == 'huggingface': 181 | from lm_eval.models import get_model 182 | model = get_model('hf-causal-experimental').create_from_arg_string( 183 | '', {'cache_dir': args.cache_dir} 184 | ) 185 | 186 | try: 187 | device = model.device 188 | except: 189 | try: 190 | device = model.model.device 191 | except: 192 | device = torch.device('cuda:0') 193 | 194 | # WandB logging 195 | wandb = init_wandb(args) 196 | if wandb is not None: 197 | attn_mlp_checkpoint = (args.attn_mlp_checkpoint_path.split('/')[-1] 198 | if args.attn_mlp_checkpoint_path is not None else None) 199 | finetune_checkpoint = (args.finetune_checkpoint_path.split('/')[-1] 200 | if args.finetune_checkpoint_path is not None else None) 201 | wandb.config.update({ 202 | 'model_type': args.model_type, 203 | 'model_config': args.model_config, 204 | 'attn_mlp_checkpoint': attn_mlp_checkpoint, 205 | 'finetune_checkpoint': finetune_checkpoint, 206 | 'task': args.task, 207 | 'num_shots': args.num_shots, 208 | 'batch_size': args.batch_size, 209 | 'max_batch_size': args.max_batch_size, 210 | }) 211 | 212 | if args.task in ['mmlu', 'hendrycksTest', 'mmlu_cloze', 'mmlu_2']: 213 | from lm_eval.tasks import TASK_REGISTRY 214 | tasks = sorted([k for k in TASK_REGISTRY.keys() if f'{args.task}-' in k]) 215 | else: 216 | tasks = [args.task] 217 | 218 | results = evaluator.simple_evaluate( 219 | model=model, 220 | model_args='', 221 | tasks=tasks, 222 | num_fewshot=args.num_shots, 223 | batch_size=args.batch_size, 224 | max_batch_size=args.max_batch_size, 225 | device=device, 226 | no_cache=args.no_cache, 227 | limit=args.limit, 228 | description_dict={}, # description_dict, 229 | decontamination_ngrams_path=None, # args.decontamination_ngrams_path, 230 | check_integrity=None, # args.check_integrity, 231 | write_out=False, # args.write_out, 232 | output_base_path=None, # args.output_base_path, 233 | ) 234 | 235 | if args.task in ['mmlu', 'hendrycksTest', 'mmlu_cloze', 'mmlu_2']: 236 | mmlu_accs = [] 237 | for k, v in results['results'].items(): 238 | if args.task in k: 239 | mmlu_accs.append(v['acc']) 240 | print(mmlu_accs) 241 | if len(mmlu_accs) > 0: 242 | results['results']['mmlu'] = {'acc': sum(mmlu_accs) / len(mmlu_accs)} 243 | 244 | print('MMLU RESULT:', results['results']['mmlu']) 245 | print(results) 246 | 247 | if wandb is not None: 248 | wandb.log(results) 249 | 250 | save_results_to_dict(results, results_dict, RESULTS_PATH, args) 251 | if _results_dict is not None: 252 | save_results_to_dict(results, _results_dict, _RESULTS_PATH, args) 253 | 254 | 255 | if __name__ == '__main__': 256 | main() 257 | -------------------------------------------------------------------------------- /lm_eval_harness/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inherit from lm-evaluation-harness/lm_eval/models/huggingface.py to load linearized models 3 | """ 4 | from lm_eval.models.huggingface import AutoCausalLM 5 | from src.model.modeling_llama import LolcatsLlamaForCausalLM as LOLCATS_LLAMA_MODEL_CLASS 6 | from src.model.modeling_mistral import LolcatsMistralForCausalLM as LOLCATS_MISTRAL_MODEL_CLASS 7 | 8 | from src.model.modeling_llama import LooooolcatsLlamaForCausalLM as LOOOOOLCATS_LLAMA_MODEL_CLASS 9 | from src.model.modeling_mistral import LooooolcatsMistralForCausalLM as LOOOOOLCATS_MISTRAL_MODEL_CLASS 10 | 11 | from src.model.modeling_llama_sharded import ShardedLolcatsLlamaForCausalLM as SHARDED_LOLCATS_LLAMA_MODEL_CLASS 12 | 13 | 14 | class LolcatsLlamaForCausalLM(AutoCausalLM): 15 | """ 16 | Wrapper for Llama-like autoregressive language model 17 | """ 18 | AUTO_MODEL_CLASS = LOLCATS_LLAMA_MODEL_CLASS 19 | @property 20 | def add_special_tokens(self) -> bool: 21 | """Whether to include special tokens in encoded text. This should be 22 | determined by whether or not the model was trained with special tokens. 23 | TODO: Remove these conditionals once HuggingFace supports a way to 24 | check whether or not an arbitrary model was trained with special tokens. 25 | """ 26 | if self._add_special_tokens is not None: 27 | return self._add_special_tokens 28 | else: 29 | return False 30 | 31 | 32 | class LolcatsMistralForCausalLM(AutoCausalLM): 33 | """ 34 | Wrapper for Mistral-like autoregressive language model 35 | """ 36 | AUTO_MODEL_CLASS = LOLCATS_MISTRAL_MODEL_CLASS 37 | @property 38 | def add_special_tokens(self) -> bool: 39 | """Whether to include special tokens in encoded text. This should be 40 | determined by whether or not the model was trained with special tokens. 41 | TODO: Remove these conditionals once HuggingFace supports a way to 42 | check whether or not an arbitrary model was trained with special tokens. 43 | """ 44 | if self._add_special_tokens is not None: 45 | return self._add_special_tokens 46 | else: 47 | return False 48 | 49 | 50 | class ShardedLolcatsLlamaForCausalLM(AutoCausalLM): 51 | """ 52 | Wrapper for Llama or Mistral-like autoregressive language model 53 | """ 54 | AUTO_MODEL_CLASS = SHARDED_LOLCATS_LLAMA_MODEL_CLASS 55 | @property 56 | def add_special_tokens(self) -> bool: 57 | """Whether to include special tokens in encoded text. This should be 58 | determined by whether or not the model was trained with special tokens. 59 | TODO: Remove these conditionals once HuggingFace supports a way to 60 | check whether or not an arbitrary model was trained with special tokens. 61 | """ 62 | if self._add_special_tokens is not None: 63 | return self._add_special_tokens 64 | else: 65 | return False 66 | 67 | 68 | # class ShardedRollLolcatsLlamaForCausalLM(AutoCausalLM): 69 | # """ 70 | # Wrapper for Llama or Mistral-like autoregressive language model 71 | # """ 72 | # AUTO_MODEL_CLASS = SHARDED_ROLL_LOLCATS_LLAMA_MODEL_CLASS 73 | # @property 74 | # def add_special_tokens(self) -> bool: 75 | # """Whether to include special tokens in encoded text. This should be 76 | # determined by whether or not the model was trained with special tokens. 77 | # TODO: Remove these conditionals once HuggingFace supports a way to 78 | # check whether or not an arbitrary model was trained with special tokens. 79 | # """ 80 | # if self._add_special_tokens is not None: 81 | # return self._add_special_tokens 82 | # else: 83 | # return False 84 | 85 | 86 | class LooooolcatsLlamaForCausalLM(AutoCausalLM): 87 | """ 88 | Wrapper for Llama-like autoregressive language model 89 | """ 90 | AUTO_MODEL_CLASS = LOOOOOLCATS_LLAMA_MODEL_CLASS 91 | @property 92 | def add_special_tokens(self) -> bool: 93 | """Whether to include special tokens in encoded text. This should be 94 | determined by whether or not the model was trained with special tokens. 95 | TODO: Remove these conditionals once HuggingFace supports a way to 96 | check whether or not an arbitrary model was trained with special tokens. 97 | """ 98 | if self._add_special_tokens is not None: 99 | return self._add_special_tokens 100 | else: 101 | return False 102 | 103 | 104 | class LooooolcatsMistralForCausalLM(AutoCausalLM): 105 | """ 106 | Wrapper for Mistral-like autoregressive language model 107 | """ 108 | AUTO_MODEL_CLASS = LOOOOOLCATS_MISTRAL_MODEL_CLASS 109 | @property 110 | def add_special_tokens(self) -> bool: 111 | """Whether to include special tokens in encoded text. This should be 112 | determined by whether or not the model was trained with special tokens. 113 | TODO: Remove these conditionals once HuggingFace supports a way to 114 | check whether or not an arbitrary model was trained with special tokens. 115 | """ 116 | if self._add_special_tokens is not None: 117 | return self._add_special_tokens 118 | else: 119 | return False 120 | -------------------------------------------------------------------------------- /lolcats_preprint_v0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/lolcats/375df84f58417a57b78c875fa3f8fa76e84ab12e/lolcats_preprint_v0.pdf -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/lolcats/375df84f58417a57b78c875fa3f8fa76e84ab12e/src/__init__.py -------------------------------------------------------------------------------- /src/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Load dataloaders 3 | """ 4 | import importlib 5 | 6 | 7 | def load_data(dataset_config: dict, dataloader_config: dict): 8 | """Return dataloaders from dataset_config""" 9 | try: 10 | dataset_module = importlib.import_module(f'dataloaders.{dataset_config["name"]}') 11 | except Exception: 12 | try: 13 | dataset_module = importlib.import_module(f'src.dataloaders.{dataset_config["name"]}') 14 | except Exception as e2: 15 | print(e2) 16 | try: # e.g., tasks like GLUE where name is benchmark and path specifies the dataset / task 17 | dataset_module = importlib.import_module(f'dataloaders.{dataset_config["path"]}') 18 | except Exception as e3: 19 | print(f'Error from {dataset_config}') 20 | raise e3 21 | _load_data = getattr(dataset_module, 'load_data') 22 | return _load_data(**dataset_config, **dataloader_config) -------------------------------------------------------------------------------- /src/dataloaders/alpaca_clean.py: -------------------------------------------------------------------------------- 1 | """ 2 | Alpaca training dataloaders 3 | 4 | We adopt the original prompt template; goes something like: 5 | ``` 6 | Below is an instruction that describes a task. 7 | Write a response that appropriately completes the request. 8 | ### Instruction: 9 | {instruction} 10 | 11 | ### Response: 12 | {response} 13 | ``` 14 | See `PROMPT_DICT` for more. 15 | """ 16 | from functools import partial 17 | from os.path import join 18 | 19 | from datasets import load_metric, load_dataset 20 | 21 | from .utils import ( 22 | get_lm_loader, get_seq2seq_loader, 23 | convert_to_hf_dataset, 24 | get_tokenizer_from_config, 25 | download_scrolls_metric as download_metric 26 | ) 27 | from .utils.packing import ConcatDataset 28 | 29 | 30 | PROMPT_DICT = { 31 | "prompt_input": ( 32 | "Below is an instruction that describes a task, paired with an input that provides further context. " 33 | "Write a response that appropriately completes the request.\n\n" 34 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" 35 | ), 36 | "prompt_no_input": ( 37 | "Below is an instruction that describes a task. " 38 | "Write a response that appropriately completes the request.\n\n" 39 | "### Instruction:\n{instruction}\n\n### Response:\n" 40 | ), 41 | } 42 | 43 | 44 | def load_data(name: str, dataset_config: dict, pretrained_model_config: dict, 45 | preprocess_config: dict, **loader_kwargs: any): 46 | """ 47 | Shared function to load dataset from experiment config 48 | -> e.g., see configs/experiments/distill_alpaca_clean_lr1e-2.yaml 49 | """ 50 | # Misc. setup 51 | cache_dir = dataset_config['cache_dir'] 52 | input_len = dataset_config['chunk_size'] 53 | concat_data = dataset_config['concat_data'] 54 | 55 | tokenizer_name = pretrained_model_config['pretrained_model_name_or_path'] 56 | tokenizer_name = tokenizer_name.split('/')[-1] 57 | # save_path = join(cache_dir, f'{name}_{tokenizer_name}') 58 | 59 | # Setup tokenizer 60 | tokenizer = get_tokenizer_from_config(pretrained_model_config) 61 | if tokenizer.pad_token is None: 62 | tokenizer.pad_token = tokenizer.eos_token 63 | print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}') 64 | 65 | tokenizer.padding_side = 'left' # for decoder-only generation 66 | # Get initial data 67 | ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs'] 68 | dataset = load_dataset( 69 | **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs} 70 | ) 71 | if dataset_config['name'] == 'samsum': # hack 72 | dataset = dataset.rename_column('dialogue', 'input') 73 | dataset = dataset.rename_column('summary', 'output') 74 | _instruction = 'Summarize this dialogue.' 75 | for split in dataset.keys(): 76 | dataset[split] = dataset[split].add_column( 77 | 'instruction', [_instruction] * len(dataset[split]) 78 | ) 79 | train_set, val_set, test_set = dataset['train'], dataset['validation'], dataset['test'] 80 | dataset = train_set # hack to work with below code 81 | else: 82 | dataset = dataset['train'] 83 | train_set = convert_to_hf_dataset([dataset[ix] for ix in range(200, len(dataset))], cache_dir) 84 | val_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir) 85 | test_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir) 86 | 87 | # Convert to dicts of {input_ids, attention_mask, labels} 88 | train_set = train_set.map( 89 | partial(template_and_tokenize, tokenizer=tokenizer, include_label=True), 90 | remove_columns=list(dataset.features),) # load_from_cache_file=False) 91 | val_set = val_set.map( 92 | partial(template_and_tokenize, tokenizer=tokenizer, include_label=True), 93 | remove_columns=list(dataset.features),) # load_from_cache_file=False) 94 | test_set = test_set.map( 95 | partial(template_and_tokenize, tokenizer=tokenizer, include_label=False), 96 | remove_columns=list(dataset.features),) # load_from_cache_file=False) 97 | 98 | # Chunk together train and val sets 99 | if concat_data: 100 | train_set = ConcatDataset(train_set, chunk_size=input_len) 101 | val_set = ConcatDataset(val_set, chunk_size=input_len) 102 | 103 | # Get dataloaders 104 | dataloaders = { 105 | 'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs), 106 | 'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs), 107 | 'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs), 108 | } 109 | # Evaluation metric 110 | try: 111 | metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge 112 | except Exception as e: 113 | print(f'Error loading metric: {e}') 114 | metric = None 115 | 116 | # Finishing touches 117 | for k, v in dataloaders.items(): # Make tokenizer accessible 118 | dataloaders[k].dataset.tokenizer = tokenizer 119 | dataloaders[k].dataset.metric = metric 120 | return dataloaders 121 | 122 | 123 | def template_and_tokenize(sample, tokenizer, include_label: bool = True): 124 | """ 125 | Format dataset context and answers into single-sequence prompts 126 | """ 127 | if sample.get('input', '') == '': 128 | prompt = PROMPT_DICT["prompt_no_input"].format_map(sample) 129 | else: 130 | prompt = PROMPT_DICT["prompt_input"].format_map(sample) 131 | 132 | prompt = tokenizer.encode(prompt, add_special_tokens=True) 133 | if include_label: 134 | answer = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}', 135 | add_special_tokens=False) 136 | target = None 137 | else: 138 | answer = [] 139 | target = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}', 140 | add_special_tokens=False) 141 | input_ids = prompt + answer 142 | attn_mask = [1] * len(input_ids) 143 | 144 | sample = { 145 | "input_ids": input_ids, 146 | "attention_mask" : attn_mask, 147 | "labels": [-100] * len(prompt) + answer if include_label else target, 148 | } 149 | return sample 150 | -------------------------------------------------------------------------------- /src/dataloaders/alpaca_clean_instruct.py: -------------------------------------------------------------------------------- 1 | """ 2 | Alpaca Clean dataset with Llama3-Instruct prompt formatting 3 | """ 4 | 5 | from functools import partial 6 | from os.path import join 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | import torch 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | from datasets import load_metric, load_dataset 15 | from transformers import AutoTokenizer 16 | from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, DataCollatorWithPadding 17 | 18 | from .utils import ( 19 | get_lm_loader, get_seq2seq_loader, 20 | convert_to_hf_dataset, 21 | get_tokenizer_from_config, 22 | download_scrolls_metric as download_metric 23 | ) 24 | from .utils.packing import ConcatDataset 25 | 26 | 27 | SYSTEM_PROMPT = "You are a helpful AI assistant who always responds to appropriately complete a user's request." 28 | 29 | 30 | def encode_response(response: str, tokenizer) -> list[int]: 31 | tokens = tokenizer.encode(response.strip(), add_special_tokens=False) 32 | # For Llama 3 Instruct: tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"]) 33 | tokens.append(tokenizer.eos_token_id) 34 | try: # Llama 3 Instruct 35 | tokens.append(tokenizer.get_added_vocab()["<|end_of_text|>"]) 36 | except KeyError: 37 | pass 38 | return tokens 39 | 40 | 41 | def load_data(name: str, dataset_config: dict, pretrained_model_config: dict, 42 | preprocess_config: dict, **loader_kwargs: any): 43 | 44 | # Misc. setup 45 | cache_dir = dataset_config['cache_dir'] 46 | input_len = dataset_config['chunk_size'] 47 | concat_data = dataset_config['concat_data'] 48 | load_from_cache_file = False # False if want to retokenize dataset 49 | 50 | # Hard-code system prompt handling 51 | if 'istral' in pretrained_model_config['pretrained_model_name_or_path']: 52 | system_prompt = '' 53 | else: 54 | system_prompt = SYSTEM_PROMPT 55 | 56 | tokenizer_name = pretrained_model_config['pretrained_model_name_or_path'] 57 | tokenizer_name = tokenizer_name.split('/')[-1] 58 | save_path = join(cache_dir, f'{name}_{tokenizer_name}') 59 | 60 | # Setup tokenizer 61 | tokenizer = get_tokenizer_from_config(pretrained_model_config) 62 | if tokenizer.pad_token is None: 63 | tokenizer.pad_token = tokenizer.eos_token 64 | print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}') 65 | 66 | tokenizer.padding_side = 'left' # for decoder-only generation 67 | 68 | # Get initial data 69 | ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs', 'system_prompt', 'name'] 70 | train_set = load_dataset( 71 | **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}, 72 | split='train[100:-100]', 73 | ) 74 | val_set = load_dataset( # we just use this dataset as a validation set 75 | **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}, 76 | split='train[:100]+train[-100:]', 77 | ) 78 | test_set = load_dataset( 79 | **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}, 80 | split='train[:100]+train[-100:]', 81 | ) 82 | 83 | # Convert to dicts of {input_ids, attention_mask, labels} 84 | train_set = train_set.map(partial(template_and_tokenize, tokenizer=tokenizer, 85 | include_label=True, system_prompt=system_prompt), 86 | remove_columns=list(train_set.features), 87 | load_from_cache_file=load_from_cache_file) 88 | val_set = val_set.map(partial(template_and_tokenize, tokenizer=tokenizer, 89 | include_label=True, system_prompt=system_prompt), 90 | remove_columns=list(val_set.features), 91 | load_from_cache_file=load_from_cache_file) 92 | test_set = test_set.map(partial(template_and_tokenize, tokenizer=tokenizer, 93 | include_label=False, system_prompt=system_prompt), 94 | remove_columns=list(test_set.features), 95 | load_from_cache_file=load_from_cache_file) 96 | 97 | # Chunk together train and val sets 98 | if concat_data: 99 | train_set = ConcatDataset(train_set, chunk_size=input_len) 100 | val_set = ConcatDataset(val_set, chunk_size=input_len) 101 | 102 | # Get dataloaders 103 | dataloaders = { 104 | 'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs), 105 | 'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs), 106 | 'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs), 107 | } 108 | # Evaluation metric 109 | metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge 110 | 111 | # Finishing touches 112 | for k, v in dataloaders.items(): # Make tokenizer accessible 113 | dataloaders[k].dataset.tokenizer = tokenizer 114 | dataloaders[k].dataset.metric = metric 115 | return dataloaders 116 | 117 | 118 | def template_and_tokenize(sample, tokenizer, include_label: bool = True, 119 | system_prompt: str = None): 120 | if system_prompt is None: 121 | system_prompt = SYSTEM_PROMPT 122 | 123 | prompt = sample['instruction'] 124 | if sample['input'] != '': 125 | prompt += f"\n\n{sample['input']}" 126 | 127 | messages = [ 128 | {"role": "system", "content": system_prompt}, 129 | ] if system_prompt != '' else [] 130 | messages.append({"role": "user", "content": prompt}) 131 | prompt_ids = tokenizer.apply_chat_template( 132 | messages, tokenize=True, add_generation_prompt=True, 133 | ) 134 | if include_label: 135 | answer = encode_response(sample['output'], tokenizer) 136 | else: 137 | answer = [] 138 | target = encode_response(sample['output'], tokenizer) 139 | 140 | input_ids = prompt_ids + answer 141 | attn_mask = [1] * len(input_ids) 142 | sample = { 143 | "input_ids": input_ids, 144 | "attention_mask" : attn_mask, 145 | "labels": [-100] * len(prompt_ids) + answer if include_label else target, 146 | } 147 | return sample 148 | 149 | -------------------------------------------------------------------------------- /src/dataloaders/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions dataset setup and loading 3 | """ 4 | from .setup import * 5 | -------------------------------------------------------------------------------- /src/dataloaders/utils/llama3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data utils for Llama3 3 | """ 4 | 5 | def encode_header(message: str, tokenizer) -> list[int]: 6 | tokens = [] 7 | tokens.append(tokenizer.get_added_vocab()["<|start_header_id|>"]) 8 | tokens.extend(tokenizer.encode(message["role"], add_special_tokens=False)) 9 | tokens.append(tokenizer.get_added_vocab()["<|end_header_id|>"]) 10 | tokens.extend(tokenizer.encode("\n\n", add_special_tokens=False)) 11 | return tokens 12 | 13 | 14 | def encode_message(message: str, tokenizer, include_header: bool = True) -> list[int]: 15 | tokens = encode_header(message, tokenizer) if include_header else [] 16 | tokens.extend( 17 | tokenizer.encode(message["content"].strip(), add_special_tokens=False) 18 | ) 19 | tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"]) 20 | return tokens 21 | 22 | 23 | def template_and_tokenize(sample, tokenizer, include_label: bool = True, 24 | system_prompt: str = None): 25 | if system_prompt is not None: 26 | dialog = [{'role': 'system', 'content': system_prompt}] 27 | else: 28 | dialog = [] 29 | 30 | chat = [] 31 | instruction = sample['instruction'] 32 | if sample['input'] != '': 33 | instruction += f"\n\n{sample['input']}" 34 | dialog.extend([ 35 | {'role': 'user', 'content': instruction}, 36 | {'role': 'assistant', 'content': sample['output']}, 37 | ]) 38 | 39 | prompt = [] 40 | prompt.append(tokenizer.get_added_vocab()["<|begin_of_text|>"]) 41 | for message in dialog[:-1]: 42 | prompt.extend(encode_message(message, tokenizer)) 43 | 44 | if include_label: 45 | answer = encode_message(dialog[-1], tokenizer) 46 | answer.append(tokenizer.get_added_vocab()["<|end_of_text|>"]) 47 | else: 48 | answer = [] 49 | target = encode_message(dialog[-1], tokenizer, include_header=False) 50 | target.append(tokenizer.get_added_vocab()["<|end_of_text|>"]) 51 | # Add the start of an assistant message for the model to complete. 52 | prompt.extend(encode_header({"role": "assistant", "content": ""}, tokenizer)) 53 | 54 | input_ids = prompt + answer 55 | attn_mask = [1] * len(input_ids) 56 | 57 | sample = { 58 | "input_ids": input_ids, 59 | "attention_mask" : attn_mask, 60 | "labels": [-100] * len(prompt) + answer if include_label else target, 61 | } 62 | return sample -------------------------------------------------------------------------------- /src/dataloaders/utils/packing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | """ 4 | Copied from https://github.com/meta-llama/llama-recipes/blob/9b3dabcaac78980eae40005bbc8b1a8276c82af3/src/llama_recipes/data/concatenator.py#L1 5 | """ 6 | import random 7 | from itertools import chain 8 | from tqdm import tqdm 9 | 10 | 11 | from torch.utils.data import Dataset 12 | 13 | 14 | class Concatenator(object): 15 | def __init__(self, chunk_size=2048): 16 | self.chunk_size=chunk_size 17 | self.residual = {"input_ids": [], "attention_mask": []} 18 | 19 | def __call__(self, batch): 20 | concatenated_samples = { 21 | k: v + list(chain(*batch[k])) for k, v in self.residual.items() 22 | } 23 | 24 | total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]]) 25 | 26 | if total_length >= self.chunk_size: 27 | chunk_num = total_length // self.chunk_size 28 | result = { 29 | k: [ 30 | v[i : i + self.chunk_size] 31 | for i in range(0, chunk_num * self.chunk_size, self.chunk_size) 32 | ] 33 | for k, v in concatenated_samples.items() 34 | } 35 | self.residual = { 36 | k: v[(chunk_num * self.chunk_size) :] 37 | for k, v in concatenated_samples.items() 38 | } 39 | else: 40 | result = concatenated_samples 41 | self.residual = {k: [] for k in concatenated_samples.keys()} 42 | 43 | result["labels"] = result["input_ids"].copy() 44 | 45 | return result 46 | 47 | class ConcatDataset(Dataset): 48 | """ 49 | Concatenates or packs samples of a dataset into chunks of size `chunk_size` 50 | """ 51 | def __init__(self, dataset, chunk_size: int = 1024, seed: int = 42,) -> None: 52 | self.dataset = dataset 53 | self.chunk_size = chunk_size 54 | self.samples = [] 55 | buffer = { 56 | "input_ids": [], 57 | "attention_mask": [], 58 | "labels": [], 59 | } 60 | random.seed(seed) 61 | for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): 62 | buffer = {k: v + sample[k] for k,v in buffer.items()} 63 | 64 | while len(next(iter(buffer.values()))) > self.chunk_size: 65 | self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) 66 | buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} 67 | # Slow hack, but filter out any samples without valid labels (all -100) 68 | self.filtered_samples = [] 69 | for s in self.samples: 70 | if sum(s['labels']) != chunk_size * -100: 71 | self.filtered_samples.append(s) 72 | if len(self.filtered_samples) < len(self.samples): 73 | print(f'OG dataset: {len(self.samples)} samples -> Filtered dataset: {len(self.filtered_samples)}') 74 | print(f'-> Filtered out {len(self.samples) - len(self.filtered_samples)} samples') 75 | 76 | def __getitem__(self, idx): 77 | return self.filtered_samples[idx] 78 | 79 | def __len__(self): 80 | return len(self.filtered_samples) 81 | -------------------------------------------------------------------------------- /src/dataloaders/utils/setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions dataset setup and loading 3 | """ 4 | import os 5 | from os.path import join 6 | import shutil 7 | import numpy as np 8 | 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from datasets import Dataset as HFDataset 12 | from huggingface_hub import hf_hub_download 13 | from transformers import AutoTokenizer, LlamaTokenizer 14 | from transformers import DataCollatorForSeq2Seq 15 | # from transformers import DefaultDataCollator, DataCollatorWithPadding 16 | 17 | 18 | def get_seq2seq_loader(dataset: Dataset, tokenizer: AutoTokenizer, 19 | split: str, **loader_kwargs: any): 20 | """ 21 | Get dataloader for seq2seq tasks (evaluation) 22 | """ 23 | tokenizer.padding_side = 'right' 24 | collate_fn = DataCollatorForSeq2Seq( 25 | tokenizer, label_pad_token_id=-100, return_tensors='pt') 26 | return DataLoader( 27 | dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs) 28 | 29 | 30 | def get_lm_loader(dataset: Dataset, tokenizer: AutoTokenizer, 31 | split: str, max_length: int = None, **loader_kwargs: any): 32 | """ 33 | Get dataloader for language modeling (training) 34 | -> Currently this ends up being the same as get_seq2seq_loader 35 | """ 36 | # collate_fn = DefaultDataCollator(return_tensors='pt') 37 | # collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, padding=True, 38 | # max_length=max_length, return_tensors='pt') 39 | collate_fn = DataCollatorForSeq2Seq( 40 | tokenizer, label_pad_token_id=-100, return_tensors='pt') 41 | return DataLoader( 42 | dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs) 43 | 44 | 45 | def convert_to_hf_dataset(dataset, cache_dir: str): 46 | """ 47 | Convert iterable dataset to HuggingFace HFDataset object 48 | """ 49 | def gen(): 50 | for _, sample in enumerate(dataset): 51 | yield sample # dataset[idx] 52 | return HFDataset.from_generator(gen, cache_dir=cache_dir) 53 | 54 | 55 | def get_tokenizer_from_config(model_config): 56 | """ 57 | Get pretrained tokenizer based on (pretrained) model config 58 | """ 59 | # Get tokenizer 60 | if 'llama' in model_config['pretrained_model_name_or_path']: 61 | try: # if we store locally 62 | model_path = join(model_config['cache_dir'], 63 | model_config['pretrained_model_name_or_path']) 64 | tokenizer = LlamaTokenizer.from_pretrained(model_path) 65 | except Exception as e: 66 | try: 67 | tokenizer = AutoTokenizer.from_pretrained(**model_config) 68 | print("-> Bad LlamaTokenizer.from_pretrained(model_path)", e) 69 | print("-> But resolved with: AutoTokenizer.from_pretrained(**model_config)") 70 | except Exception as e2: 71 | print("-> Error with AutoTokenizer.from_pretrained(**model_config)", e2) 72 | # tokenizer = LlamaTokenizer.from_pretrained(**model_config) # v4.43 errors with `*** TypeError: not a string` 73 | elif 'Mistral-7B-Instruct-v0.3' in model_config['pretrained_model_name_or_path']: 74 | tokenizer = LlamaTokenizer.from_pretrained(**model_config) # hack where AutoTokenizer doesn't recognize 75 | elif 'Mistral-7B' in model_config['pretrained_model_name_or_path']: 76 | tokenizer = AutoTokenizer.from_pretrained(**model_config) 77 | else: 78 | tokenizer = AutoTokenizer.from_pretrained(**model_config) 79 | return tokenizer 80 | 81 | 82 | def add_special_tokens_to_dataset(dataset, tokenizer): 83 | """ 84 | Add special tokens as attributes to a dataset object 85 | """ 86 | token_map = {k: v for k, v in tokenizer.special_tokens_map.items()} 87 | special_ids = tokenizer.all_special_ids 88 | for idx, k in enumerate(tokenizer.special_tokens_map.keys()): 89 | token_map[f'{k}_id'] = special_ids[idx] 90 | for k, v in token_map.items(): 91 | setattr(dataset, k, v) 92 | return dataset 93 | 94 | 95 | def train_test_split(samples: any, train_size: int, test_size: int, seed: int): 96 | """ 97 | Split samples into train and test sets 98 | """ 99 | try: 100 | assert len(samples) == train_size + test_size 101 | except Exception as e: 102 | print(len(samples), train_size + test_size) 103 | raise e 104 | arange = np.arange(len(samples)) 105 | np.random.seed(seed) 106 | test_idx = np.random.choice(arange, size=test_size, replace=False) 107 | train_idx = np.setdiff1d(arange, test_idx) 108 | return samples[train_idx], samples[test_idx] 109 | 110 | 111 | def download_scrolls_metric(): 112 | """ 113 | Download ROUGE, F1, and other accuracy metrics included in the SCROLLS dataset 114 | """ 115 | scrolls_metric_path = hf_hub_download( 116 | repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset" 117 | ) 118 | updated_scrolls_metric_path = ( 119 | os.path.dirname(scrolls_metric_path) + 120 | os.path.basename(scrolls_metric_path).replace(".", "_") + ".py" 121 | ) 122 | shutil.copy(scrolls_metric_path, updated_scrolls_metric_path) 123 | return updated_scrolls_metric_path 124 | -------------------------------------------------------------------------------- /src/finetune.py: -------------------------------------------------------------------------------- 1 | """ 2 | Finetuning functions to do post-distillation 3 | """ 4 | from os.path import join 5 | from omegaconf import OmegaConf 6 | 7 | import torch 8 | from torch.nn import Module 9 | 10 | from src.utils.setup import update_config_from_args 11 | from src.dataloaders import load_data 12 | from src.trainer import get_trainer, get_optimizer, get_scheduler 13 | 14 | 15 | def prepare_finetune_configs(args, model_config: dict, 16 | finetune_config_name: str = None, 17 | finetune_checkpoint_name: str = None, 18 | config_dir='./configs/experiment'): 19 | """ 20 | Prepare finetuning configs 21 | """ 22 | # Load finetuning config 23 | finetune_config = (finetune_config_name if finetune_config_name is not None else 24 | finetune_checkpoint_name.split('-f=')[-1].split('-')[0]) 25 | finetune_config_path = join(config_dir, f'{finetune_config}.yaml') 26 | finetune_config = OmegaConf.load(finetune_config_path) 27 | finetune_config = update_config_from_args(finetune_config, args, 28 | ignore_args=['lr', 'weight_decay']) 29 | # Update data tokenizer to match model 30 | if getattr(finetune_config.dataset, 'pretrained_model_config', None) is not None: 31 | for k in ['pretrained_model_name_or_path', 'cache_dir']: 32 | finetune_config.dataset.pretrained_model_config[k] = model_config['model'][k] 33 | # Set finetuning args 34 | for arg, argv in finetune_config.trainer.items(): 35 | if arg != 'name': 36 | setattr(args, arg, argv) 37 | for _config in ['dataloader', 'optimizer', 'lr_scheduler']: 38 | setattr(args, _config, OmegaConf.to_container(getattr(finetune_config, _config))) 39 | return finetune_config, args 40 | 41 | 42 | def get_finetuner(model: Module, finetune_config: dict, device: torch.device, 43 | args: any, wandb: any, initial_eval: bool = False): 44 | """ 45 | Initialize finetuning trainer 46 | """ 47 | model.to(device) # if using a fused optimizer 48 | model.train() 49 | 50 | # Initialize optimizer and scheduler 51 | optimizer = get_optimizer(model=model, **finetune_config.optimizer) 52 | scheduler = get_scheduler(optimizer=optimizer, **finetune_config.lr_scheduler) 53 | 54 | dataloaders = load_data(finetune_config.dataset, finetune_config.dataloader) 55 | train_loader = dataloaders[finetune_config.trainer.train_split] 56 | eval_loader = dataloaders[finetune_config.trainer.val_split] 57 | 58 | OurTrainer = get_trainer(finetune_config.trainer.name) 59 | trainer = OurTrainer(model=model, 60 | args=args, 61 | train_loader=train_loader, 62 | eval_loader=eval_loader, 63 | optimizer_and_scheduler=(optimizer, scheduler), 64 | device=device, 65 | wandb=wandb, 66 | checkpoint_suffix='_ft', 67 | **finetune_config.trainer) 68 | return trainer 69 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/lolcats/375df84f58417a57b78c875fa3f8fa76e84ab12e/src/model/__init__.py -------------------------------------------------------------------------------- /src/model/convert_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Attention conversion helpers 3 | """ 4 | from functools import partial 5 | from tqdm import tqdm 6 | import torch.nn as nn 7 | 8 | 9 | def convert_attention(model: nn.Module, 10 | attention_config: dict, 11 | train_attention: bool = False, 12 | remove_base_attn: bool = True,): 13 | """ 14 | Call to convert all attention layers 15 | """ 16 | softmax_attns = [] 17 | if 'softmax_attentions' in attention_config: 18 | softmax_attns = attention_config['softmax_attentions'] 19 | if attention_config.attention_type != 'softmax': 20 | layers = traverse_layers(model) 21 | for layer_idx, layer in enumerate(tqdm(layers, desc='Converting attentions...')): 22 | if layer_idx not in softmax_attns: 23 | layer.self_attn = convert_llama_attention( 24 | layer, attention_config, layers, train_attention, remove_base_attn, 25 | ) 26 | layer.self_attn.converted = True 27 | else: # Freeze any preserved softmax attention layers 28 | for p in layer.parameters(): 29 | p.requires_grad = False 30 | else: 31 | print(f'-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions') 32 | return model 33 | 34 | 35 | def toggle_attention(llama_model: nn.Module, train: bool = False): 36 | """ 37 | Make attentions trainable if train is True 38 | -> Set train_attention = False when finetuning 39 | """ 40 | for layer in traverse_layers(llama_model): 41 | layer.self_attn.train_attention = train 42 | return llama_model 43 | 44 | 45 | def remove_base_attention(llama_model: nn.Module): 46 | """ 47 | Remove teacher attention after distillation (if we keep it) 48 | """ 49 | for layer in traverse_layers(llama_model): 50 | if getattr(layer.self_attn, 'base_attn', False): 51 | del layer.self_attn.base_attn 52 | return llama_model 53 | 54 | 55 | def traverse_layers(model: nn.Module, verbose: bool = False): 56 | """ 57 | Return list of model layers 58 | """ 59 | try: 60 | layers = model.model.layers 61 | if verbose: 62 | print('-> Loading from model.model.layers') 63 | except AttributeError as e: # if base model 64 | if verbose: 65 | print(e) 66 | try: 67 | layers = model.layers 68 | if verbose: 69 | print('-> Loading from model.layers') 70 | except AttributeError as e1: # If we make a PEFT model 71 | if verbose: 72 | print(e1) 73 | layers = model.base_model.model.model.layers 74 | if verbose: 75 | print('-> Loading from model.base_model.model.model.layers') 76 | return layers 77 | 78 | 79 | def convert_llama_attention(layer: nn.Module, 80 | attention_config: dict, 81 | layers: list[nn.Module], # list of layers 82 | train_attention: bool = False, 83 | remove_base_attn: bool = True): 84 | """ 85 | Converts a single layer's attention layer as specified by attention_config 86 | """ 87 | return get_attention(**attention_config)( 88 | base_attn=layer.self_attn, 89 | layer_idx=layer.self_attn.layer_idx, # Transformers v4.36 90 | max_layer_idx=len(layers) - 1, 91 | train_attention=train_attention, 92 | remove_base_attn=remove_base_attn, 93 | ) 94 | 95 | 96 | def get_attention(attention_type: str, **kwargs: any): 97 | """ 98 | Get the linear attention class; either purely linear or linear with sliding window 99 | -> 'linear' == 'lolcats_llama' 100 | -> 'linear and sliding_window' == 'lolcats_llama_window_*' 101 | """ 102 | kwargs['attention_type'] = attention_type 103 | 104 | if attention_type == 'lolcats_llama': 105 | from .linear_attention import LolcatsLinearAttention 106 | return partial(LolcatsLinearAttention, **kwargs) 107 | 108 | elif attention_type == 'lolcats_llama_window_tk': 109 | from .linear_attention import LolcatsTKWindowAttention 110 | return partial(LolcatsTKWindowAttention, **kwargs) 111 | 112 | elif attention_type == 'lolcats_llama_window_sw': 113 | from .linear_attention import LolcatsSlidingWindowAttention 114 | return partial(LolcatsSlidingWindowAttention, **kwargs) 115 | 116 | elif attention_type == 'lolcats_llama_window_sw_linear': 117 | from .linear_attention.linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention 118 | return partial(LolcatsLinearSlidingWindowAttention, **kwargs) 119 | 120 | ## Experimental chunked linear attentions below 121 | elif attention_type == 'lolcats_long_llama_window_tk': 122 | from .linear_attention import LolcatsTKWindowLongAttention 123 | return partial(LolcatsTKWindowLongAttention, **kwargs) 124 | 125 | elif attention_type == 'lolcats_long_llama_window_sw': 126 | from .linear_attention import LolcatsSlidingWindowLongAttention 127 | return partial(LolcatsSlidingWindowLongAttention, **kwargs) 128 | 129 | ## TK generation build (requires Thunderkittens) 130 | elif attention_type == 'lolcats_llama_window_tk_gen': 131 | from .linear_attention import LolcatsWindowAttentionTKGen 132 | return partial(LolcatsWindowAttentionTKGen, **kwargs) 133 | 134 | else: 135 | print(f'-> attention_type {attention_type} not handled... returning None') 136 | return None 137 | 138 | 139 | def get_attention_cache(attention_type: str, past_key_values: any = None): 140 | """ 141 | Determine how we store past keys and values when generating 142 | """ 143 | if attention_type is None: 144 | return past_key_values 145 | 146 | # print(f'Returning attention cache based on attention_type == {attention_type}') 147 | elif 'lolcats_llama_window_tk_gen' in attention_type: 148 | from .linear_attention import LinearAttentionTKWindowGenerationCache 149 | return LinearAttentionTKWindowGenerationCache() 150 | 151 | elif 'llama_window_tk' in attention_type: 152 | from .linear_attention import LinearAttentionTKWindowCache 153 | return LinearAttentionTKWindowCache() 154 | 155 | elif 'llama_window_sw' in attention_type: 156 | from .linear_attention import LinearAttentionSlidingWindowCache 157 | return LinearAttentionSlidingWindowCache() 158 | 159 | elif 'llama_window_sw_linear' in attention_type: 160 | from .linear_attention import LinearAttentionSlidingWindowCache 161 | return LinearAttentionSlidingWindowCache() 162 | 163 | ## TK generation build (requires Thunderkittens) 164 | elif attention_type == 'lolcats_llama_window_tk_gen': 165 | from .linear_attention.linear_window_attention_tk_gen import LinearAttentionTKWindowGenerationCache 166 | return LinearAttentionTKWindowGenerationCache() 167 | 168 | elif 'softmax' in attention_type: 169 | return past_key_values 170 | 171 | else: 172 | from .linear_attention import LinearAttentionState 173 | return LinearAttentionState() 174 | -------------------------------------------------------------------------------- /src/model/linear_attention/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Linear and linear attention + sliding window classes 3 | """ 4 | from .linear_attention import ( 5 | LolcatsLinearAttention, LinearAttentionState 6 | ) 7 | from .linear_window_attention_tk import ( 8 | LolcatsTKWindowAttention, LinearAttentionTKWindowCache 9 | ) 10 | from .linear_window_attention_sw import ( 11 | LolcatsSlidingWindowAttention, LinearAttentionSlidingWindowCache 12 | ) 13 | # Experimental chunk linear attentions 14 | from .linear_window_attention_tk_long import ( 15 | LolcatsTKWindowLongAttention, 16 | ) 17 | from .linear_window_attention_sw_long import ( 18 | LolcatsSlidingWindowLongAttention, 19 | ) 20 | from .linear_window_attention_tk_gen import ( 21 | LolcatsWindowAttentionTKGen, 22 | LinearAttentionTKWindowGenerationCache 23 | ) 24 | -------------------------------------------------------------------------------- /src/model/linear_attention/linear_window_attention_sw_long.py: -------------------------------------------------------------------------------- 1 | """ 2 | LoLCATs attention combining sliding window and linear attentions 3 | - Using standard sliding window arrangement 4 | - Training over long sequences with fixed memory with recurrent view 5 | - During attention transfer, use Flash Attention to compute softmax attention outputs 6 | 7 | For each layer: 8 | - We first compute (softmax) attention over sliding windows 9 | - We then compute standard linear attention to "fill in" the earlier parts 10 | - We combine to model the entire sequence 11 | """ 12 | from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention 13 | from .linear_window_attention_sw import hybrid_attention_quadratic 14 | 15 | 16 | class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention): 17 | """ 18 | Lolcats attention combining sliding window and linear attention 19 | """ 20 | def __init__(self, remove_base_attn=True, **kwargs): 21 | # keep self.base_attn for Flash Attention inference 22 | super().__init__(remove_base_attn=True, **kwargs) 23 | self.quadratic_attention = hybrid_attention_quadratic 24 | -------------------------------------------------------------------------------- /src/model/linear_attention/linear_window_attention_tk_gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | LoLCATs + ThunderKittens linear attention + sliding window for generation 3 | """ 4 | from typing import Optional, Tuple, List 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | try: 9 | from thunderkittens import hedgehog as tk_window_hedgehog_attention 10 | print(f"Successfully imported ThunderKittens for TK window attention") 11 | except: 12 | print(f"Failed to import ThunderKittens for TK window attention") 13 | 14 | from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention 15 | from .linear_attention import LinearAttentionState 16 | 17 | class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention): 18 | def __init__(self, *args, window_size: int = 64, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | self.train_attention = False 21 | self.base_inference = False 22 | self.window_size = 64 # hard-coded support for TK kernel 23 | self.decode_window_size = 64 24 | 25 | b, h, l, d = 1, 32, 8192, 128 26 | self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device='cuda') 27 | self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device='cuda') 28 | self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device='cuda') 29 | 30 | def forward(self, 31 | hidden_states: torch.Tensor, 32 | attention_mask: Optional[torch.Tensor] = None, 33 | position_ids: Optional[torch.LongTensor] = None, 34 | past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None, # “legacy” cache approach 35 | output_attentions: bool = False, 36 | use_cache: bool = False, 37 | **kwargs, 38 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 39 | """ 40 | Forward pass with the option to compute attention weights multiple ways 41 | if self.train_attention is True 42 | -> Consistent with HuggingFace Transformers for easy use with their pretrained models 43 | """ 44 | b, l, _ = hidden_states.size() 45 | assert past_key_value is not None, "past_key_value must be provided for generation" 46 | assert self.train_attention is False, "train_attention is not supported for generation" 47 | assert self.base_inference is False, "base_inference is not supported for generation" 48 | assert use_cache is True, "use_cache must be True for generation" 49 | past_key_value.window_size = self.decode_window_size 50 | q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, 51 | position_ids, past_key_value) 52 | if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill 53 | f_q = self.feature_map_q(q) 54 | _kv = past_key_value.update_for_decoding(k, v, self.layer_idx, 55 | self.feature_map_k) 56 | k_cache, v_cache, kv_state, k_state = _kv 57 | # Sliding window + linear attention decode 58 | window_factors = F.sigmoid(self.window_factors) 59 | linear_factors = 1 - window_factors if self.affine_attention_factors else 1 60 | 61 | # Softmax attention terms 62 | a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) 63 | a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) 64 | a_sm = window_factors * torch.exp(a_sm - a_sm_max) 65 | sum_sm = a_sm.sum(dim=-1, keepdim=True) 66 | 67 | # Combine with linear attention terms 68 | y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) 69 | + linear_factors * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())) 70 | sum_ln = linear_factors * torch.einsum( 71 | 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] 72 | self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) 73 | 74 | else: # Process prefill 75 | # Use TK-implemented linear + terrace window attention 76 | b, h, l, d = q.shape 77 | device = q.device 78 | # tk.hedgehog arguments 79 | # y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device) 80 | # kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device) 81 | # k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device) 82 | betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32)) 83 | alphas = (1 - betas if self.affine_attention_factors else 84 | torch.ones(betas.shape, dtype=torch.float32, device=device)) 85 | q_map = self.feature_map_q.mlp.layer 86 | k_map = self.feature_map_k.mlp.layer 87 | # Saves outputs to y_pred, k_state, kv_state, where we fuse: 88 | # 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) 89 | # 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d 90 | # 3. kv_state = torch.einsum(‘bhlf,bhld->bhfd’, 91 | # f_k[:, :, :-self.window_size], 92 | # v[:, :, :-self.window_size]) # b, h, f, d 93 | # 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d 94 | 95 | tk_window_hedgehog_attention(q.contiguous(), k.contiguous(), v.contiguous(), 96 | self.y_true, self.k_state, self.kv_state, 97 | q_map, k_map, alphas, betas) 98 | 99 | past_key_value.update_with_kv(self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx) 100 | 101 | # Concatenate heads and apply output projection 102 | y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) 103 | y_true = self.o_proj(y_true) 104 | return y_true, None, past_key_value 105 | 106 | 107 | class LinearAttentionTKWindowGenerationCache(LinearAttentionState): 108 | """ 109 | Class for `past_key_values` 110 | -> Alternative to KV cache; here we only maintain a “KV state” and “K state” 111 | -> Modified from transformers.cache_utils.DynamicCache (v4.36) 112 | """ 113 | def __init__(self, window_size: int = 64) -> None: 114 | super().__init__() 115 | self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 116 | self._seen_tokens_by_layer: List[int] = [] 117 | self.window_size = window_size 118 | 119 | self.decode_kv_states: List[torch.Tensor] = [] 120 | self.decode_k_states: List[torch.Tensor] = [] 121 | self.k_cache: List[torch.Tensor] = [] 122 | self.v_cache: List[torch.Tensor] = [] 123 | 124 | def update_with_kv(self, 125 | kv_state: torch.Tensor, k_state: torch.Tensor, 126 | k: torch.Tensor, v: torch.Tensor, 127 | layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 128 | """ 129 | Update the cache with new KV and K states 130 | """ 131 | if layer_idx == 0: 132 | self._seen_tokens += k.shape[2] 133 | self._seen_tokens_by_layer.append(k.shape[2]) 134 | 135 | # Initialize KV and K states 136 | if len(self.decode_k_states) <= layer_idx: 137 | self.decode_kv_states.append(kv_state) 138 | self.decode_k_states.append(k_state) 139 | else: # Update KV and K states 140 | self.decode_kv_states[layer_idx] = self.decode_kv_states[layer_idx] + kv_state 141 | self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state 142 | 143 | self.k_cache.append(k[:, :, -self.window_size:, :]) 144 | self.v_cache.append(v[:, :, -self.window_size:, :]) 145 | 146 | def update_for_decoding(self, k: torch.Tensor, v: torch.Tensor, 147 | layer_idx: int, feature_map_k: callable) -> None: 148 | """ 149 | Update the cache for decoding 150 | """ 151 | k_cache = self.k_cache[layer_idx] 152 | v_cache = self.v_cache[layer_idx] 153 | k_state = feature_map_k(k_cache[:, :, :1, :]) 154 | v_state = v_cache[:, :, :1, :] 155 | kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(k.dtype) 156 | 157 | self.decode_kv_states[layer_idx] += kv_state 158 | self.decode_k_states[layer_idx] += k_state 159 | 160 | self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2) 161 | self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2) 162 | if layer_idx == 0: 163 | self._seen_tokens += k.shape[-2] 164 | self._seen_tokens_by_layer[layer_idx] += k.shape[-2] 165 | return (self.k_cache[layer_idx], self.v_cache[layer_idx], 166 | self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) -------------------------------------------------------------------------------- /src/model/linear_attention/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shared attention helpers 3 | """ 4 | import torch 5 | 6 | 7 | # Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) 8 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 9 | """ 10 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). 11 | The hidden states go from: 12 | (batch, num_key_value_heads, seqlen, head_dim) to 13 | (batch, num_attention_heads, seqlen, head_dim) 14 | """ 15 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 16 | if n_rep == 1: 17 | return hidden_states 18 | hidden_states = hidden_states[:, :, None, :, :].expand( 19 | batch, num_key_value_heads, n_rep, slen, head_dim) 20 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 21 | 22 | 23 | def mask_attention(qk_dot: torch.Tensor, attn_mask: torch.tensor, 24 | mask_value: float = -10000) -> torch.Tensor: 25 | """ 26 | Apply attention mask (e.g., for padding) 27 | """ 28 | if len(attn_mask.shape) == 4: # attn_mask either (b, h, l, d) or (b, l) 29 | return qk_dot.masked_fill(~attn_mask.bool(), mask_value) 30 | else: 31 | return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value) -------------------------------------------------------------------------------- /src/model/load_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to load checkpoints for learned feature maps (attentions) or other parameters 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from omegaconf import OmegaConf 7 | 8 | from src.utils.logging import print_header, _format_arg 9 | from .convert_model import convert_attention 10 | from .peft import create_peft_config 11 | 12 | 13 | def load_and_convert_attns(model: nn.Module, 14 | model_config: dict, 15 | attention_type: str = None, 16 | checkpoint_path: str = None, 17 | print_model: bool = False, 18 | merge_loras: bool = False, 19 | train_converted: bool = True, # Should be false if loading distill checkpoint by default 20 | peft_gradient_checkpointing: bool = None, 21 | train_attention: bool = False, # Should be true if converting attentions for first time, 22 | freeze_weights: bool = True, 23 | rank: int = 0, 24 | remove_base_attn: bool = True, 25 | ) -> nn.Module: 26 | """ 27 | Load trained attention kernel parameter weights 28 | """ 29 | if freeze_weights: 30 | for p in model.parameters(): 31 | p.requires_grad = False 32 | 33 | if attention_type is not None: # override default 34 | model_config['attention']['attention_type'] = attention_type 35 | model_config['attention']['rank'] = rank # multi-gpu debugging 36 | 37 | model = convert_attention(model, model_config['attention'], 38 | train_attention, remove_base_attn) 39 | 40 | # Add low-rank adapters 41 | peft_key = 'peft' # inconsistency across configs... why do this to myself 42 | if 'peft_config' in model_config['attention']: 43 | peft_key = 'peft_config' 44 | if peft_key in model_config['attention']: 45 | peft_config = model_config['attention'][peft_key] 46 | model, peft_config = create_peft_config(model, peft_config, 47 | model_config['model']['torch_dtype'], 48 | preserve_requires_grad=train_converted, 49 | use_gradient_checkpointing=peft_gradient_checkpointing) 50 | else: 51 | peft_config = None 52 | 53 | if print_model and rank == 0: # Look at model 54 | print_header('*** Model before checkpoint load ***') 55 | print(model) 56 | 57 | # Load any trained attentions 58 | if checkpoint_path is not None: 59 | print(f'Loading weights from {checkpoint_path}...') 60 | state_dict = torch.load(checkpoint_path)['model_state_dict'] 61 | _keys = model.load_state_dict(state_dict, strict=False) 62 | try: 63 | assert len(_keys.unexpected_keys) == 0 64 | if rank == 0: 65 | print_header('*** All expected keys matched successfully ***') 66 | if print_model: 67 | for k in state_dict.keys(): 68 | print(k) 69 | except Exception as e: 70 | if rank == 0: 71 | print(e) 72 | print_header('*** Error: unexpected keys in checkpoint ***') 73 | print('Unexpected keys:') 74 | for k in _keys.unexpected_keys: 75 | print(k) 76 | if print_model and rank == 0: # Look at model 77 | print_header('*** Model ***') 78 | print(model) 79 | if merge_loras: 80 | model = model.merge_and_unload() 81 | if print_model and rank == 0: 82 | print_header('*** Model (after merging adapters) ***') 83 | print(model) 84 | if print_model and rank == 0: # Look at model 85 | print_header('*** Trainable Parameters ***') 86 | for n, p in model.named_parameters(): 87 | if p.requires_grad: 88 | print(f'├── {n} (dtype = {p.dtype})') 89 | return model, peft_config 90 | 91 | 92 | def load_and_convert_finetune(model: nn.Module, 93 | finetune_config: dict, 94 | checkpoint_path: str = None, 95 | print_model: bool = False, 96 | merge_loras: bool = False, 97 | peft_gradient_checkpointing: bool = None, 98 | rank: int = 0, 99 | **peft_kwargs: any): 100 | """ 101 | Load trained adapter / model weights 102 | """ 103 | # Add low-rank adapters 104 | peft_config = None 105 | if finetune_config.finetune.method == 'lora': 106 | if getattr(finetune_config.finetune, 'kwargs', None) is not None: 107 | model, peft_config = create_peft_config( 108 | model, finetune_config.finetune, 109 | use_gradient_checkpointing=peft_gradient_checkpointing, 110 | **peft_kwargs, 111 | ) 112 | # Keep specified weights trainable 113 | if 'trainable_weights' in finetune_config.finetune: 114 | for name in finetune_config.finetune['trainable_weights']: 115 | for n, p in model.named_parameters(): 116 | if name in n: 117 | p.requires_grad = True 118 | else: 119 | for p in model.parameters(): 120 | p.requires_grad = False 121 | # Keep specified weights trainable 122 | if 'trainable_weights' in finetune_config.finetune: 123 | for name in finetune_config.finetune['trainable_weights']: 124 | for n, p in model.named_parameters(): 125 | if name in n: 126 | if 'layers_to_ignore' in finetune_config.finetune: 127 | layer = int(n.split('layers.')[-1].split('.')[0]) 128 | if layer not in finetune_config.finetune['layers_to_ignore']: 129 | p.requires_grad = True 130 | else: 131 | p.requires_grad = True 132 | 133 | 134 | # Load weights 135 | if checkpoint_path: 136 | state_dict = torch.load(checkpoint_path)['model_state_dict'] 137 | _keys = model.load_state_dict(state_dict, strict=False) 138 | try: 139 | assert len(_keys.unexpected_keys) == 0 140 | if rank == 0: 141 | print_header('*** All expected keys matched successfully ***') 142 | except Exception as e: 143 | if rank == 0: 144 | print(e) 145 | print_header('*** Error: unexpected keys in checkpoint ***') 146 | print('Unexpected keys:') 147 | for k in _keys.unexpected_keys: 148 | print(k) 149 | 150 | if print_model and rank == 0: # Look at model 151 | print_header('*** Model ***') 152 | print(model) 153 | 154 | if merge_loras: 155 | try: 156 | model = model.merge_and_unload() 157 | if print_model and rank == 0: 158 | print_header('*** Model (after merging adapters) ***') 159 | print(model) 160 | except Exception as e: 161 | print(e) 162 | 163 | if print_model and rank == 0: # Look at model 164 | print_header('*** Trainable Parameters ***') 165 | count = 0 166 | for n, p in model.named_parameters(): 167 | if p.requires_grad: 168 | print(f'├── {n}.requires_grad: {p.requires_grad}') 169 | count += 1 170 | if count == 0: 171 | print('(none)') 172 | 173 | return model, peft_config 174 | -------------------------------------------------------------------------------- /src/model/modeling_llama_sharded.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ 21 | Thin wrappers and replacement classes for LlamaForCausalLM 22 | - Simple sharding across multiple GPUs; will be slow but good for quality evals 23 | - May need to update for Llama 405B 24 | """ 25 | from typing import Optional, Tuple, List, Union 26 | 27 | import warnings 28 | import torch 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | 32 | from transformers.models.llama.modeling_llama import ( 33 | LlamaModel, LlamaForCausalLM, LLAMA_INPUTS_DOCSTRING, 34 | ) 35 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast 36 | from transformers.cache_utils import Cache, DynamicCache 37 | from transformers.utils import ( 38 | add_start_docstrings_to_model_forward, logging, 39 | ) 40 | 41 | from .convert_model import get_attention_cache 42 | 43 | logger = logging.get_logger(__name__) 44 | 45 | # Modified from transformers.models.llama.modeling_llama.LlamaModel (v4.43) 46 | class ShardedLolcatsLlamaModel(LlamaModel): 47 | """ 48 | Wrapper for Llama or Mistral-like base model 49 | 50 | Modified from transformers.models.llama.modeling_llama.LlamaModel 51 | -> Only difference is using KV state for past_key_values instead of cache 52 | """ 53 | def __init__(self, *args: any, **kwargs: any): 54 | super().__init__(*args, **kwargs) 55 | self.layerwise_cpu = False 56 | 57 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 58 | def forward( 59 | self, 60 | input_ids: torch.LongTensor = None, 61 | attention_mask: Optional[torch.Tensor] = None, 62 | position_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 64 | inputs_embeds: Optional[torch.FloatTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | return_dict: Optional[bool] = None, 69 | cache_position: Optional[torch.LongTensor] = None, 70 | ) -> Union[Tuple, BaseModelOutputWithPast]: 71 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 72 | output_hidden_states = ( 73 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 74 | ) 75 | use_cache = use_cache if use_cache is not None else self.config.use_cache 76 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 77 | 78 | if (input_ids is None) ^ (inputs_embeds is not None): 79 | raise ValueError( 80 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" 81 | ) 82 | 83 | if self.gradient_checkpointing and self.training and use_cache: 84 | logger.warning_once( 85 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 86 | ) 87 | use_cache = False 88 | 89 | batch_size, seq_length = input_ids.shape 90 | 91 | if inputs_embeds is None: 92 | inputs_embeds = self.embed_tokens(input_ids) 93 | 94 | return_legacy_cache = False 95 | if use_cache: 96 | if past_key_values is None or isinstance(past_key_values, DynamicCache): # Determine and setup our KV cache or state 97 | attention_type = getattr(self.layers[0].self_attn, 'attention_type', None) 98 | past_key_values = get_attention_cache(attention_type, past_key_values) 99 | else: 100 | past_key_values.get_usable_length(seq_length) 101 | 102 | if cache_position is None: 103 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 104 | cache_position = torch.arange( 105 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 106 | ) 107 | 108 | if position_ids is None: 109 | position_ids = cache_position.unsqueeze(0) 110 | 111 | causal_mask = self._update_causal_mask( 112 | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions 113 | ) 114 | hidden_states = inputs_embeds 115 | 116 | # create position embeddings to be shared across the decoder layers 117 | # - ignored for linearized models 118 | position_embeddings = None 119 | # position_embeddings = self.rotary_emb(hidden_states, position_ids.to(hidden_states.device)) 120 | 121 | # decoder layers 122 | all_hidden_states = () if output_hidden_states else None 123 | all_self_attns = () if output_attentions else None 124 | next_decoder_cache = None 125 | 126 | for decoder_layer in self.layers: 127 | # Move output to right device 128 | device = decoder_layer.self_attn.q_proj.weight.device 129 | hidden_states = hidden_states.to(device) 130 | position_ids = position_ids.to(device) 131 | if attention_mask is not None: 132 | attention_mask = attention_mask.to(device) 133 | 134 | if output_hidden_states: 135 | all_hidden_states += (hidden_states,) 136 | 137 | 138 | if getattr(decoder_layer.self_attn, 'converted', False): 139 | if self.gradient_checkpointing and self.training: 140 | layer_outputs = self._gradient_checkpointing_func( 141 | decoder_layer.__call__, 142 | hidden_states, 143 | causal_mask, 144 | position_ids, 145 | past_key_values, 146 | output_attentions, 147 | use_cache, 148 | cache_position, 149 | position_embeddings, 150 | ) 151 | else: 152 | layer_outputs = decoder_layer( 153 | hidden_states, 154 | attention_mask=causal_mask, 155 | position_ids=position_ids, 156 | past_key_value=past_key_values, 157 | output_attentions=output_attentions, 158 | use_cache=use_cache, 159 | cache_position=cache_position, 160 | position_embeddings=position_embeddings, 161 | ) 162 | else: 163 | with torch.no_grad(): 164 | layer_outputs = decoder_layer( 165 | hidden_states, 166 | attention_mask=attention_mask, 167 | position_ids=position_ids, 168 | past_key_value=past_key_values, 169 | output_attentions=output_attentions, 170 | use_cache=use_cache, 171 | cache_position=cache_position, 172 | position_embeddings=position_embeddings, 173 | ) 174 | 175 | hidden_states = layer_outputs[0] 176 | 177 | if use_cache: 178 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 179 | 180 | if output_attentions: 181 | all_self_attns += (layer_outputs[1],) 182 | 183 | hidden_states = self.norm(hidden_states.to(self.norm.weight.device)) 184 | 185 | # add hidden states from the last decoder layer 186 | if output_hidden_states: 187 | all_hidden_states += (hidden_states,) 188 | 189 | next_cache = next_decoder_cache if use_cache else None 190 | if return_legacy_cache: 191 | next_cache = next_cache.to_legacy_cache() 192 | 193 | if not return_dict: 194 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 195 | return BaseModelOutputWithPast( 196 | last_hidden_state=hidden_states, 197 | past_key_values=next_cache, 198 | hidden_states=all_hidden_states, 199 | attentions=all_self_attns, 200 | ) 201 | 202 | 203 | class ShardedLolcatsLlamaForCausalLM(LlamaForCausalLM): 204 | """ 205 | Wrapper for Llama-like autoregressive language model 206 | """ 207 | def __init__(self, config): 208 | # Adapt config to LlamaConfig 209 | if getattr(config, 'attention_bias', None) is None: 210 | config.attention_bias = False 211 | if getattr(config, 'rope_scaling', None) is None: 212 | config.rope_scaling = None 213 | if getattr(config, 'pretraining_tp', None) is None: 214 | config.pretraining_tp = 1 215 | super().__init__(config) 216 | self.model = ShardedLolcatsLlamaModel(config) 217 | self.vocab_size = config.vocab_size 218 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 219 | 220 | # Initialize weights and apply final processing 221 | self.post_init() 222 | 223 | def forward(self, *args: any, labels: Optional[torch.LongTensor] = None, **kwargs: any): 224 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 225 | outputs = self.model(*args, **kwargs) 226 | hidden_states = outputs[0] 227 | if getattr(self.model.layers[0].self_attn, 'train_attention', False): 228 | logits = None 229 | else: # regular training 230 | if self.config.pretraining_tp > 1: 231 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 232 | logits = [F.linear(hidden_states, lm_head_slices[i]) 233 | for i in range(self.config.pretraining_tp)] 234 | logits = torch.cat(logits, dim=-1) 235 | else: 236 | logits = self.lm_head(hidden_states) 237 | logits = logits.float() 238 | 239 | return CausalLMOutputWithPast( 240 | logits=logits, 241 | past_key_values=outputs.past_key_values, 242 | hidden_states=outputs.hidden_states, 243 | attentions=outputs.attentions, 244 | ) 245 | -------------------------------------------------------------------------------- /src/model/modeling_mistral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ 21 | Thin wrappers and replacement classes for MistralForCausalLM 22 | """ 23 | from typing import Optional, Tuple, List, Union 24 | 25 | import warnings 26 | import torch 27 | import torch.nn as nn 28 | from transformers import MistralModel, MistralForCausalLM 29 | from transformers.modeling_outputs import CausalLMOutputWithPast 30 | 31 | from .modeling_llama import LolcatsLlamaModel 32 | from .convert_model import get_attention_cache 33 | 34 | 35 | # Modified from transformers.models.llama.modeling_llama.LlamaModel 36 | class LolcatsMistralModel(LolcatsLlamaModel, MistralModel): 37 | """ 38 | Wrapper for Mistral-like autoregressive language model 39 | """ 40 | def forward(self, *args, **kwargs): 41 | return super().forward(*args, **kwargs) 42 | 43 | 44 | class LolcatsMistralForCausalLM(MistralForCausalLM): 45 | """ 46 | Wrapper for Llama or Mistral-like autoregressive language model 47 | """ 48 | def __init__(self, config): 49 | # Adapt config to LlamaConfig 50 | if getattr(config, 'attention_bias', None) is None: 51 | config.attention_bias = False 52 | if getattr(config, 'rope_scaling', None) is None: 53 | config.rope_scaling = None 54 | if getattr(config, 'pretraining_tp', None) is None: 55 | config.pretraining_tp = 1 56 | if getattr(config, 'pretraining_tp', None) is None: 57 | config.pretraining_tp = 1 58 | if getattr(config, 'mlp_bias', None) is None: 59 | config.mlp_bias = False 60 | super().__init__(config) 61 | self.model = LolcatsMistralModel(config) 62 | self.vocab_size = config.vocab_size 63 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 64 | 65 | # Initialize weights and apply final processing 66 | self.post_init() 67 | 68 | 69 | class LooooolcatsMistralForCausalLM(LolcatsMistralForCausalLM): 70 | """ 71 | Wrapper for Llama or Mistral-like autoregressive language model 72 | -> Experimental / WIP; but goal is to combine chunked linear attention during training 73 | to process long contexts with minimally-growing memory usage 74 | """ 75 | def chunk_forward(self, *args: any, **kwargs: any): 76 | """Call this when training / processing one chunk""" 77 | return super().forward(*args, **kwargs) 78 | 79 | def forward( 80 | self, 81 | input_ids: torch.LongTensor = None, 82 | attention_mask: Optional[torch.Tensor] = None, 83 | position_ids: Optional[torch.LongTensor] = None, 84 | past_key_values: Optional[List[torch.FloatTensor]] = None, 85 | inputs_embeds: Optional[torch.FloatTensor] = None, 86 | labels: Optional[torch.LongTensor] = None, 87 | use_cache: Optional[bool] = None, 88 | output_attentions: Optional[bool] = None, 89 | output_hidden_states: Optional[bool] = None, 90 | return_dict: Optional[bool] = None, 91 | cache_position: Optional[torch.LongTensor] = None, # Ignored for now, new Transformers >4.36 92 | ) -> Union[Tuple, CausalLMOutputWithPast]: 93 | """ 94 | Forward pass where we chunk inputs 95 | """ 96 | self.generating = False 97 | if use_cache is not True: 98 | use_cache = True 99 | 100 | if attention_mask is not None and use_cache: 101 | warnings.warn( 102 | f"Sorry padding currently not supported. Setting attention_mask to None (will still be causal)." 103 | ) 104 | attention_mask = None 105 | 106 | if past_key_values is None: 107 | # Determine and setup our KV cache or state 108 | attention_type = getattr(self.model.layers[0].self_attn, 'attention_type', None) 109 | past_key_values = get_attention_cache(attention_type) 110 | # past_key_values = LinearAttentionState() 111 | 112 | if input_ids.shape[-1] == 1 and not self.training: # Heuristic to detect generating 113 | return super().forward(input_ids, attention_mask, position_ids, 114 | past_key_values, inputs_embeds, labels, 115 | use_cache, output_attentions, output_hidden_states, 116 | return_dict) 117 | else: 118 | if self.generating: # Heuristic to detect new sample 119 | self.generating = False 120 | # Determine and setup our KV cache or state 121 | attention_type = getattr(self.model.layers[0].self_attn, 'attention_type', None) 122 | past_key_values = get_attention_cache(attention_type) 123 | print(f'-> attention_type:', attention_type) 124 | 125 | # Make it so we keep track of gradients in kv_state computation 126 | for idx in range(len(self.model.layers)): 127 | self.model.layers[idx].self_attn.state_grad_enabled = self.training 128 | 129 | # Split inputs into chunks, and do linear attention over each (passing the states) 130 | input_ids = torch.split(input_ids, self.state_chunk_len, dim=-1) 131 | if position_ids is not None: 132 | position_ids = torch.split(position_ids, self.state_chunk_len, dim=-1) 133 | 134 | all_logits = [] # save these 135 | for _idx, _input_ids in enumerate(input_ids): 136 | outputs = super().forward(_input_ids, None, 137 | position_ids[_idx] if position_ids is not None else None, 138 | past_key_values, inputs_embeds, 139 | labels=None, 140 | use_cache=True, 141 | output_attentions=False, 142 | output_hidden_states=False, 143 | return_dict=True,) 144 | past_key_values = outputs.past_key_values 145 | all_logits.append(outputs.logits) 146 | 147 | # Comment in / adjust to do gradient accumulation over chunks 148 | # if self.training: 149 | # loss = outputs.loss 150 | # loss.backward() # accumulate gradients over chunks 151 | # else: 152 | # del outputs.loss 153 | 154 | if _idx == len(input_ids) - 1: 155 | self.generating = True # time to generate; if no generation will reset 156 | 157 | return CausalLMOutputWithPast( 158 | # loss=loss, 159 | logits=torch.cat(all_logits, dim=-2), # b, l, d 160 | past_key_values=past_key_values, 161 | ) -------------------------------------------------------------------------------- /src/model/peft.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for parameter-efficient finetuning via low-rank adapters (LoRA) 3 | -> Mainly follow PEFT / llama recipes 4 | 5 | Right now quantization not super tested 6 | """ 7 | import torch 8 | from torch.nn import Module 9 | 10 | 11 | # Modified from https://github.com/facebookresearch/llama-recipes/blob/main/examples/quickstart.ipynb 12 | def create_peft_config(model: Module, 13 | peft_config: dict, 14 | target_dtype: str = 'bfloat16', 15 | preserve_requires_grad: bool = False, 16 | use_gradient_checkpointing: bool = None, 17 | add_self_attn_prefix: bool = True): 18 | """ 19 | Create a parameter-efficient finetuning model (e.g., attaching LoRAs) 20 | -> Assumes that all non-trainable weights have been frozen already. 21 | If not, freeze them before calling this function. 22 | """ 23 | if peft_config['method'] == 'lora': 24 | from peft import ( 25 | get_peft_model, 26 | LoraConfig, 27 | TaskType, 28 | prepare_model_for_kbit_training, 29 | ) 30 | try: 31 | target_modules = [] # hack to only do self_attn terms 32 | for module_name in peft_config['kwargs']['target_modules']: 33 | if ('_proj' in module_name and 'self_attn' not in module_name 34 | and add_self_attn_prefix): 35 | target_modules.append(f'self_attn.{module_name}') 36 | elif '_proj' in module_name: 37 | target_modules.append(module_name) 38 | peft_config['kwargs']['target_modules'] = target_modules 39 | except Exception as e: 40 | print(e) 41 | target_modules = [] 42 | 43 | if 'layers_to_ignore' in peft_config: 44 | peft_config['kwargs']['layers_to_transform'] = [ 45 | i for i in range(len(model.model.layers)) 46 | if i not in peft_config['layers_to_ignore'] 47 | ] 48 | 49 | peft_config = LoraConfig( 50 | task_type=TaskType.CAUSAL_LM, 51 | inference_mode=False, 52 | **peft_config['kwargs'], 53 | ) 54 | # Save parameters that did not have frozen weights before to unfreeze later 55 | trainable_weights = [ 56 | n for n, p in model.named_parameters() if p.requires_grad 57 | ] 58 | # Prepare int-8 or int-4 model for training 59 | loaded_in_kbit = (getattr(model, "is_loaded_in_8bit", False) or 60 | getattr(model, "is_loaded_in_4bit", False)) 61 | if loaded_in_kbit: # From https://huggingface.co/docs/peft/en/package_reference/peft_model: 62 | # This method wraps the entire protocol for preparing a model before running a training. 63 | # 1- Cast the layernorm in fp32 64 | # 2- making output embedding layer require grads 65 | # 3- Add the upcasting of the lm head to fp32 66 | model.enable_input_require_grads() 67 | ugc = (use_gradient_checkpointing 68 | if use_gradient_checkpointing is not None else True) 69 | print('-> use_gradient_checkpointing:', ugc) 70 | # model.gradient_checkpointing_enable() 71 | model = prepare_model_for_kbit_training( 72 | model, use_gradient_checkpointing=ugc, 73 | gradient_checkpointing_kwargs={'use_reentrant': False}, 74 | ) 75 | 76 | model = get_peft_model(model, peft_config) 77 | model.print_trainable_parameters() 78 | 79 | for n, p in model.named_parameters(): 80 | # Unfreeze weights frozen by get_peft_model() 81 | if preserve_requires_grad: 82 | if n[len('base_model.model.'):] in trainable_weights: 83 | p.requires_grad = True 84 | 85 | # prepare_model_for_kbit_training will cast all non INT8 parameters to fp32 86 | # -> https://github.com/huggingface/peft/blob/7e84dec20b3106bdd0a90ba8e80187f0aec835b7/src/peft/utils/other.py#L103 87 | # So we'll cast these back to their prior dtype 88 | if p.requires_grad and loaded_in_kbit: 89 | p.data = p.data.to(getattr(torch, target_dtype)) 90 | 91 | if not loaded_in_kbit: 92 | model.to(dtype=getattr(torch, target_dtype)) 93 | 94 | return model, peft_config 95 | else: 96 | raise NotImplementedError(f"Sorry PEFT method {peft_config['method']} not implemented yet.") 97 | -------------------------------------------------------------------------------- /src/model/pretrained.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes for loading pretrained models 3 | """ 4 | from os.path import join 5 | from omegaconf import OmegaConf 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import transformers 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer 12 | # from transformers import BitsAndBytesConfig 13 | from peft import prepare_model_for_kbit_training 14 | 15 | 16 | def get_pretrained_loader(pretrained_model_name_or_path: str, 17 | huggingface_token: str = None, 18 | **model_kwargs: any): 19 | """ 20 | Return the appropriate loader for the pretrained model 21 | """ 22 | 23 | if 'lama' in pretrained_model_name_or_path: # Llama or llama 24 | return PretrainedLlamaLoader( 25 | pretrained_model_name_or_path=pretrained_model_name_or_path, 26 | huggingface_token=huggingface_token, 27 | **model_kwargs, 28 | ) 29 | elif 'istral' in pretrained_model_name_or_path: # Mistral or mistral; 30 | return PretrainedMistralLoader( 31 | pretrained_model_name_or_path=pretrained_model_name_or_path, 32 | huggingface_token=huggingface_token, 33 | **model_kwargs, 34 | ) 35 | else: 36 | print(f'-> {pretrained_model_name_or_path} using default pretrained model loader') 37 | return PretrainedModelLoader( 38 | pretrained_model_name_or_path=pretrained_model_name_or_path, 39 | huggingface_token=huggingface_token, 40 | **model_kwargs, 41 | ) 42 | 43 | 44 | class PretrainedModelLoader(): 45 | """ 46 | Class for loading a pretrained model. 47 | Example: 48 | model_loader = PretrainedModelLoader(**model_kwargs) 49 | model = model_loader.load() 50 | """ 51 | def __init__(self, 52 | pretrained_model_name_or_path: str, 53 | cache_dir: str = None, 54 | return_dict: bool = True, # False 55 | device_map: str = 'auto', 56 | low_cpu_mem_usage: bool = True, 57 | torch_dtype: str = 'bfloat16', 58 | rope_theta: float = 10000., 59 | attn_implementation: str = 'sdpa', # eager 60 | load_in_8bit: bool = False, 61 | load_in_4bit: bool = False, 62 | huggingface_token: str = None, 63 | peft_id: str = None, 64 | rope_scaling: dict = None, 65 | **other_kwargs: any) -> None: 66 | 67 | print(f'-> Using {attn_implementation} attention') 68 | 69 | self.loading_kwargs = { 70 | 'pretrained_model_name_or_path': pretrained_model_name_or_path, 71 | 'cache_dir': cache_dir, 72 | 'return_dict': return_dict, 73 | 'load_in_8bit': load_in_8bit, 74 | 'load_in_4bit': load_in_4bit, 75 | 'device_map': device_map, 76 | 'low_cpu_mem_usage': low_cpu_mem_usage, 77 | 'torch_dtype': getattr(torch, torch_dtype), 78 | 'rope_theta': rope_theta, 79 | 'attn_implementation': attn_implementation, 80 | } 81 | if rope_scaling is not None: # Llama 3.1 patch 82 | rope_scaling = OmegaConf.to_container(rope_scaling) 83 | self.loading_kwargs['rope_scaling'] = rope_scaling 84 | for k, v in other_kwargs.items(): 85 | self.loading_kwargs[k] = v 86 | 87 | self.quantization = load_in_8bit or load_in_4bit 88 | self.peft_id = peft_id 89 | self.gradient_checkpointing = False 90 | if huggingface_token is not None: # for gated models, e.g., Llama 3 91 | self.loading_kwargs['token'] = huggingface_token 92 | 93 | if self.quantization: 94 | raise NotImplementedError 95 | # bnb_config = BitsAndBytesConfig( 96 | # load_in_8bit=load_in_8bit, 97 | # load_in_4bit=load_in_4bit, 98 | # bnb_4bit_compute_dtype=torch.bfloat16, 99 | # bnb_4bit_use_double_quant=True, 100 | # bnb_4bit_quant_type="nf4", 101 | # ) 102 | # del self.loading_kwargs['load_in_8bit'] 103 | # del self.loading_kwargs['load_in_4bit'] 104 | # self.loading_kwargs['quantization_config'] = bnb_config 105 | 106 | def load(self) -> nn.Module: 107 | """ 108 | Load pretrained model 109 | """ 110 | model = AutoModelForCausalLM.from_pretrained(**self.loading_kwargs) 111 | if self.quantization: 112 | model = prepare_model_for_kbit_training( 113 | model, use_gradient_checkpointing=self.gradient_checkpointing, 114 | gradient_checkpointing_kwargs={'use_reentrant': False}, 115 | ) 116 | return model 117 | 118 | def load_tokenizer(self): 119 | """ 120 | Load pretrained tokenizer 121 | """ 122 | try: 123 | return AutoTokenizer.from_pretrained(**self.loading_kwargs) 124 | except Exception as e: 125 | print("-> Error with `AutoTokenizer.from_pretrained(**self.loading_kwargs)`:", e) 126 | print("-> Trying `LlamaTokenizer.from_pretrained(**self.loading_kwargs)`") 127 | # MZ 6/1: Mistral-7B-Instruct-v0.3 in Transformers v4.36 doesn't work with the above 128 | return LlamaTokenizer.from_pretrained(**self.loading_kwargs) 129 | 130 | 131 | class PretrainedLlamaLoader(PretrainedModelLoader): 132 | def load(self, model_type: str = None, ): 133 | llama3_1 = float('.'.join(transformers.__version__.split('.')[:2])) > 4.42 # 'Meta-Llama-3.1' in self.loading_kwargs['pretrained_model_name_or_path'] 134 | if model_type is None: 135 | from transformers import LlamaForCausalLM as model_class 136 | 137 | elif 'lolcats_llama_sharded' in model_type: 138 | from .modeling_llama_sharded import ShardedLolcatsLlamaForCausalLM as model_class 139 | 140 | elif 'lolcats_long_llama' in model_type: 141 | from .modeling_llama import LooooolcatsLlamaForCausalLM as model_class 142 | 143 | elif 'lolcats_llama' in model_type: 144 | from .modeling_llama import LolcatsLlamaForCausalLM as model_class 145 | 146 | else: 147 | if model_type == 'flash_attention_2': 148 | self.loading_kwargs['attn_implementation'] = model_type 149 | from transformers import AutoModelForCausalLM as model_class 150 | print('-> Loading from AutoModelForCausalLM') 151 | 152 | model = model_class.from_pretrained(**self.loading_kwargs) 153 | if self.peft_id is not None: 154 | from peft import PeftModel 155 | print('-> Loading PEFT checkpoint') 156 | model = PeftModel.from_pretrained( 157 | model, 158 | self.peft_id, 159 | torch_dtype=self.loading_kwargs['torch_dtype'], 160 | device_map='auto', 161 | cache_dir=self.loading_kwargs['cache_dir'] 162 | ).merge_and_unload() 163 | 164 | if self.quantization: 165 | model = prepare_model_for_kbit_training( 166 | model, use_gradient_checkpointing=self.gradient_checkpointing, 167 | gradient_checkpointing_kwargs={'use_reentrant': False}, 168 | ) 169 | return model 170 | 171 | def load_tokenizer(self): 172 | return AutoTokenizer.from_pretrained(**self.loading_kwargs) 173 | 174 | 175 | class PretrainedMistralLoader(PretrainedModelLoader): 176 | def load(self, model_type: str = None): 177 | if model_type is None: 178 | from transformers import MistralForCausalLM as model_class 179 | elif 'lolcats_long_llama' in model_type: 180 | from .modeling_mistral import LooooolcatsMistralForCausalLM as model_class 181 | elif 'lolcats_llama' in model_type: 182 | from .modeling_mistral import LolcatsMistralForCausalLM as model_class 183 | else: 184 | if model_type == 'flash_attention_2': 185 | self.loading_kwargs['attn_implementation'] = model_type 186 | from transformers import AutoModelForCausalLM as model_class 187 | print('-> Loading from AutoModelForCausalLM') 188 | 189 | model = model_class.from_pretrained(**self.loading_kwargs) 190 | if self.peft_id is not None: 191 | from peft import PeftModel 192 | model = PeftModel.from_pretrained( 193 | model, 194 | self.peft_id, 195 | torch_dtype=self.loading_kwargs['torch_dtype'], 196 | device_map='auto', 197 | cache_dir=self.loading_kwargs['cache_dir'], 198 | ).merge_and_unload() 199 | 200 | if self.quantization: 201 | model = prepare_model_for_kbit_training( 202 | model, use_gradient_checkpointing=self.gradient_checkpointing, 203 | gradient_checkpointing_kwargs={'use_reentrant': False}, 204 | ) 205 | return model 206 | -------------------------------------------------------------------------------- /src/model/rotary.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ 21 | Rotary embeddings. Same as usual for Transformer models. 22 | 23 | Note these are modified from HF Transformers v4.36, from: 24 | - transformers/models/llama/modeling_llama.py or transformers/models/mistral/modeling_mistral.py 25 | - i.e., https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L123 26 | """ 27 | import torch 28 | import torch.nn as nn 29 | 30 | 31 | def get_rotary_embeddings(rope_scaling_type: str = None, 32 | head_dim: int = 128, 33 | max_position_embeddings: int = 4096, 34 | rope_theta: float = 10000.0, 35 | rope_scaling_factor: float = 1.0, 36 | device: torch.device = None, 37 | ) -> nn.Module: 38 | """Return rotary embedding object""" 39 | if rope_scaling_type is None: 40 | return RotaryEmbedding( 41 | head_dim, 42 | max_position_embeddings=max_position_embeddings, 43 | base=rope_theta, 44 | device=device, 45 | ) 46 | elif rope_scaling_type == "linear": 47 | return LinearScalingRotaryEmbedding( 48 | head_dim, 49 | max_position_embeddings=max_position_embeddings, 50 | scaling_factor=rope_scaling_factor, 51 | base=rope_theta, 52 | device=device, 53 | ) 54 | elif rope_scaling_type == "dynamic": 55 | return DynamicNTKScalingRotaryEmbedding( 56 | head_dim, 57 | max_position_embeddings=max_position_embeddings, 58 | scaling_factor=rope_scaling_factor, 59 | base=rope_theta, 60 | device=device, 61 | ) 62 | else: 63 | raise NotImplementedError(f'Sorry rope_scaling_type == "{rope_scaling_type}" not implemented.') 64 | 65 | 66 | # Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) 67 | def rotate_half(x): 68 | """Rotates half the hidden dims of the input.""" 69 | x1 = x[..., : x.shape[-1] // 2] 70 | x2 = x[..., x.shape[-1] // 2 :] 71 | return torch.cat((-x2, x1), dim=-1) 72 | 73 | 74 | # Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) 75 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 76 | """Applies Rotary Position Embedding to the query and key tensors.""" 77 | if position_ids is not None: 78 | cos, sin = cos[position_ids], sin[position_ids] 79 | cos = cos.unsqueeze(unsqueeze_dim) 80 | sin = sin.unsqueeze(unsqueeze_dim) 81 | q_embed = (q * cos) + (rotate_half(q) * sin) 82 | k_embed = (k * cos) + (rotate_half(k) * sin) 83 | return q_embed, k_embed 84 | 85 | 86 | # Modified from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) 87 | class RotaryEmbedding(nn.Module): 88 | """Original Rotary Embeddings from RoFormer https://arxiv.org/abs/2104.09864""" 89 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 90 | super().__init__() 91 | 92 | self.dim = dim 93 | self.max_position_embeddings = max_position_embeddings 94 | self.base = base 95 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 96 | self.register_buffer("inv_freq", inv_freq, persistent=False) 97 | 98 | # Build here to make `torch.jit.trace` work. 99 | self._set_cos_sin_cache( 100 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 101 | ) 102 | 103 | def _set_cos_sin_cache(self, seq_len, device, dtype): 104 | self.max_seq_len_cached = seq_len 105 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 106 | 107 | freqs = torch.outer(t, self.inv_freq) 108 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 109 | emb = torch.cat((freqs, freqs), dim=-1) 110 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 111 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 112 | 113 | def forward(self, x, seq_len=None): 114 | """ 115 | Compute rotary embeddings 116 | """ 117 | # x: [bs, num_attention_heads, seq_len, head_size] 118 | if seq_len > self.max_seq_len_cached: 119 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 120 | 121 | return ( 122 | self.cos_cached[:seq_len].to(dtype=x.dtype), 123 | self.sin_cached[:seq_len].to(dtype=x.dtype), 124 | ) 125 | 126 | 127 | # Copied from transformers/models/llama/modeling_llama.py at v4.36 128 | class LinearScalingRotaryEmbedding(RotaryEmbedding): 129 | """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 130 | 131 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 132 | self.scaling_factor = scaling_factor 133 | super().__init__(dim, max_position_embeddings, base, device) 134 | 135 | def _set_cos_sin_cache(self, seq_len, device, dtype): 136 | self.max_seq_len_cached = seq_len 137 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 138 | t = t / self.scaling_factor 139 | 140 | freqs = torch.outer(t, self.inv_freq) 141 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 142 | emb = torch.cat((freqs, freqs), dim=-1) 143 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 144 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 145 | 146 | 147 | # Copied from transformers/models/llama/modeling_llama.py at v4.36 148 | class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): 149 | """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 150 | 151 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 152 | self.scaling_factor = scaling_factor 153 | super().__init__(dim, max_position_embeddings, base, device) 154 | 155 | def _set_cos_sin_cache(self, seq_len, device, dtype): 156 | self.max_seq_len_cached = seq_len 157 | 158 | if seq_len > self.max_position_embeddings: 159 | base = self.base * ( 160 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 161 | ) ** (self.dim / (self.dim - 2)) 162 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 163 | self.register_buffer("inv_freq", inv_freq, persistent=False) 164 | 165 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 166 | 167 | freqs = torch.outer(t, self.inv_freq) 168 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 169 | emb = torch.cat((freqs, freqs), dim=-1) 170 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 171 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 172 | 173 | -------------------------------------------------------------------------------- /src/model/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def count_parameters(model, requires_grad: bool = True): 5 | """ 6 | Return total # of trainable parameters 7 | """ 8 | if requires_grad: 9 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 10 | else: 11 | model_parameters = model.parameters() 12 | try: 13 | return sum([np.prod(p.size()) for p in model_parameters]).item() 14 | except: 15 | return sum([np.prod(p.size()) for p in model_parameters]) 16 | -------------------------------------------------------------------------------- /src/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from .optim import get_optimizer, get_scheduler 3 | 4 | 5 | def get_trainer(name: str): 6 | """ 7 | Return our trainer class 8 | """ 9 | try: 10 | module = importlib.import_module(f'src.trainer.{name}') 11 | except ModuleNotFoundError as e: 12 | print(e) 13 | print('-> Using default trainer') 14 | module = importlib.import_module('src.trainer.default') 15 | return getattr(module, 'OurTrainer') 16 | -------------------------------------------------------------------------------- /src/trainer/distill_attention_mse_linear.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom trainer class for distilling attentions ("attention transfer") over long sequences with recurrent linear attention view. Can substitute for Hugging Face trainer. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from tqdm import tqdm 8 | 9 | from src.model.modeling_llama import get_attention_cache 10 | from src.model.convert_model import traverse_layers 11 | from .default_lm import OurTrainer as DefaultTrainer 12 | 13 | 14 | class OurTrainer(DefaultTrainer): 15 | """ 16 | Custom trainer class for distilling attentions. 17 | - We compute and store the attention outputs and/or weights for each head and layer, 18 | for both the "teacher" softmax attentions and "student" learnable subquadratic attentions 19 | - We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights) 20 | """ 21 | def __init__(self, 22 | model: nn.Module, 23 | metric_for_best_model: str = 'distill/eval/loss', 24 | mse_factor: float = 1e3, 25 | **kwargs: any): 26 | super().__init__(model=model, 27 | metric_for_best_model=metric_for_best_model, 28 | **kwargs) 29 | self.criterion_mse = nn.MSELoss(reduction='mean') 30 | self.mse_factor = mse_factor 31 | self.xent_factor = 0 32 | self.compute_loss_backprop = False # Whether we backprop in self.compute_loss 33 | 34 | 35 | def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], 36 | sample_idx: int = None, **kwargs: any,) -> tuple[torch.Tensor, dict[any]]: 37 | """ 38 | Attention distillation ("attention transfer") 39 | - For each layer and head, get attentions and train to 40 | minimize some combo of MSE and cross-entropy loss 41 | """ 42 | input_seq_len = data['input_ids'].shape[-1] 43 | inputs = {'input_ids': data['input_ids'].to(model.device)} # assume all inputs good 44 | 45 | # Get softmax attention outputs 46 | with torch.no_grad(): 47 | # Set base_inference to True to use FlashAttention 48 | for layer in traverse_layers(model): 49 | layer.self_attn.base_inference = True 50 | # Get hidden states 51 | true_outputs = model(**inputs, output_attentions=True, 52 | use_cache=False,) 53 | # no_logit_float=True,) 54 | # Hack were we save attention layer inputs and outputs in outputs.attentions 55 | # -> see model/hedgehog_attention_tk_long.py 56 | # attn_inputs = [a[0] for a in true_outputs.get('attentions')] 57 | # attn_outputs = [a[1] for a in true_outputs.get('attentions')] 58 | true_attn_io = true_outputs.get('attentions') # layer-wise attn inputs and outputs 59 | true_outputs = true_outputs.get('logits').cpu() 60 | for layer in traverse_layers(model): 61 | layer.self_attn.base_inference = False 62 | inputs = {k: v.cpu() for k, v in inputs.items()} 63 | torch.cuda.empty_cache() 64 | 65 | # Get trainable subquadratic attention outputs 66 | attention_type = getattr(layer.self_attn, 'attention_type', None) 67 | past_key_values = get_attention_cache(attention_type) 68 | 69 | total_seq_len = 0 70 | position_ids = torch.arange(input_seq_len).view(1, -1) 71 | 72 | loss_mse = 0 73 | for layer_idx, layer in enumerate(tqdm(traverse_layers(model), desc='Processing layer', 74 | leave=False)): 75 | attn_input, attn_output = true_attn_io[layer_idx] 76 | attn_preds = layer.self_attn(attn_input.to(model.device), 77 | attention_mask=None, 78 | position_ids=position_ids.to(model.device), 79 | past_key_value=past_key_values)[1] 80 | if self.mse_factor > 0: # MSE on layer outputs 81 | loss_mse += self.criterion_mse(attn_preds, attn_output.to(model.device)) 82 | del attn_input; del attn_output 83 | loss_mse = loss_mse / (layer_idx + 1) * self.mse_factor 84 | loss = loss_mse 85 | torch.cuda.empty_cache() 86 | 87 | if 'position_ids' in data: 88 | outputs = {'loss_mse': loss_mse.item(), 89 | 'loss_xent': 0, 90 | 'mse_factor': self.mse_factor, 91 | 'xent_factor': self.xent_factor, 92 | 'input_len': data['position_ids'].shape[1], 93 | 'position_ids': data['position_ids'][0],} 94 | else: 95 | outputs = {'loss_mse': loss_mse.item(), 96 | 'loss_xent': 0, 97 | 'mse_factor': self.mse_factor, 98 | 'xent_factor': self.xent_factor,} 99 | return loss, outputs -------------------------------------------------------------------------------- /src/trainer/distill_attention_xent_mse.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer. 3 | 4 | In this implementation we support using either just the softmax attention outputs, or the softmax attention weights. 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .default_lm import OurTrainer as DefaultTrainer 10 | 11 | 12 | class OurTrainer(DefaultTrainer): 13 | """ 14 | Custom trainer class for distilling attentions. 15 | - We compute and store the attention outputs and/or weights for each head and layer, 16 | for both the "teacher" softmax attentions and "student" learnable subquadratic attentions 17 | - We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights) 18 | """ 19 | def __init__(self, 20 | model: nn.Module, 21 | metric_for_best_model: str = 'distill/eval/loss', 22 | mse_factor: float = 1e3, 23 | xent_factor: float = 0, 24 | **kwargs: any): 25 | super().__init__(model=model, 26 | metric_for_best_model=metric_for_best_model, 27 | **kwargs) 28 | self.criterion_xent = nn.CrossEntropyLoss(reduction='mean') 29 | self.criterion_mse = nn.MSELoss(reduction='mean') 30 | self.mse_factor = mse_factor 31 | self.xent_factor = xent_factor 32 | self.compute_loss_backprop = False # Whether we backprop in self.compute_loss 33 | 34 | def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], 35 | sample_idx: int = None, **kwargs: any,) -> tuple[torch.Tensor, dict[any]]: 36 | """ 37 | Attention distillation ("attention transfer") 38 | - For each layer and head, get attentions and train to 39 | minimize some combo of MSE and cross-entropy loss 40 | """ 41 | inputs = {k: v.to(model.device) for k, v in data.items() if k != 'labels'} 42 | outputs = model(**inputs, output_attentions=True, use_cache=False) 43 | outputs = outputs.get('attentions') 44 | 45 | # Attentions are tuple[tuple[torch.Tensor, torch.Tensor]] 46 | # n_layers x (predicted_attns, true_attns) 47 | # predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len) 48 | loss_mse = 0 49 | loss_xent = 0 50 | n_layers = 0 # Number of layers to distill 51 | softmax_layers = [] 52 | for layer_idx, attns in enumerate(outputs): 53 | if attns is not None: 54 | if len(attns) != 2: 55 | attns = attns.cpu() 56 | else: 57 | if self.xent_factor > 0: 58 | # Cross-entropy loss 59 | a_pred, a_true = attns[0] 60 | a_pred = a_pred.clamp(min=1e-12).log() # nn.CrossEntropy assumes unnormalized logits 61 | k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len 62 | # Compute mean cross-entropy over all queries 63 | a_pred = a_pred.contiguous().view(-1, k_len) 64 | a_true = a_true.contiguous().view(-1, k_len) 65 | loss_xent += self.criterion_xent(a_pred, a_true) 66 | if self.mse_factor > 0: 67 | loss_mse += self.criterion_mse(*attns[1]) 68 | n_layers += 1 69 | else: 70 | softmax_layers.append(layer_idx) 71 | if n_layers > 0: 72 | loss_xent = loss_xent / n_layers * self.xent_factor 73 | loss_mse = loss_mse / n_layers * self.mse_factor 74 | loss = loss_xent + loss_mse 75 | if 'position_ids' in data: 76 | outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, 77 | 'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, 78 | 'input_len': data['position_ids'].shape[1], 79 | 'position_ids': data['position_ids'][0].detach().cpu().numpy(), 80 | 'mse_factor': self.mse_factor, 81 | 'xent_factor': self.xent_factor,} 82 | else: 83 | outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, 84 | 'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, 85 | 'mse_factor': self.mse_factor, 86 | 'xent_factor': self.xent_factor} 87 | return loss, outputs 88 | -------------------------------------------------------------------------------- /src/trainer/finetune_seq2seq.py: -------------------------------------------------------------------------------- 1 | """ 2 | General seq2seq / input-output trainer 3 | """ 4 | from typing import Optional 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | 11 | from tqdm import tqdm 12 | 13 | from .default_lm import OurTrainer as DefaultTrainer 14 | from .utils import replace_padding_tokens 15 | 16 | 17 | def compute_scrolls_metrics(eval_preds, scrolls_metric, tokenizer): 18 | """ 19 | Function to compute metrics that are also in SCROLLS (ROUGE, F1, etc.) 20 | """ 21 | preds, labels = eval_preds 22 | if isinstance(preds, tuple): 23 | preds = preds[0] 24 | # Replace -100s used for padding as we can't decode them 25 | preds = replace_padding_tokens(preds, tokenizer.pad_token_id) 26 | labels = replace_padding_tokens(labels, tokenizer.pad_token_id) 27 | 28 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 29 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 30 | 31 | # Scrolls metric expects predictions to be [pred_1, pred_2, ...] 32 | # and references to be [[ref_1], [ref_2], ... ] 33 | decoded_labels = [[s] for s in decoded_labels] 34 | 35 | result = scrolls_metric.compute(predictions=decoded_preds, 36 | references=decoded_labels) 37 | print('----------------') 38 | print('Model generation') 39 | print(decoded_preds[:10]) 40 | print('----------------') 41 | print('True answer') 42 | print(decoded_labels[:10]) 43 | return result 44 | 45 | 46 | class OurTrainer(DefaultTrainer): 47 | """ 48 | Evaluator for seq-to-seq / generation benchmarks 49 | """ 50 | def __init__(self, model, args, # max_eval_batches: Optional[int] = 100, 51 | **kwargs: any): 52 | super().__init__(model=model, args=args, **kwargs) 53 | # Reset + determine metric for best automatically based on the dataset 54 | self.metric_for_best = None 55 | self.is_better = lambda x, y: x > y # Hardcode greater is better for now 56 | self.print_steps = getattr(args, 'print_steps', 100) 57 | print(f'self.print_steps:', self.print_steps) 58 | # ablation sweep 59 | self.max_eval_batches = 10 60 | 61 | def init_criterion_(self): 62 | pass 63 | 64 | def compute_loss(self): 65 | pass 66 | 67 | def evaluate(self, *args: any, **kwargs: any): 68 | return self.eval_step(*args, **kwargs) 69 | 70 | def eval_step(self, model: nn.Module, step: int, 71 | dataloader: DataLoader = None, 72 | max_batches: int = None, 73 | prefix: str = None, 74 | **kwargs: any): # -1): 75 | """ 76 | One evaluation step 77 | """ 78 | total = 0 79 | total_loss = 0 80 | metrics = {} 81 | max_batches = self.max_eval_batches if max_batches is None else max_batches 82 | max_batches = 10 # ablation sweep 83 | 84 | dataloader = (dataloader if dataloader is not None else self.eval_loader) 85 | 86 | scrolls_metric = dataloader.dataset.metric # Should be assigned in dataset 87 | tokenizer = dataloader.dataset.tokenizer 88 | 89 | # Save decoded predictions and references here to compute average metrics 90 | predictions, references = [], [] 91 | 92 | model.eval() 93 | 94 | pbar = tqdm(dataloader, leave=False, colour='green', 95 | desc=f'Evaluating at step {step}') 96 | 97 | with torch.no_grad(): 98 | for ix, data in enumerate(pbar): 99 | inputs = {k: v.to(self.device) for k, v in data.items() 100 | if k in ['input_ids', 'attention_mask']} 101 | labels = data['labels'] 102 | outputs = model.generate(**inputs, 103 | max_new_tokens=1024, # hardcoded for now 104 | pad_token_id=tokenizer.pad_token_id, 105 | use_cache=True,).cpu() 106 | # Only save newly generated tokens 107 | pred_ids = outputs[:, data['input_ids'].shape[1]:] 108 | predictions.append(pred_ids) 109 | references.append(labels) 110 | pbar.set_description(f"Evaluating at step {step} | input_len: {data['input_ids'].shape[1]} | output_len: {labels.shape[1]}") 111 | 112 | if ix == max_batches: 113 | break 114 | 115 | if (ix + 1) % self.print_steps == 0: # 100 == 0: 116 | print(f'Model input: \n', tokenizer.batch_decode(inputs['input_ids'].detach().cpu())[0]) 117 | print(f'Model output:\n', tokenizer.batch_decode(pred_ids)[0]) 118 | print(f'True output:\n', tokenizer.batch_decode(labels)[0]) 119 | 120 | # Compute and save metrics 121 | try: 122 | predictions = torch.cat(predictions, dim=0) 123 | references = torch.cat(references, dim=0) 124 | except: 125 | pass 126 | _metric = compute_scrolls_metrics((predictions, references), 127 | scrolls_metric, tokenizer) 128 | if self.metric_for_best is None: # Hard-coded for now 129 | if 'f1' in _metric: 130 | self.metric_for_best = f'eval/f1' 131 | elif 'exact_match' in _metric: 132 | self.metric_for_best = f'eval/exact_match' 133 | elif 'rouge/geometric_mean' in _metric: 134 | self.metric_for_best = f'eval/rouge/geometric_mean' 135 | for k, v in _metric.items(): 136 | if 'display' not in k: 137 | _k = f'{prefix}/eval/{k}' if prefix is not None else f'eval/{k}' 138 | metrics[_k] = v 139 | 140 | return metrics 141 | -------------------------------------------------------------------------------- /src/trainer/optim.py: -------------------------------------------------------------------------------- 1 | """ 2 | Optimizer and schedulers 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from torch.optim import Optimizer 7 | from torch.optim.lr_scheduler import LRScheduler 8 | 9 | 10 | def get_optimizer(optim: str, model: nn.Module, **kwargs: any) -> Optimizer: 11 | """ 12 | Return training optimizer 13 | """ 14 | if optim == 'sgd': 15 | return torch.optim.SGD(model.parameters(), **kwargs) 16 | elif optim == 'adam': 17 | return torch.optim.Adam(model.parameters(), **kwargs) 18 | elif optim in ['adamw', 'adamw_torch']: 19 | return torch.optim.AdamW(model.parameters(), **kwargs) 20 | elif optim == 'adamw_torch_fused': 21 | return torch.optim.AdamW(model.parameters(), **kwargs, fused=True) 22 | elif optim == 'adafactor': 23 | from transformers import Adafactor 24 | kwargs['relative_step'] = False # for now 25 | return Adafactor(model.parameters(), **kwargs) 26 | else: 27 | raise NotImplementedError(f"{optim} optimizer not implemented sorry.") 28 | 29 | 30 | def get_scheduler(lr_scheduler_type: str, optimizer: Optimizer, 31 | **kwargs: any) -> LRScheduler: 32 | """ 33 | Return learning rate scheduler 34 | """ 35 | if lr_scheduler_type in ['plateau', 'reduce_lr_on_plateau']: 36 | from torch.optim.lr_scheduler import ReduceLROnPlateau 37 | return ReduceLROnPlateau(optimizer=optimizer, **kwargs) 38 | 39 | elif lr_scheduler_type == 'cosine_warmup': 40 | from transformers import get_cosine_schedule_with_warmup 41 | return get_cosine_schedule_with_warmup(optimizer=optimizer, **kwargs) 42 | 43 | elif lr_scheduler_type in ['linear_warmup', 'linear']: 44 | from transformers import get_linear_schedule_with_warmup 45 | return get_linear_schedule_with_warmup(optimizer=optimizer, **kwargs) 46 | 47 | else: 48 | return None -------------------------------------------------------------------------------- /src/trainer/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training loop helpers 3 | """ 4 | import torch 5 | import numpy as np 6 | 7 | from transformers.tokenization_utils import PreTrainedTokenizer 8 | 9 | 10 | def replace_padding_tokens(token_ids: torch.Tensor, 11 | pad_token_id: int, 12 | ignore_token_id: int = -100) -> any: 13 | """ 14 | Replace ignore_token_id tokens with pad_token_id, 15 | e.g., for printing inputs during training 16 | """ 17 | if isinstance(token_ids, list): 18 | return [np.where(t != ignore_token_id, t, pad_token_id)[0] for t in token_ids] 19 | else: 20 | return np.where(token_ids != ignore_token_id, token_ids, pad_token_id) 21 | 22 | 23 | def decode_samples(outputs: torch.Tensor, 24 | targets: torch.Tensor, 25 | tokenizer: PreTrainedTokenizer, 26 | sample_idx: int = None) -> None: 27 | """ 28 | Print first element of samples for debugging 29 | """ 30 | print('=' * 20) 31 | print(f'*** TARGETS (sample {sample_idx})***') 32 | tokens = tokenizer.decode( 33 | replace_padding_tokens(targets[0], tokenizer.pad_token_id) 34 | ) 35 | print(tokens) 36 | print('-' * 20) 37 | print(f'*** PREDICTIONS (sample {sample_idx}) ***') 38 | pred_logits = outputs.argmax(dim=-1).cpu() 39 | pred_tokens = tokenizer.decode( 40 | replace_padding_tokens(pred_logits[0], tokenizer.pad_token_id) 41 | ) 42 | print(pred_tokens) 43 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/lolcats/375df84f58417a57b78c875fa3f8fa76e84ab12e/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logging utilities to make terminal slightly more delightful 3 | """ 4 | import rich.syntax 5 | import rich.tree 6 | 7 | from omegaconf import OmegaConf, DictConfig, ListConfig 8 | 9 | 10 | def _format_arg(arg_name: str, cutoff=2) -> str: 11 | if arg_name is None: 12 | return arg_name 13 | arg_name = str(arg_name) 14 | 15 | # Hardcode to handle backslash 16 | name_splits = arg_name.split('/') 17 | if len(name_splits) > 1: 18 | return name_splits[-1] 19 | # Abbreviate based on underscore 20 | name_splits = arg_name.split('_') 21 | if len(name_splits) > 1: 22 | return ''.join([s[0] for s in name_splits]) 23 | else: 24 | return arg_name[:cutoff] 25 | 26 | 27 | def print_header(x: str) -> None: 28 | """ 29 | Print a header with a line above and below 30 | """ 31 | print('-' * len(x)) 32 | print(x) 33 | print('-' * len(x)) 34 | 35 | 36 | def print_args(args, return_dict=False, verbose=True): 37 | """ 38 | Print the arguments passed to the script 39 | """ 40 | attributes = [a for a in dir(args) if a[0] != '_'] 41 | arg_dict = {} # switched to ewr 42 | if verbose: 43 | print('ARGPARSE ARGS') 44 | for ix, attr in enumerate(attributes): 45 | fancy = '└─' if ix == len(attributes) - 1 else '├─' 46 | if verbose: 47 | print(f'{fancy} {attr}: {getattr(args, attr)}') 48 | arg_dict[attr] = getattr(args, attr) 49 | if return_dict: 50 | return arg_dict 51 | 52 | 53 | def update_description_metrics(description: str, metrics: dict): 54 | """ 55 | Set the numbers that show up on progress bars 56 | """ 57 | for split in metrics: 58 | if split != 'test': # No look 59 | for metric_name, metric in metrics[split].items(): 60 | description += f' | {split}/{metric_name}: {metric:.3f}' 61 | return description 62 | 63 | 64 | # Control how tqdm progress bar looks 65 | def type_of_script(): 66 | try: 67 | ipy_str = str(type(get_ipython())) 68 | if 'zmqshell' in ipy_str: 69 | return 'jupyter' 70 | if 'terminal' in ipy_str: 71 | return 'ipython' 72 | except: 73 | return 'terminal' 74 | 75 | # Progress bar 76 | def update_pbar_display(metrics, batch_ix, pbar, prefix, batch_size, accum_iter=1): 77 | description = f'└── {prefix} batch {int(batch_ix)}/{len(pbar)} [batch size: {batch_size} - grad. accum. over {accum_iter} batch(es)]' 78 | for metric_name, metric in metrics.items(): 79 | if metric_name == 'correct': 80 | description += f' | {metric_name} (acc. %): {int(metric):>5d}/{int(metrics["total"])} = {metric / metrics["total"] * 100:.3f}%' 81 | elif metric_name == 'acc': 82 | description += f' | {metric_name}: {metric:.3f}' 83 | elif metric_name in ['perplexity']: # , 'bpc']: 84 | description += f' | {metric_name}: {Decimal(metric):.3E}' 85 | elif metric_name != 'total': 86 | description += f' | {metric_name}: {metric / metrics["total"]:.3f}' 87 | pbar.set_description(description) 88 | 89 | 90 | def print_config(config: DictConfig, 91 | resolve: bool = True, 92 | name: str = 'CONFIG') -> None: 93 | """Prints content of DictConfig using Rich library and its tree structure. 94 | Args: 95 | config (DictConfig): Configuration composed by Hydra. 96 | fields (Sequence[str], optional): Determines which main fields from config will 97 | be printed and in what order. 98 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 99 | """ 100 | 101 | style = "bright" # "dim" 102 | tree = rich.tree.Tree(name, style=style, guide_style=style) 103 | 104 | fields = config.keys() 105 | for field in fields: 106 | branch = tree.add(field, style=style, guide_style=style) 107 | 108 | config_section = config.get(field) 109 | branch_content = str(config_section) 110 | if isinstance(config_section, DictConfig): 111 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 112 | elif isinstance(config_section, ListConfig): 113 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 114 | 115 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 116 | 117 | rich.print(tree) -------------------------------------------------------------------------------- /src/utils/setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | General helper functions for setting up experiments 3 | """ 4 | import os 5 | import random 6 | 7 | from argparse import ArgumentParser 8 | from omegaconf import DictConfig 9 | 10 | import torch 11 | import numpy as np 12 | 13 | from .logging import _format_arg 14 | 15 | 16 | def init_wandb(args: ArgumentParser) -> any: 17 | """Initialize WandB""" 18 | if args.no_wandb: 19 | wandb = None 20 | else: 21 | import wandb 22 | wandb.init(config={}, 23 | entity=args.wandb_entity, 24 | name=args.run_name, 25 | project=args.project_name) 26 | return wandb 27 | 28 | 29 | def seed_everything(seed: int) -> None: 30 | """ 31 | Seed everything 32 | """ 33 | random.seed(seed) 34 | os.environ['PYTHONHASHSEED'] = str(seed) 35 | np.random.seed(seed) 36 | torch.manual_seed(seed) 37 | torch.cuda.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | torch.backends.cudnn.deterministic = True 40 | torch.backends.cudnn.benchmark = False 41 | 42 | 43 | def get_run_name_from_checkpoint(checkpoint_path: str) -> str: 44 | """ 45 | Helper function to get a condensed run name from the checkpoint path 46 | """ 47 | name = [] 48 | for s in checkpoint_path.split('/')[-1].split('-'): 49 | if '.pt' in s: 50 | name.append(f'_{s[:-3]}') 51 | try: 52 | s = s.split('=') 53 | s = ''.join([c[0] for c in s[1].split('_')]) 54 | name.append(s) 55 | except IndexError: 56 | pass 57 | return ''.join(name) 58 | 59 | 60 | def get_run_name_from_args(args) -> str: 61 | """ 62 | Prepare a heinous identifier for the run based on args 63 | """ 64 | if args.load_distill_checkpoint is not None and args.load_distill_checkpoint != 'default': 65 | distill_name = get_run_name_from_checkpoint(args.load_distill_checkpoint) 66 | else: 67 | distill_name = args.distill_config 68 | if args.load_finetune_checkpoint is not None and args.finetune_config is None: # args.load_finetune_checkpoint != 'default': 69 | finetune_name = get_run_name_from_checkpoint(args.load_finetune_checkpoint) 70 | else: 71 | finetune_name = args.finetune_config 72 | args.run_name = f'dl-d={distill_name}-m={args.model_config}-f={finetune_name}' 73 | if args.no_peft_grad_ckpt is not None: 74 | args.run_name += f'-npgc={args.no_peft_grad_ckpt}' 75 | args.run_name += f'-s={args.seed}' 76 | if args.debug: 77 | args.run_name += f'-debug' 78 | if args.no_attention_mask is not None: 79 | args.run_name += f'-nam=1' 80 | return args.run_name.replace('True', '1').replace('False', '0') # concise hacks 81 | 82 | 83 | def flatten_config(config: dict, flattened: dict, key: str) -> dict: 84 | """ 85 | Recursive way to flatten config args for saving to WandB 86 | """ 87 | for k, v in config.items(): 88 | if isinstance(v, dict): 89 | flatten_config(v, flattened, f'{key}{k}_') 90 | elif isinstance(v, list): 91 | for ix, _config in enumerate(v): 92 | if isinstance(_config, dict): 93 | flatten_config(_config, flattened, f'{key}{k}_{ix}_') 94 | else: 95 | flattened[f'{key}{k}'] = v 96 | return flattened 97 | 98 | 99 | def update_config_from_args(config: DictConfig, 100 | args: ArgumentParser, 101 | ignore_args: list = None) -> DictConfig: 102 | """ 103 | Quick hacks to override default configs 104 | """ 105 | ignore_args = [] if ignore_args is None else ignore_args 106 | 107 | # Dataset 108 | if getattr(args, 'dataset', None): 109 | config.dataset.name = args.dataset 110 | args.run_name += f'-ds={args.dataset}' 111 | 112 | # Optimizer 113 | for arg in ['lr', 'weight_decay']: 114 | if arg not in ignore_args: 115 | argval = getattr(args, arg, None) 116 | if argval is not None: 117 | setattr(config.optimizer, arg, argval) 118 | args.run_name += f'-{_format_arg(arg)}={argval}' 119 | try: 120 | if getattr(args, 'optim', None): 121 | config.optimizer.optim = args.optim 122 | args.run_name += f'-o={args.optim}' 123 | except AttributeError: 124 | pass 125 | 126 | # Scheduler 127 | try: 128 | if getattr(args, 'scheduler', None): 129 | config.lr_scheduler.lr_scheduler_type = args.scheduler 130 | args.run_name += f'-sc={args.scheduler}' 131 | except AttributeError: 132 | pass 133 | 134 | # Dataset 135 | for arg in [a for a in dir(args) if 'dataset_' in a]: 136 | argval = getattr(args, arg, None) 137 | if argval is not None: 138 | setattr(config.dataset.dataset_config, arg[len('dataset_'):], argval) 139 | args.run_name += f'-{_format_arg(arg)}={argval}' 140 | 141 | # Dataloader 142 | for arg in ['batch_size']: # , 'num_workers']: 143 | argval = getattr(args, arg, None) 144 | if argval is not None: 145 | setattr(config.dataloader, arg, argval) 146 | args.run_name += f'-{_format_arg(arg)}={argval}' 147 | 148 | # Trainer 149 | for arg in ['gradient_accumulation_steps', 'num_train_epochs', 150 | 'max_steps', 'max_finetune_steps', 'eval_steps', 151 | 'seed', 'max_eval_batches']: 152 | argval = getattr(args, arg, None) 153 | if argval is not None: 154 | setattr(config.trainer, arg, argval) 155 | if arg in ['max_steps', 'max_finetune_steps', 156 | 'gradient_accumulation_steps', 'num_train_epochs', 'seed']: 157 | args.run_name += f'-{_format_arg(arg)}={argval}' 158 | 159 | # Misc 160 | for arg in ['replicate']: 161 | argval = getattr(args, arg, None) 162 | if argval is not None: 163 | args.run_name += f'-{_format_arg(arg)}={argval}' 164 | 165 | return config 166 | 167 | 168 | def update_model_config_from_args(model_config: DictConfig, 169 | args: ArgumentParser) -> DictConfig: 170 | """ 171 | Override default configs given argparse args 172 | """ 173 | # Overall attention 174 | for arg in ['attention_type', 'learned_kernel', 'tie_qk_kernels', 175 | 'train_qk', 'state_chunk_len', 'no_peft_grad_ckpt', 176 | 'window_size']: 177 | argval = getattr(args, arg, None) 178 | if argval is not None: 179 | setattr(model_config['attention'], arg, argval) 180 | args.run_name += f'-{_format_arg(arg)}={argval}' 181 | else: 182 | try: 183 | getattr(model_config['attention'], arg) 184 | except AttributeError: 185 | setattr(model_config['attention'], arg, None) 186 | 187 | # Learned kernel 188 | for arg in ['lk_skip_connection', 'lk_zero_init', 'lk_normal_init']: 189 | argval = getattr(args, arg, None) 190 | if argval is not None: 191 | setattr(model_config['attention']['learned_kernel_kwargs'], 192 | arg[len('lk_'):], argval) 193 | args.run_name += f'-{_format_arg(arg)}={argval}' 194 | 195 | # Pretrained model 196 | if args.pretrained_model_name_or_path is not None: # if specified 197 | pmnop = args.pretrained_model_name_or_path 198 | model_config.model.pretrained_model_name_or_path = pmnop 199 | args.run_name += f'-pmnop={pmnop.split("/")[-1]}' 200 | 201 | return model_config 202 | --------------------------------------------------------------------------------