├── requirements.txt ├── peft_pretraining ├── megatron_dataset │ ├── Makefile │ ├── blendable_dataset.py │ ├── samplers.py │ ├── dataset.py │ └── data_utils.py ├── args_utils.py ├── dataloader.py ├── relora.py └── training_utils.py ├── setup.py ├── CITATION.cff ├── configs ├── llama_9m.json ├── llama_100m.json ├── llama_130m.json ├── llama_1b.json ├── llama_20m.json ├── llama_250m.json ├── llama_35m.json ├── llama_3b.json ├── llama_40m.json ├── llama_60m.json ├── llama_71m.json ├── llama_7b.json ├── llama_250m_50K.json ├── llama_250m_old.json ├── llama_350m.json └── pile_megatron_dataset.yaml ├── .vscode └── launch.json ├── training_configs └── 1B_v1.0.yaml ├── notebooks ├── 02_quick_debugs.ipynb ├── 12_test_relora_init.ipynb ├── 15_debug_dataloading.ipynb ├── 03_scaling_laws_plotting.ipynb ├── 01_peft_pretraining.ipynb ├── 06_svd.ipynb ├── 10_chunking.ipynb ├── 11_test_pythia.ipynb ├── 14_check_pretokenization.ipynb ├── 09_bar_plot.ipynb ├── 05_check_ranks.ipynb └── 16_quantized.ipynb ├── pretokenize.py ├── .gitignore ├── README.dev.md ├── README.md └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | tokenizers 4 | datasets 5 | peft 6 | wandb 7 | loguru 8 | nvitop 9 | matplotlib 10 | pybind11 11 | bitsandbytes 12 | scipy 13 | evaluate 14 | packaging 15 | ninja 16 | -------------------------------------------------------------------------------- /peft_pretraining/megatron_dataset/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color 2 | CPPFLAGS += $(shell python3 -m pybind11 --includes) 3 | LIBNAME = helpers 4 | LIBEXT = $(shell python3-config --extension-suffix) 5 | 6 | default: $(LIBNAME)$(LIBEXT) 7 | 8 | %$(LIBEXT): %.cpp 9 | $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("requirements.txt") as f: 4 | required = f.read().splitlines() 5 | 6 | setup( 7 | name="peft_pretraining", 8 | version="1.0", 9 | description="ReLoRA: Parameter-efficient pre-training", 10 | url="https://github.com/Guitaricet/peft_pretraining", 11 | author="Vlad Lialin", 12 | author_email="vlad.lialin@gmail.com", 13 | license="Apache 2.0", 14 | packages=["peft_pretraining"], 15 | install_requires=required, 16 | ) 17 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: "Stack More Layers Differently: High-Rank Training Through Low-Rank Updates" 3 | version: 1.0.0 4 | message: "If you use this software, please cite it as below." 5 | authors: 6 | - family-names: "Lialin" 7 | given-names: "Vladislav" 8 | - family-names: "Shivagunde" 9 | given-names: "Namrata" 10 | - family-names: "Muckatira" 11 | given-names: "Sherin" 12 | - family-names: "Rumshisky" 13 | given-names: "Anna" 14 | year: 2023 15 | repository-code: "https://arxiv.org/abs/2307.05695" 16 | -------------------------------------------------------------------------------- /configs/llama_9m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 128, 9 | "intermediate_size": 352, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 4, 14 | "num_hidden_layers": 4, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_100m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 640, 9 | "intermediate_size": 1708, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 10, 14 | "num_hidden_layers": 12, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_130m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 768, 9 | "intermediate_size": 2048, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 12, 14 | "num_hidden_layers": 12, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_1b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 2048, 9 | "intermediate_size": 5461, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 24, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_20m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 256, 9 | "intermediate_size": 688, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 4, 14 | "num_hidden_layers": 4, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_250m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 768, 9 | "intermediate_size": 2560, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_35m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 384, 9 | "intermediate_size": 1024, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 6, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_3b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 2560, 9 | "intermediate_size": 6848, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 32, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_40m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 416, 9 | "intermediate_size": 1024, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 8, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_60m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 512, 9 | "intermediate_size": 1376, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 8, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_71m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 512, 9 | "intermediate_size": 1368, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 8, 14 | "num_hidden_layers": 12, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 4096, 9 | "intermediate_size": 11008, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 2048, 12 | "model_type": "llama", 13 | "num_attention_heads": 32, 14 | "num_hidden_layers": 32, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /configs/llama_250m_50K.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 768, 9 | "intermediate_size": 2560, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 50257 20 | } -------------------------------------------------------------------------------- /configs/llama_250m_old.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 768, 9 | "intermediate_size": 2560, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32000 20 | } -------------------------------------------------------------------------------- /configs/llama_350m.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LLaMAForCausalLM" 4 | ], 5 | "bos_token_id": 0, 6 | "eos_token_id": 1, 7 | "hidden_act": "silu", 8 | "hidden_size": 1024, 9 | "intermediate_size": 2736, 10 | "initializer_range": 0.02, 11 | "max_sequence_length": 1024, 12 | "model_type": "llama", 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "pad_token_id": -1, 16 | "rms_norm_eps": 1e-06, 17 | "transformers_version": "4.28.1", 18 | "use_cache": true, 19 | "vocab_size": 32100 20 | } -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Llama-S Relora", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "main.py", 12 | "console": "integratedTerminal", 13 | "justMyCode": true, 14 | "args": [ 15 | "--model_config", "configs/llama_9m.json", 16 | "--use_peft", 17 | "--lora_r", "32", 18 | "--relora", "100", 19 | "--reset_optimizer_on_relora", "False", 20 | "--device", "cuda:0", 21 | "--lr", "0.0005", 22 | "--batch_size", "240", 23 | "--dtype", "bfloat16", 24 | "--tags", "vscode_debugger", 25 | ] 26 | } 27 | ] 28 | } -------------------------------------------------------------------------------- /training_configs/1B_v1.0.yaml: -------------------------------------------------------------------------------- 1 | # dataset 2 | megatron_dataset_config: configs/pile_megatron_dataset.yaml 3 | max_length: 2048 4 | workers: 8 5 | 6 | # model 7 | model_name_or_path: EleutherAI/pythia-1b 8 | model_revision: step1000 9 | 10 | # saving 11 | save_dir: checkpoints/relora_1b_Aug5_2023_run2 12 | autoresume: true 13 | 14 | # ReLoRA 15 | use_peft: true 16 | force_keep_original: true 17 | lora_r: 128 18 | relora: 1000 19 | restart_warmup_steps: 100 20 | reset_optimizer_on_relora: false 21 | optimizer_magnitude_pruning: 0.8 22 | 23 | # Optimization 24 | optimizer: adam 25 | batch_size: 8 26 | total_batch_size: 1024 27 | lr: 4e-4 28 | adam_beta1: 0.9 29 | adam_beta2: 0.95 30 | weight_decay: 0.01 31 | scheduler: cosine_restarts 32 | warmup_steps: 500 # used to be 13_000, but reduced it to comply with the scheduler 33 | num_training_steps: 130_000 # used to be 133_000, but it's an ugly number 34 | eval_every: 500 35 | save_every: 500 36 | 37 | # Misc 38 | dtype: bfloat16 39 | distributed_type: ddp 40 | tags: relora1b_debug 41 | comment: "Checking if ReLoRA 1B loss is similar to regular training loss overnight" 42 | -------------------------------------------------------------------------------- /notebooks/02_quick_debugs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "100.11712" 12 | ] 13 | }, 14 | "execution_count": 4, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "import torch\n", 21 | "\n", 22 | "import transformers\n", 23 | "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n", 24 | "\n", 25 | "config = AutoConfig.from_pretrained(\"../configs/llama_100m.json\")\n", 26 | "model = AutoModelForCausalLM.from_config(config)\n", 27 | "\n", 28 | "# n params\n", 29 | "sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [] 38 | } 39 | ], 40 | "metadata": { 41 | "kernelspec": { 42 | "display_name": "base", 43 | "language": "python", 44 | "name": "python3" 45 | }, 46 | "language_info": { 47 | "codemirror_mode": { 48 | "name": "ipython", 49 | "version": 3 50 | }, 51 | "file_extension": ".py", 52 | "mimetype": "text/x-python", 53 | "name": "python", 54 | "nbconvert_exporter": "python", 55 | "pygments_lexer": "ipython3", 56 | "version": "3.10.12" 57 | }, 58 | "orig_nbformat": 4 59 | }, 60 | "nbformat": 4, 61 | "nbformat_minor": 2 62 | } 63 | -------------------------------------------------------------------------------- /peft_pretraining/megatron_dataset/blendable_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, EleutherAI 2 | # This file is based on code by the authors denoted below and has been modified from its original version. 3 | # 4 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """Blendable dataset.""" 19 | 20 | import time 21 | 22 | import numpy as np 23 | import torch 24 | import torch.distributed as dist 25 | 26 | 27 | class BlendableDataset(torch.utils.data.Dataset): 28 | def __init__(self, datasets, weights): 29 | self.datasets = datasets 30 | num_datasets = len(datasets) 31 | assert num_datasets == len(weights) 32 | 33 | self.size = 0 34 | for dataset in self.datasets: 35 | self.size += len(dataset) 36 | 37 | # Normalize weights. 38 | weights = np.array(weights, dtype=np.float64) 39 | sum_weights = np.sum(weights) 40 | assert sum_weights > 0.0 41 | weights /= sum_weights 42 | 43 | # Build indices. 44 | start_time = time.time() 45 | assert num_datasets < 255 46 | self.dataset_index = np.zeros(self.size, dtype=np.uint8) 47 | self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) 48 | 49 | from peft_pretraining.megatron_dataset import helpers 50 | 51 | helpers.build_blending_indices( 52 | self.dataset_index, 53 | self.dataset_sample_index, 54 | weights, 55 | num_datasets, 56 | self.size, 57 | False, # verbose 58 | ) 59 | 60 | rank = dist.get_rank() if dist.is_initialized() else 0 61 | _time_delta = time.time() - start_time 62 | if _time_delta > 5.0: 63 | print(f"> RANK {rank} elapsed time for building blendable dataset indices: " 64 | f"{time.time() - start_time:.2f} (sec)") 65 | 66 | def __len__(self): 67 | return self.size 68 | 69 | def __getitem__(self, idx): 70 | try: 71 | dataset_idx = self.dataset_index[idx] 72 | sample_idx = self.dataset_sample_index[idx] 73 | return self.datasets[dataset_idx][sample_idx] 74 | except IndexError: 75 | new_idx = idx % len(self) 76 | print( 77 | f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" 78 | ) 79 | return self[new_idx] 80 | -------------------------------------------------------------------------------- /configs/pile_megatron_dataset.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # NOTE: this config does not support using - in the key names, 3 | # because we load it from yaml and then feed to NeoXArgs.from_dict(). 4 | # Use _ instead of - in the key names 5 | 6 | "pipe_parallel_size": 1, 7 | "model_parallel_size": 1, 8 | 9 | # path to dataset .bin and .idx file (path should be filenames without `.bin` or `.idx`) 10 | "train_data_paths": ["/fsx/pile/pile_20B_tokenizer_text_document"], 11 | "valid_data_paths": ["/fsx/pile/pile_20B_tokenizer_text_document"], 12 | "test_data_paths": ["/fsx/pile/pile_20B_tokenizer_text_document"], 13 | 14 | "tokenizer_type": "HFTokenizer", 15 | "vocab_file": "configs/pythia_tokenizer.json", 16 | 17 | "train_micro_batch_size_per_gpu": "", 18 | "train_batch_size": "", 19 | "num_workers": 8, 20 | 21 | "seq_length": 2048, 22 | "train_iters": 143000, 23 | "data_impl": "mmap", 24 | 25 | ############################################################################################ 26 | # everything below is ignored by the training script, only needed to create neox_args object 27 | ############################################################################################ 28 | 29 | # model settings 30 | "num_layers": 12, 31 | "hidden_size": 768, 32 | "num_attention_heads": 12, 33 | "max_position_embeddings": 2048, 34 | "pos_emb": "rotary", 35 | "rotary_pct": 0.25, 36 | "no_weight_tying": true, 37 | "gpt_j_residual": true, 38 | "output_layer_parallelism": "column", 39 | 40 | "scaled_upper_triang_masked_softmax_fusion": true, 41 | "bias_gelu_fusion": true, 42 | 43 | # init methods 44 | "init_method": "small_init", 45 | "output_layer_init_method": "wang_init", 46 | 47 | "optimizer": { 48 | "type": "Adam", 49 | "params": { 50 | "lr": 0.0006, 51 | "betas": [0.9, 0.95], 52 | "eps": 1.0e-8, 53 | } 54 | }, 55 | "min_lr": 0.00006, 56 | 57 | "zero_optimization": { 58 | "stage": 1, 59 | "allgather_partitions": True, 60 | "allgather_bucket_size": 500000000, 61 | "overlap_comm": True, 62 | "reduce_scatter": True, 63 | "reduce_bucket_size": 500000000, 64 | "contiguous_gradients": True, 65 | "cpu_offload": False 66 | }, 67 | 68 | # activation checkpointing 69 | "checkpoint_activations": true, 70 | "checkpoint_num_layers": 1, 71 | "partition_activations": true, 72 | "synchronize_each_layer": true, 73 | 74 | # regularization 75 | "gradient_clipping": 1.0, 76 | "weight_decay": 0.1, 77 | "hidden_dropout": 0, 78 | "attention_dropout": 0, 79 | 80 | # precision settings 81 | "fp16": { 82 | "fp16": true, 83 | "enabled": true, 84 | "loss_scale": 0, 85 | "loss_scale_window": 1000, 86 | "initial_scale_power": 12, 87 | "hysteresis": 2, 88 | "min_loss_scale": 1, 89 | }, 90 | 91 | "lr_decay_iters": 143000, 92 | "distributed_backend": "nccl", 93 | "lr_decay_style": "cosine", 94 | "warmup": 0.01, 95 | # "save_interval": 250, 96 | # "eval_interval": 40000, 97 | # "eval_iters": 10, 98 | } -------------------------------------------------------------------------------- /notebooks/12_test_relora_init.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 13, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os, sys\n", 10 | "sys.path.append(\"..\")\n", 11 | "\n", 12 | "import torch\n", 13 | "\n", 14 | "from transformers import AutoTokenizer, GPTNeoXForCausalLM, AutoModelForCausalLM, AutoConfig\n", 15 | "from peft_pretraining.relora import ReLoRaModel\n", 16 | "# from optimum.bettertransformer import BetterTransformer" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 24, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "model_name = \"EleutherAI/pythia-1.4b\"\n", 26 | "# model_name = \"gpt2\"\n", 27 | "# config = AutoConfig.from_pretrained(model_name)\n", 28 | "# orig_model = AutoModelForCausalLM.from_config(config)\n", 29 | "orig_model = AutoModelForCausalLM.from_pretrained(model_name)\n", 30 | "# orig_model.eval()" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 25, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 40 | "input_ids = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\").input_ids" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 30, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "relora_model = ReLoRaModel(\n", 50 | " orig_model,\n", 51 | " r=128,\n", 52 | " lora_alpha=32,\n", 53 | " lora_dropout=0.1,\n", 54 | " target_modules=[\"attn\", \"attention\", \"mlp\"],\n", 55 | " trainable_scaling=False,\n", 56 | " keep_original_weights=True,\n", 57 | " lora_only=False,\n", 58 | ")" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 31, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "data": { 68 | "text/plain": [ 69 | "tensor(4.3360, grad_fn=)" 70 | ] 71 | }, 72 | "execution_count": 31, 73 | "metadata": {}, 74 | "output_type": "execute_result" 75 | } 76 | ], 77 | "source": [ 78 | "out2 = relora_model(input_ids, labels=input_ids)\n", 79 | "out2.loss" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "peft_pretraining_shala", 93 | "language": "python", 94 | "name": "python3" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 3 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython3", 106 | "version": "3.10.11" 107 | }, 108 | "orig_nbformat": 4 109 | }, 110 | "nbformat": 4, 111 | "nbformat_minor": 2 112 | } 113 | -------------------------------------------------------------------------------- /notebooks/15_debug_dataloading.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "0fb0896d-c5ea-4a91-a5af-2206e057380c", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "from peft_pretraining.megatron_dataset import data_utils\n", 13 | "from peft_pretraining.megatron_dataset.arguments import NeoXArgs" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "80ca8e2b-8dff-4c03-8228-3de497ff2ab3", 20 | "metadata": { 21 | "tags": [] 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "args = NeoXArgs.from_ymls([\"../configs/pile_megatron_dataset.yaml\"])" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "id": "5dde67a5-19a6-4c34-b807-dd2d177f9b6f", 32 | "metadata": { 33 | "tags": [] 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "train_dataloader, _, _ = data_utils.build_train_valid_test_dataloaders(neox_args=args)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 8, 43 | "id": "f78bd8a1-32d9-469a-b83c-64cafd1920c4", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "train_iterator = iter(train_dataloader)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 9, 53 | "id": "72139140-b9cb-45ee-98d8-ec5d94d531f7", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "batch = next(train_iterator)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "id": "7a4f265b-3e8f-4f00-81ce-916477fa4668", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "type(batch), batch.keys()" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 11, 73 | "id": "74f06004-4311-43bd-bf1f-bdb164facb99", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "torch.Size([1024, 2049])" 80 | ] 81 | }, 82 | "execution_count": 11, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "batch[\"input_ids\"].shape" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "cf16f858-dd7e-4ad2-8739-c0d830a00eb8", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "for i, batch in enumerate(train_dataloader):\n", 99 | " if i > 1: break\n", 100 | " print(batch)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "b66a7fc4-6a67-49ae-9b09-b25a8430d6e9", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "for i, batch in enumerate(train_dataloader):\n", 111 | " if i > 1: break\n", 112 | " print(batch)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "1d773d3d-9bc2-4b74-883c-48d3f4197164", 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "Python 3 (ipykernel)", 127 | "language": "python", 128 | "name": "python3" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.10.12" 141 | } 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 5 145 | } 146 | -------------------------------------------------------------------------------- /pretokenize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download and pre-tokenize a huggingface dataset. 3 | Based on: https://github.com/conceptofmind/PaLM/blob/main/palm/build_dataloaders.py 4 | 5 | Usage: 6 | python build_dataloaders.py --tokenizer EleutherAI/gpt-neox-20b --dataset openwebtext --text_field text --sequence_length 2048 7 | """ 8 | import os 9 | import time 10 | import json 11 | import argparse 12 | import multiprocessing 13 | 14 | from loguru import logger 15 | from datasets import load_dataset, DatasetDict, Dataset 16 | from transformers import AutoTokenizer 17 | 18 | 19 | from peft_pretraining.dataloader import tokenize_and_chunk 20 | 21 | 22 | def parse_args(args=None): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--tokenizer", type=str, required=True, help="HuggingFace tokenizer name") 25 | parser.add_argument("--dataset", type=str, required=True, help="HuggingFace dataset name. E.g., wikitext") 26 | parser.add_argument("--dataset_config", type=str, default=None, help="HuggingFace dataset config name. E.g., wikitext-2-v1") 27 | parser.add_argument("--text_field", type=str, default="text", help="Name of the text field in the dataset") 28 | parser.add_argument("--sequence_length", type=int, default=2048, help="Sequence length") 29 | parser.add_argument("--num_cpu", type=int, default=multiprocessing.cpu_count(), help="Number of CPU cores") 30 | parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the pre-tokenized dataset") 31 | 32 | parser.add_argument("--take", type=int, default=None, help="Number of examples to take from the dataset") 33 | args = parser.parse_args(args) 34 | 35 | return args 36 | 37 | 38 | def main(args): 39 | print("In main") 40 | logger.info("*" * 40) 41 | logger.info(f"Starting script with the arguments") 42 | for k, v in vars(args).items(): 43 | logger.info(f"{k:30} {v}") 44 | logger.info("*" * 40) 45 | 46 | _tokenizer_name_for_save = args.tokenizer.replace("/", "_") 47 | save_path = os.path.join(args.save_dir, f"{args.dataset}_{_tokenizer_name_for_save}_{args.sequence_length}") 48 | if args.dataset_config is not None: 49 | save_path = os.path.join(args.save_dir, f"{args.dataset}_{args.dataset_config}_{_tokenizer_name_for_save}_{args.sequence_length}") 50 | 51 | if os.path.exists(save_path): 52 | raise ValueError(f"Path {save_path} already exists") 53 | 54 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 55 | logger.info(f"Loaidng the dataset in streaming mode: {args.take is not None}") 56 | dataset = load_dataset(args.dataset, args.dataset_config, streaming=args.take is not None) 57 | 58 | if args.take is not None: 59 | logger.info(f"Taking {args.take} examples from the dataset") 60 | def take(ds, n): 61 | return Dataset.from_generator(lambda: (yield from ds.take(n))) 62 | dataset_dict = {k: take(v, args.take) for k, v in dataset.items()} 63 | dataset = DatasetDict(dataset_dict) 64 | 65 | logger.info("Tokenizing and chunking the dataset") 66 | _time = time.time() 67 | dataset = tokenize_and_chunk( 68 | tokenizer=tokenizer, 69 | dataset=dataset, 70 | text_field=args.text_field, 71 | sequence_length=args.sequence_length, 72 | num_cpu=args.num_cpu, 73 | ) 74 | _hours = (time.time() - _time) / 3600 75 | logger.info(f"Tokenization and chunking took {_hours:.2f} hours") 76 | 77 | dataset.save_to_disk(save_path) 78 | logger.info(f"Saved the dataset to {save_path}") 79 | 80 | with open(os.path.join(save_path, "args.json"), "w") as f: 81 | json.dump(vars(args), f, indent=4) 82 | print("In main") 83 | 84 | 85 | if __name__ == "__main__": 86 | print("Starting the script") 87 | args = parse_args() 88 | main(args) 89 | 90 | -------------------------------------------------------------------------------- /peft_pretraining/args_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import yaml 3 | from datetime import datetime 4 | 5 | from loguru import logger 6 | 7 | 8 | def check_args_torchrun_main(args): 9 | if args.training_config is not None: 10 | logger.info(f"Yaml config provided for the run. The file {args.training_config} is used to provide all the parameters.") 11 | if len(sys.argv) > 3: 12 | logger.error(f"argv length is {len(sys.argv)}") 13 | raise RuntimeError( 14 | "You provided both a yaml config and command line arguments. " 15 | "Please use only one of the two options." 16 | ) 17 | with open(args.training_config) as f: 18 | training_config = yaml.safe_load(f) 19 | for k, v in training_config.items(): 20 | if k == "lr": v = float(v) 21 | setattr(args, k, v) 22 | 23 | if (args.dataset_path is None) == (args.megatron_dataset_config is None): 24 | raise ValueError("Either --dataset_path or --megatron_dataset_config must be specified and not both\n" 25 | f"Got {args.dataset_path=} and {args.megatron_dataset_config=}") 26 | 27 | if args.megatron_dataset_config is not None: 28 | if not os.path.exists(args.megatron_dataset_config): 29 | raise ValueError(f"{args.megatron_dataset_config=} does not exist") 30 | 31 | if args.batch_size is None: 32 | raise ValueError("batch_size must be specified") 33 | 34 | if args.tags is not None: 35 | args.tags = args.tags.split(",") 36 | 37 | if not args.use_peft: 38 | # just for more clear hparam logging to wandb 39 | args.relora = None 40 | args.lora_r = None 41 | args.force_keep_original = False 42 | 43 | if args.total_batch_size is None: 44 | args.gradient_accumulation = args.gradient_accumulation or 1 45 | args.total_batch_size = args.batch_size * args.gradient_accumulation 46 | 47 | assert args.total_batch_size % args.batch_size == 0, "total_batch_size must be divisible by batch_size" 48 | 49 | if args.max_train_tokens is not None: 50 | args.num_training_steps = args.max_train_tokens // args.total_batch_size 51 | logger.info(f"Training for {args.num_training_steps} update steps") 52 | 53 | if args.warmed_up_model is not None: 54 | assert os.path.exists(args.warmed_up_model), f"{args.warmed_up_model=} does not exist" 55 | 56 | if args.dtype in ["fp16", "float16"]: 57 | raise NotImplementedError("fp16 is not supported in torchrun_main.py. Use deepspeed_main.py instead (but it seems to have bugs)") 58 | 59 | if (int(args.reset_optimizer_on_relora) + 60 | int(bool(args.optimizer_random_pruning)) + 61 | int(bool(args.optimizer_magnitude_pruning)) 62 | ) > 1: 63 | raise ValueError("reset_optimizer_on_relora, and keep_first_opt_rows are mutually exclusive") 64 | 65 | if args.relora and not args.use_peft: 66 | logger.warning("--relora assumes --use_peft. Setting --use_peft=True") 67 | args.use_peft = True 68 | 69 | assert 0 <= args.optimizer_random_pruning < 1, "--optimizer_random_pruning must be between 0 and 1" 70 | assert 0 <= args.optimizer_magnitude_pruning < 1, "--optimizer_magnitude_pruning must be between 0 and 1" 71 | 72 | 73 | if args.distributed_type == "fsdp" and args.weight_decay > 0: 74 | raise ValueError("FSDP does not support weight decay yet.") 75 | 76 | if args.distributed_type == "fsdp" and "zero" in args.optimizer: 77 | raise ValueError("FSDP does zero-optimization by default, do not specify optimizer as zero optimizer.") 78 | 79 | if args.skip_batches is not None: 80 | args.skip_batches = map(int, args.skip_batches.split(",")) 81 | args.skip_batches = set(args.skip_batches) 82 | logger.info(f"Skipping batches {args.skip_batches}") 83 | 84 | args.skip_batches = args.skip_batches or set() 85 | 86 | return args 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | profiler_logs 2 | **/*.sbatch 3 | checkpoints 4 | **/wandb/** 5 | wandb 6 | ignore 7 | experimental_data 8 | fine_tuning_results* 9 | log 10 | notebooks/*.pdf 11 | notebooks/*.png 12 | preprocessed_data 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | cover/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | .pybuilder/ 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | # For a library or package, you might want to ignore these files since the code is 100 | # intended to run in multiple environments; otherwise, check them in: 101 | # .python-version 102 | 103 | # pipenv 104 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 105 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 106 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 107 | # install all needed dependencies. 108 | #Pipfile.lock 109 | 110 | # poetry 111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 115 | #poetry.lock 116 | 117 | # pdm 118 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 119 | #pdm.lock 120 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 121 | # in version control. 122 | # https://pdm.fming.dev/#use-with-ide 123 | .pdm.toml 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .env 137 | .venv 138 | env/ 139 | venv/ 140 | ENV/ 141 | env.bak/ 142 | venv.bak/ 143 | 144 | # Spyder project settings 145 | .spyderproject 146 | .spyproject 147 | 148 | # Rope project settings 149 | .ropeproject 150 | 151 | # mkdocs documentation 152 | /site 153 | 154 | # mypy 155 | .mypy_cache/ 156 | .dmypy.json 157 | dmypy.json 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 162 | # pytype static type analyzer 163 | .pytype/ 164 | 165 | # Cython debug symbols 166 | cython_debug/ 167 | 168 | # PyCharm 169 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 170 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 171 | # and can be added to the global gitignore or merged into this file. For a more nuclear 172 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 173 | #.idea/ 174 | -------------------------------------------------------------------------------- /README.dev.md: -------------------------------------------------------------------------------- 1 | Some script to check that the most common training reigmes work. 2 | 3 | ``` 4 | torchrun --nproc-per-node 2 torchrun_main.py \ 5 | --dataset_path preprocessed_data/wikitext_wikitext-2-v1_EleutherAI_pythia-1.4b_512 \ 6 | --model_name_or_path EleutherAI/pythia-1.4b \ 7 | --use_peft \ 8 | --relora 10 \ 9 | --model_revision step1000 \ 10 | --batch_size 4 \ 11 | --total_batch_size 96 \ 12 | --lr 5e-4 \ 13 | --max_length 512 \ 14 | --eval_every 20 \ 15 | --save_every 20 \ 16 | --num_training_steps 40 \ 17 | --distributed_type ddp \ 18 | --optimizer adam_zero \ 19 | --tags debug 20 | 21 | 22 | torchrun --nproc-per-node 2 torchrun_main.py \ 23 | --dataset_path preprocessed_data/wikitext_wikitext-2-v1_EleutherAI_pythia-1.4b_512 \ 24 | --model_name_or_path EleutherAI/pythia-1.4b \ 25 | --model_revision step1000 \ 26 | --batch_size 6 \ 27 | --total_batch_size 96 \ 28 | --lr 5e-4 \ 29 | --max_length 512 \ 30 | --eval_every 2 \ 31 | --save_every 10 \ 32 | --num_training_steps 20 \ 33 | --distributed_type ddp \ 34 | --tags debug,fsdp_debug 35 | 36 | 37 | torchrun --nproc-per-node 2 torchrun_main.py \ 38 | --dataset_path preprocessed_data/wikitext_wikitext-2-v1_t5-base_512 \ 39 | --model_config configs/llama_250m.json \ 40 | --batch_size 24 \ 41 | --total_batch_size 96 \ 42 | --lr 5e-4 \ 43 | --max_length 512 \ 44 | --eval_every 2 \ 45 | --save_every 10 \ 46 | --num_training_steps 20 \ 47 | --distributed_type ddp \ 48 | --tags debug,fsdp_debug 49 | 50 | 51 | torchrun --nproc-per-node 2 torchrun_main.py \ 52 | --dataset_path preprocessed_data/wikitext_wikitext-2-v1_t5-base_512 \ 53 | --model_config configs/llama_250m.json \ 54 | --batch_size 24 \ 55 | --total_batch_size 96 \ 56 | --lr 5e-4 \ 57 | --max_length 512 \ 58 | --eval_every 2 \ 59 | --save_every 10 \ 60 | --num_training_steps 20 \ 61 | --distributed_type fsdp \ 62 | --tags debug,fsdp_debug 63 | 64 | 65 | torchrun --nproc-per-node 2 torchrun_main.py \ 66 | --dataset_path preprocessed_data/wikitext_wikitext-2-v1_gpt2_512 \ 67 | --model_config configs/llama_250m_50K.json \ 68 | --batch_size 24 \ 69 | --total_batch_size 96 \ 70 | --lr 5e-4 \ 71 | --max_length 512 \ 72 | --eval_every 2 \ 73 | --save_every 10 \ 74 | --num_training_steps 20 \ 75 | --distributed_type ddp \ 76 | --dtype float32 \ 77 | --tags debug,fsdp_debug 78 | 79 | 80 | torchrun --nproc-per-node 2 torchrun_main.py \ 81 | --model_config configs/llama_250m.json \ 82 | --batch_size 24 \ 83 | --total_batch_size 96 \ 84 | --lr 5e-4 \ 85 | --max_length 512 \ 86 | --eval_every 2 \ 87 | --save_every 10 \ 88 | --num_training_steps 20000 \ 89 | --distributed_type fsdp \ 90 | --tags debug,fsdp_debug 91 | 92 | 93 | torchrun --nproc-per-node 2 torchrun_main.py \ 94 | --model_config configs/llama_250m.json \ 95 | --batch_size 24 \ 96 | --total_batch_size 96 \ 97 | --lr 5e-4 \ 98 | --max_length 512 \ 99 | --eval_every 2 \ 100 | --save_every 10 \ 101 | --num_training_steps 20000 \ 102 | --distributed_type fsdp \ 103 | --tags debug,fsdp_debug 104 | 105 | 106 | torchrun --nproc-per-node 2 torchrun_main.py \ 107 | --model_config configs/llama_250m.json \ 108 | --batch_size 24 \ 109 | --total_batch_size 96 \ 110 | --lr 1e-3 \ 111 | --max_length 512 \ 112 | --use_peft \ 113 | --relora 10 \ 114 | --cycle_length 10 \ 115 | --restart_warmup_steps 5 \ 116 | --scheduler cosine_restarts \ 117 | --warmup_steps 5 \ 118 | --reset_optimizer_on_relora False \ 119 | --optimizer_magnitude_pruning 0.9 \ 120 | --num_training_steps 20000 \ 121 | --save_every 5000 \ 122 | --eval_every 5000 \ 123 | --warmed_up_model checkpoints/llama_250m-2023-06-09-11-29-56/model_5000 \ 124 | --distributed_type fsdp \ 125 | --tags debug,fsdp_debug 126 | 127 | 128 | torchrun --nproc-per-node 2 torchrun_main.py \ 129 | --model_config configs/llama_250m.json \ 130 | --batch_size 24 \ 131 | --total_batch_size 96 \ 132 | --lr 1e-3 \ 133 | --max_length 512 \ 134 | --use_peft \ 135 | --relora 10 \ 136 | --cycle_length 10 \ 137 | --restart_warmup_steps 5 \ 138 | --scheduler cosine_restarts \ 139 | --warmup_steps 5 \ 140 | --reset_optimizer_on_relora False \ 141 | --optimizer_magnitude_pruning 0.9 \ 142 | --num_training_steps 20000 \ 143 | --save_every 5000 \ 144 | --eval_every 5000 \ 145 | --warmed_up_model checkpoints/llama_250m-2023-06-09-11-29-56/model_5000 \ 146 | --distributed_type fsdp \ 147 | --tags debug,fsdp_debug 148 | 149 | ``` 150 | -------------------------------------------------------------------------------- /notebooks/03_scaling_laws_plotting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import pandas as pd\n", 11 | "\n", 12 | "import matplotlib.pyplot as plt" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "data = pd.read_csv(\"../experimental_data/wandb_export_2023-05-08T13_46_28.554-04_00.csv\")\n", 22 | "data = data[data[\"Name\"] != \"radiant-wind-116\"] # exploded\n", 23 | "data" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "full_models = data[1^data[\"use_peft\"]]\n", 33 | "peft_models = data[data[\"use_peft\"]]\n", 34 | "\n", 35 | "full_models_x = full_models[\"trainable_params_M\"]\n", 36 | "full_models_y = full_models[\"loss\"]\n", 37 | "\n", 38 | "peft_models_x = peft_models[\"trainable_params_M\"]\n", 39 | "peft_models_y = peft_models[\"loss\"]\n", 40 | "\n", 41 | "# figure out the scaling law\n", 42 | "# full_models_x = a * full_models_x ** b\n", 43 | "# peft_models_x = c * peft_models_x ** d\n", 44 | "from scipy.optimize import curve_fit\n", 45 | "\n", 46 | "def func(x, a, b):\n", 47 | " return a * x ** b\n", 48 | "\n", 49 | "full_popt, full_pcov = curve_fit(func, full_models_x, full_models_y)\n", 50 | "print(f\"Full Models: {full_popt}\")\n", 51 | "\n", 52 | "peft_popt, peft_pcov = curve_fit(func, peft_models_x, peft_models_y)\n", 53 | "print(f\"PEFT Models: {popt}\")\n", 54 | "\n", 55 | "plt.figure(figsize=(5, 5), dpi=150)\n", 56 | "plt.scatter(full_models_x, full_models_y, label=\"Full Models\")\n", 57 | "plt.scatter(peft_models_x, peft_models_y, label=\"PEFT Models\")\n", 58 | "plt.xlabel(\"Trainable Parameters (M)\")\n", 59 | "plt.ylabel(\"Loss\")\n", 60 | "plt.legend()\n", 61 | "plt.xscale(\"log\")\n", 62 | "plt.yscale(\"log\")\n", 63 | "\n", 64 | "# plot the scaling law\n", 65 | "x = np.linspace(5, 150, 100)\n", 66 | "plt.plot(x, func(x, *full_popt), label=\"Full Models\")\n", 67 | "plt.plot(x, func(x, *peft_popt), label=\"PEFT Models\")\n", 68 | "plt.legend()\n", 69 | "\n", 70 | "plt.show()" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# perform leave-one-out curve fitting to estimate confidence intervals\n", 80 | "\n", 81 | "full_models_coefficients = []\n", 82 | "for i in range(1, len(full_models_x)):\n", 83 | " x = full_models_x.tolist().copy()\n", 84 | " y = full_models_y.tolist().copy()\n", 85 | "\n", 86 | " x = x[:i] + x[i+1:]\n", 87 | " y = y[:i] + y[i+1:]\n", 88 | "\n", 89 | " assert len(x) == len(y) == 4\n", 90 | " popt, pcov = curve_fit(func, x, y)\n", 91 | " full_models_coefficients.append(popt)\n", 92 | "\n", 93 | "peft_models_coefficients = []\n", 94 | "for i in range(1, len(peft_models_x)):\n", 95 | " x = peft_models_x.tolist().copy()\n", 96 | " y = peft_models_y.tolist().copy()\n", 97 | "\n", 98 | " x = x[:i] + x[i+1:]\n", 99 | " y = y[:i] + y[i+1:]\n", 100 | "\n", 101 | " assert len(x) == len(y) == 4\n", 102 | " popt, pcov = curve_fit(func, x, y)\n", 103 | " peft_models_coefficients.append(popt)\n", 104 | "\n", 105 | "full_models_coefficients = np.array(full_models_coefficients)\n", 106 | "peft_models_coefficients = np.array(peft_models_coefficients)\n", 107 | "\n", 108 | "full_models_std = np.std(full_models_coefficients, axis=0)\n", 109 | "peft_models_stds = np.std(peft_models_coefficients, axis=0)\n", 110 | "\n", 111 | "full_models_mean = np.mean(full_models_coefficients, axis=0)\n", 112 | "peft_models_mean = np.mean(peft_models_coefficients, axis=0)\n", 113 | "\n", 114 | "print(f\"Full Models: {full_models_mean} +/- {full_models_std}\")\n", 115 | "print(f\"PEFT Models: {peft_models_mean} +/- {peft_models_stds}\")" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "base", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.10.9" 136 | }, 137 | "orig_nbformat": 4 138 | }, 139 | "nbformat": 4, 140 | "nbformat_minor": 2 141 | } 142 | -------------------------------------------------------------------------------- /peft_pretraining/dataloader.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import multiprocessing 3 | from itertools import chain 4 | 5 | import torch 6 | from torch.utils.data import BatchSampler, DataLoader, IterableDataset, get_worker_info 7 | from transformers import AutoTokenizer, default_data_collator 8 | from datasets import Dataset 9 | 10 | from loguru import logger 11 | 12 | 13 | class PreprocessedIterableDataset(IterableDataset): 14 | def __init__(self, data, tokenizer, batch_size, max_length): 15 | super().__init__() 16 | self.data = data 17 | self.tokenizer = tokenizer 18 | self.batch_size = batch_size 19 | self.max_length = max_length 20 | 21 | def __iter__(self): 22 | worker_info = get_worker_info() 23 | if worker_info is None: 24 | # If no worker_info is provided, we are not using DataLoader workers, so yield all data 25 | iter_data = iter(self.data) 26 | else: 27 | # If using DataLoader workers, yield a subset of the data for this worker 28 | worker_id = worker_info.id 29 | num_workers = worker_info.num_workers 30 | iter_data = itertools.islice(self.data, worker_id, None, num_workers) 31 | 32 | batch = [] 33 | for example in iter_data: 34 | tokenized_example = self.tokenizer( 35 | example["text"], 36 | max_length=self.max_length, 37 | truncation=True, 38 | padding="max_length", 39 | return_tensors="pt", 40 | ) 41 | batch.append(tokenized_example) 42 | 43 | if len(batch) == self.batch_size: 44 | yield self._format_batch(batch) 45 | batch = [] 46 | 47 | if batch: 48 | yield self._format_batch(batch) 49 | 50 | def _format_batch(self, batch): 51 | input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch]) 52 | attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch]) 53 | 54 | return {"input_ids": input_ids, "attention_mask": attention_mask} 55 | 56 | 57 | def tokenize_and_chunk( 58 | tokenizer: AutoTokenizer, 59 | dataset: Dataset, 60 | text_field: str, 61 | sequence_length: int, 62 | num_cpu: int = multiprocessing.cpu_count(), 63 | ): 64 | """ 65 | Build data loaders for training. 66 | 67 | This function performs the following steps: 68 | 1. Load the tokenizer from the pretrained "EleutherAI/gpt-neox-20b" model. 69 | 2. Load the "openwebtext" dataset. 70 | 3. Tokenize the dataset, adding the end-of-sentence token to each text. 71 | 4. Process the tokenized dataset into chunks of a specified block size. 72 | 73 | Returns: 74 | Dataset: The processed dataset ready for training. 75 | """ 76 | extra_map_kwargs = {"num_proc": num_cpu} # iterable dataset does not support workers in map 77 | if isinstance(dataset, IterableDataset): 78 | extra_map_kwargs = {} 79 | 80 | _len_pre = len(dataset) 81 | # check that text_field is in dataset 82 | tokenized_dataset = dataset.map( 83 | lambda example: tokenizer([t + tokenizer.eos_token for t in example[text_field]]), 84 | batched=True, 85 | remove_columns=[text_field], 86 | **extra_map_kwargs, 87 | ) 88 | assert "input_ids" in tokenized_dataset["train"].features 89 | assert len(tokenized_dataset["train"]) > 0 90 | logger.info(f"Tokenization finished") 91 | logger.info(f"\n{tokenized_dataset}") 92 | assert len(tokenized_dataset) == _len_pre 93 | 94 | block_size = sequence_length 95 | 96 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 97 | def group_texts(examples): 98 | # Concatenate all texts. 99 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 100 | 101 | total_length = len(concatenated_examples["input_ids"]) 102 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 103 | # customize this part to your needs. 104 | if total_length >= block_size: 105 | total_length = (total_length // block_size) * block_size 106 | # Split by chunks of max_len. 107 | result = { 108 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 109 | for k, t in concatenated_examples.items() 110 | if k != "attention_mask" # we never pad for LM, so it's best to minimize the dataset storage 111 | } 112 | return result 113 | 114 | remove_columns = ["attention_mask"] 115 | train_dataset = tokenized_dataset.map( 116 | group_texts, 117 | batched=True, 118 | remove_columns=remove_columns, 119 | **extra_map_kwargs, 120 | ) 121 | logger.info(f"Chunking finished") 122 | logger.info(f"\n{train_dataset}") 123 | 124 | return train_dataset 125 | 126 | 127 | # from https://github.com/huggingface/accelerate/blob/8514c35192ac9762920f1ab052e5cea4c0e46eeb/src/accelerate/data_loader.py#L816 128 | class SkipBatchSampler(BatchSampler): 129 | """ 130 | A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. 131 | """ 132 | 133 | def __init__(self, batch_sampler, skip_batches=0): 134 | self.batch_sampler = batch_sampler 135 | self.skip_batches = skip_batches 136 | 137 | def __iter__(self): 138 | for index, samples in enumerate(self.batch_sampler): 139 | if index >= self.skip_batches: 140 | yield samples 141 | 142 | @property 143 | def total_length(self): 144 | return len(self.batch_sampler) 145 | 146 | def __len__(self): 147 | return len(self.batch_sampler) - self.skip_batches 148 | 149 | 150 | class SkipDataLoader(DataLoader): 151 | """ 152 | Subclass of a PyTorch `DataLoader` that will skip the first batches. 153 | 154 | Args: 155 | dataset (`torch.utils.data.dataset.Dataset`): 156 | The dataset to use to build this datalaoder. 157 | skip_batches (`int`, *optional*, defaults to 0): 158 | The number of batches to skip at the beginning. 159 | kwargs: 160 | All other keyword arguments to pass to the regular `DataLoader` initialization. 161 | """ 162 | 163 | def __init__(self, dataset, skip_batches=0, **kwargs): 164 | super().__init__(dataset, **kwargs) 165 | self.skip_batches = skip_batches 166 | 167 | def __iter__(self): 168 | for index, batch in enumerate(super().__iter__()): 169 | if index >= self.skip_batches: 170 | yield batch 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReLoRA -- PEFT Pretraining 2 | > Official code for Stack More Layers Differently: High-Rank Training Through Low-Rank Updates https://arxiv.org/abs/2307.05695 3 | ReLoRA 4 | 5 | ## Setup 6 | 7 | Requires Python 3.10+ (due to param annotaitons style) and PyTorch 2.0+ (for flash attention). 8 | All requirements are listed in `requirements.txt` and kept up-to-date. 9 | 10 | ```bash 11 | pip install -e . 12 | pip install flash-attn 13 | ``` 14 | 15 | > We do not have flash attention in our requirements, because for some reason flash attention installation script requires torch and some other requirements to already be installed 16 | 17 | ## 1B training script 18 | 19 | The rule of thumb of selecting the learning rate I use for now is 2X regular training learning rate. 20 | It might require tuning on larger models. 21 | Microbatch size depends on the GPU memory and needs to be tuned to maximize the throughput. 22 | Note that relora allows to use larger microbatch sizes than regular training. 23 | 24 | Number of steps is 143K (Pythia) minus 10K, because we start from the checkpoint at 10K steps. 25 | Relora reset frequency is 5320 so that the number of steps is would be divisible by it. 26 | 27 | ```bash 28 | torchrun --nproc-per-node 8 --nnodes 1 torchrun_main.py --training_config training_configs/1B_v1.0.yaml 29 | ``` 30 | 31 | ## Usage 32 | 33 | Pre-process data (might take some time) 34 | 35 | ```bash 36 | python pretokenize.py \ 37 | --save_dir preprocessed_data \ 38 | --tokenizer t5-base \ 39 | --dataset c4 \ 40 | --dataset_config en \ 41 | --text_field text \ 42 | --sequence_length 512 43 | ``` 44 | 45 | The script will log where the pre-processed data is saved. It should be something like `preprocessed_data/__`. 46 | 47 | To train a model using ReLoRA, first, perform a warmup through regular training. 48 | 49 | ```bash 50 | export DATA_PATH= 51 | 52 | torchrun --nproc-per-node torchrun_main.py \ 53 | --model_config configs/llama_250m.json \ 54 | --dataset_path $DATA_PATH \ 55 | --batch_size 24 \ 56 | --total_batch_size 1152 \ 57 | --lr 5e-4 \ 58 | --max_length 512 \ 59 | --save_every 1000 \ 60 | --eval_every 1000 \ 61 | --num_training_steps 20000 62 | --tags warm_start_250M 63 | ``` 64 | 65 | > **Reproducibility note:** The way we ran the experiments in the paper was by specifying full num_training_steps, including both the warmup and the ReLoRA training, and stopping it after the desired number of steps was completed. Providing only the number of training steps should work too. The only difference will be the LR schedule during the warmup period. 66 | 67 | When you have a warmed-up network checkpoint, run the script with ReLoRA enabled. Note that we use a larger LR during the ReLoRA stage. 68 | 69 | Train with PEFT 70 | ```bash 71 | torchrun --nproc-per-node torchrun_main.py \ 72 | --model_config configs/llama_250m.json \ 73 | --batch_size 24 \ 74 | --total_batch_size 1152 \ 75 | --lr 1e-3 \ 76 | --max_length 512 \ 77 | --use_peft \ 78 | --relora 5000 \ 79 | --cycle_length 5000 \ 80 | --restart_warmup_steps 100 \ 81 | --scheduler cosine_restarts \ 82 | --warmup_steps 500 \ 83 | --reset_optimizer_on_relora True \ 84 | --num_training_steps 20000 \ 85 | --save_every 5000 \ 86 | --eval_every 5000 \ 87 | --warmed_up_model checkpoints/llama_250m-2023-06-09-11-29-56/model_5000 \ 88 | --tags relora_250M 89 | ``` 90 | 91 | 92 | 93 | ## Note on batch sizes 94 | 95 | To minimize the pain with multi-GPU setups, we recommend avoiding using `--gradient_accumulation` option directly. Instead, specify `--total_batch_size` and allow the script to figure out the gradient accumulation option based on `--batch_size` and the number of GPUs used. 96 | 97 | ## Relora 98 | 99 | Relora integrates existing LoRA parameters into the main network and resets them. 100 | In principle, such an approach can be more flexible than LoRA, but you need to be careful with 101 | 102 | 1. Optimizer states 103 | 2. Learning rate schedule during and right after the reset 104 | 3. How frequently you reset 105 | 106 | Reset frequency is determined by `--relora` parameter (in the number of update steps, not global steps). 107 | Optimizer reset options are: 108 | ``` 109 | "--reset_optimizer_on_relora", default=True, type=lambda x: x.lower() == "true" 110 | "--optimizer_random_pruning", default=False, type=float 111 | "--optimizer_magnitude_pruning", default=False, type=float 112 | ``` 113 | 114 | We found that using `--optimizer_magnitude_pruning 0.9` or plain `--reset_optimizer_on_relora` usually performs well. 115 | Note that `--reset_optimizer_on_relora is True by default` and you need to provide `--reset_optimizer_on_relora False --optimizer_magnitude_pruning 0.9` if you want to do magnitude pruning. 116 | 117 | ReLoRA currently only supports cosine decay learning rate scheduler. 118 | Specifically `cosine_restarts` that works in cyclical mode that repeats the warmup every `--cycle_length` update steps. 119 | 120 | ## Warm starts 121 | 122 | You can start LoRa from a partially trained checkpoint. To do that, provide `--warmed_up_model` option. For example: 123 | 124 | ``` 125 | torchrun torchrun_main.py ... .. --warmed_up_model checkpoints/llama_1b-2023-05-05-20-12-43/model_1000 126 | ``` 127 | 128 | ## Distributed training 129 | 130 | We support single-node distributed training using vanilla PyTorch DDP. 131 | | `main.py` script does not have all features required for relora and will be deleted soon. We recommend to use `torchrun --nproc-per-node 1` for a single-GPU training. 132 | 133 | An example of using torchrun 134 | ```bash 135 | torchrun --nproc-per-node 8 torchrun_main.py \ 136 | --model_config configs/llama_35m.json \ 137 | --use_peft \ 138 | --lora_r 128 \ 139 | --relora 500 \ 140 | --cycle_length 500 \ 141 | --warmup_steps 250 \ 142 | --reset_optimizer_on_relora False \ 143 | --lr 0.001 \ 144 | --batch_size 60 \ 145 | --total_batch_size 480 \ 146 | --num_training_steps 5000 \ 147 | --save_every 5000 \ 148 | --dtype bfloat16 \ 149 | --tags relora_debug,example 150 | ``` 151 | 152 | Where `--nproc-per-node` is the nubmer of GPUs you are using. 153 | 154 | ## Citation 155 | 156 | ``` 157 | @misc{lialin2023stack, 158 | title={Stack More Layers Differently: High-Rank Training Through Low-Rank Updates}, 159 | author={Vladislav Lialin and Namrata Shivagunde and Sherin Muckatira and Anna Rumshisky}, 160 | year={2023}, 161 | eprint={2307.05695}, 162 | archivePrefix={arXiv}, 163 | primaryClass={cs.CL} 164 | } 165 | ``` 166 | -------------------------------------------------------------------------------- /peft_pretraining/megatron_dataset/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, EleutherAI 2 | # This file is based on code by the authors denoted below and has been modified from its original version. 3 | # 4 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """Batch samplers that work with either random or sequential data samplers.""" 19 | 20 | import torch 21 | from torch.utils import data 22 | 23 | 24 | class RandomSampler(data.sampler.Sampler): 25 | """Based off of pytorch RandomSampler and DistributedSampler. Essentially 26 | a RandomSampler, but this class lets the user set an epoch like 27 | DistributedSampler Samples elements randomly. If without replacement, then 28 | sample from a shuffled dataset. If with replacement, then user can 29 | specify ``num_samples`` to draw. 30 | Arguments: 31 | data_source (Dataset): dataset to sample from 32 | num_samples (int): number of samples to draw, default=len(dataset) 33 | replacement (bool): samples are drawn with replacement if ``True``, 34 | default=False 35 | """ 36 | 37 | def __init__(self, data_source, replacement=False, num_samples=None): 38 | self.data_source = data_source 39 | self.replacement = replacement 40 | self._num_samples = num_samples 41 | self.epoch = -1 42 | 43 | if self._num_samples is not None and replacement is False: 44 | raise ValueError( 45 | "With replacement=False, num_samples should not " 46 | "be specified, since a random permute will be " 47 | "performed." 48 | ) 49 | 50 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 51 | raise ValueError( 52 | "num_samples should be a positive integer " 53 | "value, but got num_samples={}".format(self.num_samples) 54 | ) 55 | if not isinstance(self.replacement, bool): 56 | raise ValueError( 57 | "replacement should be a boolean value, but got " 58 | "replacement={}".format(self.replacement) 59 | ) 60 | 61 | @property 62 | def num_samples(self): 63 | # dataset size might change at runtime 64 | if self._num_samples is None: 65 | return len(self.data_source) 66 | return self._num_samples 67 | 68 | def __iter__(self): 69 | n = len(self.data_source) 70 | g = torch.Generator() 71 | if self.epoch >= 0: 72 | g.manual_seed(self.epoch) 73 | if self.replacement: 74 | return iter( 75 | torch.randint( 76 | high=n, size=(self.num_samples,), dtype=torch.int64, generator=g 77 | ).tolist() 78 | ) 79 | return iter(torch.randperm(n, generator=g).tolist()) 80 | 81 | def __len__(self): 82 | return self.num_samples 83 | 84 | def set_epoch(self, epoch): 85 | self.epoch = epoch 86 | 87 | 88 | class DistributedBatchSampler(data.sampler.BatchSampler): 89 | """Similar to normal implementation of distributed sampler, except 90 | implementation is at the batch sampler level, instead of just the 91 | sampler level. This allows wrapping of arbitrary data samplers 92 | (sequential, random, WeightedRandomSampler, etc.) with this batch 93 | sampler. 94 | 95 | The `interleave` argument specifies how to distribute a batch. A value 96 | of True combined with the above random sampler is equivalent to pytorch's 97 | torch.utils.data.distributed.DistributedSampler. 98 | 99 | For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2 100 | specifying True will result in the following samples for each gpu: 101 | GPU0: [0,2,4,6] GPU1: [1,3,5,7] 102 | specifying False will result in the following samples: 103 | GPU0: [0,1,2,3] GPU1: [4,5,6,7]""" 104 | 105 | def __init__( 106 | self, 107 | sampler, 108 | batch_size, 109 | drop_last, 110 | rank=-1, 111 | world_size=2, 112 | wrap_last=False, 113 | interleave=False, 114 | ): 115 | super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) 116 | if rank == -1: 117 | assert False, "should not be here" 118 | rank = torch.distributed.get_rank() 119 | self.rank = rank 120 | self.world_size = world_size 121 | self.sampler.wrap_around = 0 122 | self.wrap_around = 0 123 | self.wrap_last = wrap_last 124 | self.start_iter = 0 125 | self.interleave = interleave 126 | 127 | def __iter__(self): 128 | batch = [] 129 | i = 0 130 | for idx in self.data_iterator(self.sampler, wrap_around=False): 131 | batch.append(idx) 132 | if len(batch) == self.batch_size: 133 | tbatch = self._batch(batch) 134 | if i >= self.start_iter: 135 | yield tbatch 136 | self.start_iter = 0 137 | i += 1 138 | batch = [] 139 | batch_len = len(batch) 140 | if batch_len > 0 and not self.drop_last: 141 | if self.wrap_last: 142 | self.sampler.wrap_around -= self.batch_size 143 | self.wrap_around += len(batch) 144 | self.wrap_around %= self.batch_size 145 | yield self._batch(batch) 146 | if self.wrap_last: 147 | self.sampler.wrap_around += self.batch_size 148 | 149 | def data_iterator(self, _iter, wrap_around=False): 150 | """iterates through data and handles wrap around""" 151 | for i, idx in enumerate(_iter): 152 | if i < self.wrap_around % self.batch_size: 153 | continue 154 | if wrap_around: 155 | self.wrap_around += 1 156 | self.wrap_around %= self.batch_size 157 | yield idx 158 | 159 | def _batch(self, batch): 160 | """extracts samples only pertaining to this worker's batch""" 161 | if self.interleave: 162 | return batch[self.rank : self.batch_size : self.world_size] 163 | start = self.rank * self.batch_size // self.world_size 164 | end = (self.rank + 1) * self.batch_size // self.world_size 165 | return batch[start:end] 166 | -------------------------------------------------------------------------------- /notebooks/01_peft_pretraining.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "\n", 11 | "import transformers\n", 12 | "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n", 13 | "from peft import get_peft_model, LoraConfig, TaskType\n", 14 | "\n", 15 | "import datasets\n", 16 | "import wandb\n", 17 | "\n", 18 | "from tqdm import tqdm" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "BATCH_SIZE = 24\n", 28 | "MAX_LENGTH = 128\n", 29 | "\n", 30 | "data = datasets.load_dataset(\"c4\", \"en\", split=\"train\", streaming=True)\n", 31 | "data = data.shuffle(seed=42)\n", 32 | "\n", 33 | "tokenizer = AutoTokenizer.from_pretrained(\"t5-base\")\n", 34 | "\n", 35 | "def preprocess_batched(batch):\n", 36 | " batch = tokenizer(\n", 37 | " batch[\"text\"],\n", 38 | " max_length=MAX_LENGTH,\n", 39 | " truncation=True,\n", 40 | " padding=\"max_length\",\n", 41 | " return_tensors=\"pt\",\n", 42 | " )\n", 43 | " return batch\n", 44 | "\n", 45 | "data_mapped = data.map(preprocess_batched, batched=True, batch_size=1000, remove_columns=[\"text\", \"timestamp\", \"url\"])\n", 46 | "\n", 47 | "def collate_fn(batch_list):\n", 48 | " batch = {\n", 49 | " \"input_ids\": torch.stack([example[\"input_ids\"] for example in batch_list]),\n", 50 | " \"attention_mask\": torch.stack([example[\"attention_mask\"] for example in batch_list]),\n", 51 | " }\n", 52 | " return batch\n", 53 | "\n", 54 | "def batch_fn(dataset, batch_size):\n", 55 | " batch = []\n", 56 | " for example in dataset:\n", 57 | " batch.append(example)\n", 58 | " if len(batch) == batch_size:\n", 59 | " batch = collate_fn(batch)\n", 60 | " yield batch\n", 61 | " batch = []\n", 62 | " if len(batch) > 0:\n", 63 | " yield batch\n", 64 | "\n", 65 | "data_mapped.batch = lambda batch_size: batch_fn(data_mapped, batch_size)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "USE_PEFT = True\n", 75 | "TRAIN_LN = True\n", 76 | "NUM_TRAINING_STEPS = 10_000\n", 77 | "\n", 78 | "device = \"cuda:1\"\n", 79 | "\n", 80 | "model_config = AutoConfig.from_pretrained(\"gpt2-large\")\n", 81 | "model = AutoModelForCausalLM.from_config(model_config)\n", 82 | "\n", 83 | "if USE_PEFT:\n", 84 | " peft_config = LoraConfig(\n", 85 | " task_type=TaskType.CAUSAL_LM,\n", 86 | " inference_mode=False,\n", 87 | " r=8,\n", 88 | " lora_alpha=32,\n", 89 | " lora_dropout=0.1,\n", 90 | " )\n", 91 | "\n", 92 | " model = get_peft_model(peft_config, model)\n", 93 | "\n", 94 | " for name, param in model.named_parameters():\n", 95 | " if TRAIN_LN and \"ln_\" in name:\n", 96 | " param.requires_grad = True\n", 97 | " if \"lm_head\" in name:\n", 98 | " param.requires_grad = True\n", 99 | " if \"transformer.wte\" in name:\n", 100 | " param.requires_grad = True\n", 101 | " if \"transformer.wpe\" in name:\n", 102 | " param.requires_grad = True\n", 103 | "\n", 104 | " model.print_trainable_parameters()\n", 105 | "\n", 106 | "model = model.to(device)\n", 107 | "\n", 108 | "n_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 109 | "n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 110 | "p_trainable_params = n_trainable_params / n_total_params\n", 111 | "\n", 112 | "trainable_params = (p for p in model.parameters() if p.requires_grad)\n", 113 | "trainable_params_names = [name for name, p in model.named_parameters() if p.requires_grad]\n", 114 | "\n", 115 | "optimizer = torch.optim.Adam(trainable_params, lr=1e-4)\n", 116 | "scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1_000, num_training_steps=NUM_TRAINING_STEPS)\n", 117 | "\n", 118 | "_config = {\n", 119 | " \"using_peft\": USE_PEFT,\n", 120 | " \"layer_norm_trainable\": TRAIN_LN,\n", 121 | " \"peft_config\": peft_config.to_dict(),\n", 122 | " \"total_params\": n_total_params,\n", 123 | " \"trainable_params\": n_trainable_params,\n", 124 | " \"percent_trainable_params\": p_trainable_params,\n", 125 | " \"name_trainable_params\": trainable_params_names,\n", 126 | " \"dataset\": \"c4\",\n", 127 | " \"batch_size\": BATCH_SIZE,\n", 128 | " \"max_length\": MAX_LENGTH,\n", 129 | " \"model\": model_config.to_dict(),\n", 130 | " \"scheduler\": \"linear\",\n", 131 | " \"device\": str(device),\n", 132 | "}\n", 133 | "\n", 134 | "wandb.init(project=\"peft_pretraining\", config=_config)\n", 135 | "pbar = tqdm(total=NUM_TRAINING_STEPS)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "model.base_model.transformer.wte.weight.requires_grad" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "for epoch in range(1):\n", 154 | " data_mapped.set_epoch(epoch)\n", 155 | " for batch in data_mapped.batch(batch_size=BATCH_SIZE):\n", 156 | " pbar.update(1)\n", 157 | " optimizer.zero_grad()\n", 158 | "\n", 159 | " batch = {k: v.to(device) for k, v in batch.items()}\n", 160 | " labels = batch[\"input_ids\"].clone()\n", 161 | " labels[labels == 0] = -100\n", 162 | "\n", 163 | " loss = model(**batch, labels=labels).loss\n", 164 | " loss.backward()\n", 165 | " optimizer.step()\n", 166 | " scheduler.step()\n", 167 | "\n", 168 | " lr = scheduler.get_last_lr()[0]\n", 169 | " wandb.log({\n", 170 | " \"loss\": loss.item(),\n", 171 | " \"lr\": lr,\n", 172 | " })" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "base", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.7.4" 200 | }, 201 | "orig_nbformat": 4 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 2 205 | } 206 | -------------------------------------------------------------------------------- /notebooks/06_svd.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "\n", 11 | "\n", 12 | "def svd_internal_dimensionality_reduction(tensor, num_components):\n", 13 | " \"\"\"\n", 14 | " Performs SVD dimensionality reduction, but returns the full tensor instead of just the reduced components.\n", 15 | " \"\"\"\n", 16 | " u, s, v = torch.svd(tensor)\n", 17 | " return torch.matmul(u[:, :num_components] * s[:num_components], v[:, :num_components].T)\n" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "A = torch.randn(10, 3)\n", 27 | "B = torch.randn(3, 10)\n", 28 | "\n", 29 | "C_low_rank = torch.matmul(A, B)\n", 30 | "C_full_rank = torch.randn(10, 10)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "C_full_rank_svd = svd_internal_dimensionality_reduction(C_full_rank, 3)\n", 40 | "C_full_rank_svd.shape" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "torch.svd(C_full_rank_svd).S" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "torch.svd(C_low_rank).S" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "import torch\n", 68 | "\n", 69 | "def random_projection_dim_reduction(tensor, target_dim):\n", 70 | " \"\"\"\n", 71 | " Performs random projection dimensionality reduction according to the Johnson-Lindenstrauss lemma.\n", 72 | " Only reduces the inner dimensionality, does not affect the shape of the tensor\n", 73 | " \"\"\"\n", 74 | " original_dtype = tensor.dtype\n", 75 | " original_shape = tensor.shape\n", 76 | " tensor = tensor.to(dtype=torch.float32)\n", 77 | "\n", 78 | " # generate a random matrix with entries drawn from a normal distribution\n", 79 | " random_matrix = torch.randn(tensor.shape[-1], target_dim, dtype=torch.float32, device=tensor.device)\n", 80 | " random_matrix /= torch.norm(random_matrix, dim=0, keepdim=True)\n", 81 | "\n", 82 | " # project the tensor onto the random matrix, shape should not change\n", 83 | " new_matrix = torch.matmul(tensor, random_matrix).to(dtype=original_dtype)\n", 84 | " assert new_matrix.shape == original_shape\n", 85 | " return new_matrix\n" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "A = torch.randn(100, 10)\n", 95 | "\n", 96 | "B = random_projection_dim_reduction(A, 2)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "B.shape" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "import torch\n", 115 | "import matplotlib.pyplot as plt\n", 116 | "\n", 117 | "@torch.no_grad()\n", 118 | "def random_pruning(tensor, prune_ratio):\n", 119 | " \"\"\"\n", 120 | " Performs random pruning dimensionality reduction.\n", 121 | " Only reduces the inner dimensionality, does not affect the shape of the tensor\n", 122 | " \"\"\"\n", 123 | " random_pruning_mask = torch.rand_like(tensor) > prune_ratio\n", 124 | " tensor = tensor * random_pruning_mask\n", 125 | " return tensor\n", 126 | "\n", 127 | "# Create a 2D tensor with random values\n", 128 | "tensor = torch.rand((10, 10))\n", 129 | "\n", 130 | "# Define a list of pruning ratios\n", 131 | "prune_ratios = [0.1, 0.3, 0.5, 0.7, 0.9]\n", 132 | "\n", 133 | "# Initialize a figure\n", 134 | "fig, axs = plt.subplots(1, len(prune_ratios)+1, figsize=(20, 5))\n", 135 | "\n", 136 | "# Plot the original tensor\n", 137 | "axs[0].imshow(tensor.numpy(), cmap='viridis')\n", 138 | "axs[0].set_title('Original Tensor')\n", 139 | "\n", 140 | "# Apply pruning for each ratio and plot the resulting tensors\n", 141 | "for i, prune_ratio in enumerate(prune_ratios):\n", 142 | " pruned_tensor = random_pruning(tensor.clone(), prune_ratio)\n", 143 | " axs[i+1].imshow(pruned_tensor.numpy(), cmap='viridis')\n", 144 | " axs[i+1].set_title(f'Pruned Tensor (ratio = {prune_ratio})')\n", 145 | "\n", 146 | "# Display the plot\n", 147 | "plt.show()\n" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "import torch\n", 157 | "import matplotlib.pyplot as plt\n", 158 | "\n", 159 | "@torch.no_grad()\n", 160 | "def magnitude_pruning(tensor, prune_ratio):\n", 161 | " \"\"\"\n", 162 | " Performs magnitude pruning dimensionality reduction.\n", 163 | " Only reduces the inner dimensionality, does not affect the shape of the tensor\n", 164 | " \"\"\"\n", 165 | " tensor_magnitude = torch.abs(tensor)\n", 166 | " threshold = torch.quantile(tensor_magnitude.flatten(), prune_ratio)\n", 167 | "\n", 168 | " mask = tensor_magnitude > threshold\n", 169 | " tensor = tensor * mask.to(dtype=tensor.dtype)\n", 170 | " return tensor\n", 171 | "\n", 172 | "# Create a 2D tensor with random values\n", 173 | "tensor = torch.rand((10, 10))\n", 174 | "\n", 175 | "# Define a list of pruning ratios\n", 176 | "prune_ratios = [0.1, 0.3, 0.5, 0.7, 0.9]\n", 177 | "\n", 178 | "# Initialize a figure\n", 179 | "fig, axs = plt.subplots(1, len(prune_ratios)+1, figsize=(20, 5))\n", 180 | "\n", 181 | "# Plot the original tensor\n", 182 | "axs[0].imshow(tensor.numpy(), cmap='viridis')\n", 183 | "axs[0].set_title('Original Tensor')\n", 184 | "\n", 185 | "# Apply pruning for each ratio and plot the resulting tensors\n", 186 | "for i, prune_ratio in enumerate(prune_ratios):\n", 187 | " pruned_tensor = magnitude_pruning(tensor.clone(), prune_ratio)\n", 188 | " axs[i+1].imshow(pruned_tensor.numpy(), cmap='viridis')\n", 189 | " axs[i+1].set_title(f'Pruned Tensor (ratio = {prune_ratio})')\n", 190 | "\n", 191 | "# Display the plot\n", 192 | "plt.show()\n" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [] 201 | } 202 | ], 203 | "metadata": { 204 | "kernelspec": { 205 | "display_name": "base", 206 | "language": "python", 207 | "name": "python3" 208 | }, 209 | "language_info": { 210 | "codemirror_mode": { 211 | "name": "ipython", 212 | "version": 3 213 | }, 214 | "file_extension": ".py", 215 | "mimetype": "text/x-python", 216 | "name": "python", 217 | "nbconvert_exporter": "python", 218 | "pygments_lexer": "ipython3", 219 | "version": "3.10.9" 220 | }, 221 | "orig_nbformat": 4 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 2 225 | } 226 | -------------------------------------------------------------------------------- /notebooks/10_chunking.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 18, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import itertools\n", 10 | "import multiprocessing\n", 11 | "from itertools import chain\n", 12 | "\n", 13 | "import torch\n", 14 | "from torch.utils.data import IterableDataset, get_worker_info\n", 15 | "from transformers import AutoTokenizer, default_data_collator\n", 16 | "from datasets import Dataset\n", 17 | "\n", 18 | "\n", 19 | "def tokenize_and_chunk(\n", 20 | " tokenizer: AutoTokenizer,\n", 21 | " dataset: Dataset,\n", 22 | " text_field: str,\n", 23 | " sequence_length: int,\n", 24 | " num_cpu: int = multiprocessing.cpu_count(),\n", 25 | "):\n", 26 | " \"\"\"\n", 27 | " Build data loaders for training.\n", 28 | "\n", 29 | " This function performs the following steps:\n", 30 | " 1. Load the tokenizer from the pretrained \"EleutherAI/gpt-neox-20b\" model.\n", 31 | " 2. Load the \"openwebtext\" dataset.\n", 32 | " 3. Tokenize the dataset, adding the end-of-sentence token to each text.\n", 33 | " 4. Process the tokenized dataset into chunks of a specified block size.\n", 34 | "\n", 35 | " Returns:\n", 36 | " Dataset: The processed dataset ready for training.\n", 37 | " \"\"\"\n", 38 | " extra_map_kwargs = {\"num_proc\": num_cpu}\n", 39 | " if isinstance(dataset, IterableDataset):\n", 40 | " extra_map_kwargs = {}\n", 41 | "\n", 42 | " current_columns = dataset.column_names\n", 43 | " tokenized_dataset = dataset.map(\n", 44 | " lambda example: tokenizer([t + tokenizer.eos_token for t in example[text_field]]),\n", 45 | " batched=True,\n", 46 | " remove_columns=current_columns,\n", 47 | " **extra_map_kwargs,\n", 48 | " )\n", 49 | "\n", 50 | " block_size = sequence_length\n", 51 | "\n", 52 | " # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.\n", 53 | " def group_texts(examples):\n", 54 | " # Concatenate all texts.\n", 55 | " concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}\n", 56 | " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", 57 | " # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n", 58 | " # customize this part to your needs.\n", 59 | " if total_length >= block_size:\n", 60 | " total_length = (total_length // block_size) * block_size\n", 61 | " # Split by chunks of max_len.\n", 62 | " result = {\n", 63 | " k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n", 64 | " for k, t in concatenated_examples.items()\n", 65 | " }\n", 66 | " return result\n", 67 | "\n", 68 | " train_dataset = tokenized_dataset.map(\n", 69 | " group_texts,\n", 70 | " batched=True,\n", 71 | " **extra_map_kwargs,\n", 72 | " )\n", 73 | "\n", 74 | " return train_dataset\n" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 19, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "import torch.utils.data\n", 84 | "import datasets\n", 85 | "max_length = 2048\n", 86 | "\n", 87 | "data = datasets.load_dataset(\"c4\", \"en\", split=\"train\", streaming=True)\n", 88 | "tokenizer = AutoTokenizer.from_pretrained(\"t5-base\", model_max_length=max_length)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 20, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "dataset = tokenize_and_chunk(tokenizer, data, text_field=\"text\", sequence_length=max_length, num_cpu=None)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 24, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=2, collate_fn=default_data_collator)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 25, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": [ 118 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 119 | "To disable this warning, you can either:\n", 120 | "\t- Avoid using `tokenizers` before the fork if possible\n", 121 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" 122 | ] 123 | }, 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 129 | "To disable this warning, you can either:\n", 130 | "\t- Avoid using `tokenizers` before the fork if possible\n", 131 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "batch = next(iter(dataloader))" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 26, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "name": "stderr", 146 | "output_type": "stream", 147 | "text": [ 148 | "/tmp/ipykernel_3513881/2323961505.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", 149 | " torch.tensor(batch[\"input_ids\"])\n" 150 | ] 151 | }, 152 | { 153 | "data": { 154 | "text/plain": [ 155 | "tensor([[12847, 277, 15068, ..., 8, 414, 13],\n", 156 | " [ 336, 471, 5, ..., 28, 46, 1287],\n", 157 | " [ 6, 9445, 8424, ..., 45, 8, 814],\n", 158 | " ...,\n", 159 | " [ 21, 8, 471, ..., 979, 16, 112],\n", 160 | " [23659, 774, 5, ..., 19, 92, 46],\n", 161 | " [ 256, 11577, 412, ..., 112, 372, 28]])" 162 | ] 163 | }, 164 | "execution_count": 26, 165 | "metadata": {}, 166 | "output_type": "execute_result" 167 | } 168 | ], 169 | "source": [ 170 | "import torch\n", 171 | "\n", 172 | "torch.tensor(batch[\"input_ids\"])" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "peft_pretraining_shala", 186 | "language": "python", 187 | "name": "python3" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.10.11" 200 | }, 201 | "orig_nbformat": 4 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 2 205 | } 206 | -------------------------------------------------------------------------------- /notebooks/11_test_pythia.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/fsx/vlialin/relora/.venv/lib/python3.10/site-packages/bitsandbytes/cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n", 13 | " warn(\"The installed version of bitsandbytes was compiled without GPU support. \"\n" 14 | ] 15 | }, 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "/fsx/vlialin/relora/.venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "import torch\n", 26 | "\n", 27 | "from transformers import AutoTokenizer, GPTNeoXForCausalLM as HF_GPTNeoXForCausalLM\n", 28 | "# from optimum.bettertransformer import BetterTransformer\n", 29 | "\n", 30 | "%load_ext autoreload\n", 31 | "%autoreload 2\n", 32 | "from peft_pretraining.modeling_pythia import GPTNeoXForCausalLM" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "model_name = \"EleutherAI/pythia-1b\"\n", 42 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 43 | "input_ids = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\").input_ids" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "flashattention_model = GPTNeoXForCausalLM.from_pretrained(model_name)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "flashattention_out = flashattention_model(input_ids)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 5, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "orig_model = HF_GPTNeoXForCausalLM.from_pretrained(model_name)\n", 71 | "# orig_model = BetterTransformer.transform(orig_model)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 6, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "orig_out = orig_model(input_ids)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 7, 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "data": { 90 | "text/plain": [ 91 | "tensor([[[ 3.6091, -9.3721, 8.7746, ..., -9.2669, -9.2728, -9.3494],\n", 92 | " [ 1.8715, -10.9088, 5.3397, ..., -10.6625, -10.8846, -10.8984],\n", 93 | " [ 1.1809, -10.6320, 6.5040, ..., -10.4274, -10.7300, -10.4196],\n", 94 | " [ 5.8298, -8.4654, 12.3902, ..., -8.4833, -8.6062, -8.5475],\n", 95 | " [ 3.7810, -9.6519, 6.0415, ..., -9.6365, -9.8250, -9.5622],\n", 96 | " [ 8.1836, -8.4023, 14.3205, ..., -8.5621, -8.4343, -8.4242]]],\n", 97 | " grad_fn=)" 98 | ] 99 | }, 100 | "execution_count": 7, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "orig_out.logits" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 8, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "tensor([[[ 3.6091, -9.3721, 8.7746, ..., -9.2669, -9.2728, -9.3494],\n", 118 | " [ 1.8715, -10.9088, 5.3397, ..., -10.6625, -10.8846, -10.8984],\n", 119 | " [ 1.1809, -10.6320, 6.5040, ..., -10.4274, -10.7300, -10.4196],\n", 120 | " [ 5.8298, -8.4654, 12.3902, ..., -8.4833, -8.6062, -8.5475],\n", 121 | " [ 3.7810, -9.6519, 6.0415, ..., -9.6365, -9.8250, -9.5622],\n", 122 | " [ 8.1836, -8.4023, 14.3205, ..., -8.5621, -8.4343, -8.4242]]],\n", 123 | " grad_fn=)" 124 | ] 125 | }, 126 | "execution_count": 8, 127 | "metadata": {}, 128 | "output_type": "execute_result" 129 | } 130 | ], 131 | "source": [ 132 | "flashattention_out.logits" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 10, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "data": { 142 | "text/plain": [ 143 | "True" 144 | ] 145 | }, 146 | "execution_count": 10, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "torch.allclose(orig_out.logits, flashattention_out.logits, atol=1e-5)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 22, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stderr", 169 | "output_type": "stream", 170 | "text": [ 171 | "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", 172 | "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n", 173 | "/mnt/shared_home/vlialin/miniconda3/envs/peft_pretraining_shala/lib/python3.10/site-packages/transformers/generation/utils.py:1369: UserWarning: Using `max_length`'s default (20) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n", 174 | " warnings.warn(\n" 175 | ] 176 | }, 177 | { 178 | "data": { 179 | "text/plain": [ 180 | "'Hello, my dog is cute, but he\\'s not a dog. He\\'s a cat.\"\\n'" 181 | ] 182 | }, 183 | "execution_count": 22, 184 | "metadata": {}, 185 | "output_type": "execute_result" 186 | } 187 | ], 188 | "source": [ 189 | "out_gen = orig_model.generate(input_ids)\n", 190 | "tokenizer.decode(out_gen[0].tolist())" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 18, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stderr", 200 | "output_type": "stream", 201 | "text": [ 202 | "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", 203 | "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n" 204 | ] 205 | }, 206 | { 207 | "data": { 208 | "text/plain": [ 209 | "'Hello, my dog is cute, I am. I am. I am. I am. I'" 210 | ] 211 | }, 212 | "execution_count": 18, 213 | "metadata": {}, 214 | "output_type": "execute_result" 215 | } 216 | ], 217 | "source": [ 218 | "out_gen = flashattention_model.generate(input_ids)\n", 219 | "tokenizer.decode(out_gen[0].tolist())" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "peft_pretraining_shala", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.10.12" 247 | }, 248 | "orig_nbformat": 4 249 | }, 250 | "nbformat": 4, 251 | "nbformat_minor": 2 252 | } 253 | -------------------------------------------------------------------------------- /notebooks/14_check_pretokenization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/mnt/shared_home/vlialin/miniconda3/envs/peft_pretraining_shala/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "from datasets import load_dataset, load_from_disk" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stderr", 28 | "output_type": "stream", 29 | "text": [ 30 | "Found cached dataset c4 (/mnt/shared_home/hf_cache/datasets_cache/c4/realnewslike/0.0.0/df532b158939272d032cc63ef19cd5b83e9b4d00c922b833e4cb18b2e9869b01)\n", 31 | "100%|██████████| 2/2 [00:26<00:00, 13.03s/it]" 32 | ] 33 | }, 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "2\n" 39 | ] 40 | }, 41 | { 42 | "name": "stderr", 43 | "output_type": "stream", 44 | "text": [ 45 | "\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "data = load_dataset(\"c4\", \"realnewslike\")\n", 51 | "print(\"2\")\n", 52 | "data_preprocessed = load_from_disk(\"../preprocessed_data/c4_realnewslike_t5-base_512\")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 26, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stderr", 62 | "output_type": "stream", 63 | "text": [ 64 | "/mnt/shared_home/vlialin/miniconda3/envs/peft_pretraining_shala/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", 65 | "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", 66 | "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n", 67 | "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", 68 | "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n", 69 | " warnings.warn(\n" 70 | ] 71 | } 72 | ], 73 | "source": [ 74 | "from transformers import AutoTokenizer\n", 75 | "tokenizer = AutoTokenizer.from_pretrained(\"t5-base\", use_fast=True)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 5, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/plain": [ 86 | "DatasetDict({\n", 87 | " train: Dataset({\n", 88 | " features: ['text', 'timestamp', 'url'],\n", 89 | " num_rows: 13799838\n", 90 | " })\n", 91 | " validation: Dataset({\n", 92 | " features: ['text', 'timestamp', 'url'],\n", 93 | " num_rows: 13863\n", 94 | " })\n", 95 | "})" 96 | ] 97 | }, 98 | "execution_count": 5, 99 | "metadata": {}, 100 | "output_type": "execute_result" 101 | } 102 | ], 103 | "source": [ 104 | "data" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 6, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "data": { 114 | "text/plain": [ 115 | "DatasetDict({\n", 116 | " train: Dataset({\n", 117 | " features: ['timestamp', 'url', 'input_ids'],\n", 118 | " num_rows: 538176\n", 119 | " })\n", 120 | " validation: Dataset({\n", 121 | " features: ['timestamp', 'url', 'input_ids'],\n", 122 | " num_rows: 528\n", 123 | " })\n", 124 | "})" 125 | ] 126 | }, 127 | "execution_count": 6, 128 | "metadata": {}, 129 | "output_type": "execute_result" 130 | } 131 | ], 132 | "source": [ 133 | "data_preprocessed" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 29, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "name": "stdout", 143 | "output_type": "stream", 144 | "text": [ 145 | "\n", 146 | "Netflix Inc said it added 8.\n", 147 | "84 million paid global streaming subscribers in the fourth quarter, while analysts had expected 9.\n", 148 | "18 million net global streaming additions.\n", 149 | " It was not immediately clear if analysts were excluding unpaid additions.\n", 150 | " Shares of the company were down 3 percent in after-hours trading.\n", 151 | "\n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "def add_eol(text):\n", 157 | " return \".\\n\".join(text.split('.'))\n", 158 | "\n", 159 | "_id = -1\n", 160 | "print()\n", 161 | "print(add_eol(data[\"train\"][_id][\"text\"]))" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 28, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "DIII News All-Freshman Team and didn't miss a game after breaking her nose in the middle of the schedule.\n", 174 | " Her nose lost the battle.\n", 175 | " Temple won the war.\n", 176 | " \"After the knee injuries, yeah, that wasn't really anything,\" Temple said.\n", 177 | " \"I just wore a mask for two weeks and sat out a couple practices and I was ready to go.\n", 178 | "\" Courtney L.\n", 179 | " Blankenship, 25, Bucyrus, theft, fined $323, sentenced to 30 days in jail with all suspended.\n", 180 | " Chandler L.\n", 181 | " Lust, 18, Bucyrus, abuse of an intoxicant, fined $308, sentenced to 90 days in jail with all suspended; drug paraphernalia, sentenced to 30 days in jail with all suspended.\n", 182 | " Jessica K.\n", 183 | " Wade, 33, Bloomville, criminal damages, fined $325, sentenced to 30 days in jail with all suspended.\n", 184 | " Leslee D.\n", 185 | " Baxter, 31, Bucyrus, physical control, fined $1,000, sentenced to 90 days in jail with all suspended.\n", 186 | " Roger L.\n", 187 | " Boudinot, 65, Bucyrus, operating a vehicle under the influence, fined $625, sentenced to 30 days in jail with 27 suspended, driver’s license suspended for six months.\n", 188 | " Casey Taylor, 30, Galion, dog running at large, fined $180.\n", 189 | " Travis D.\n", 190 | " Lozier, 22, Galion, possession of a controlled substance, fined $155.\n", 191 | " Phillip M.\n", 192 | " Tesso Jr.\n", 193 | ", 24, Crestline, disorderly conduct, fined $155.\n", 194 | " Zachary A.\n", 195 | " Fout, 23, Galion, possession of marijuana, fined $155.\n", 196 | " Thelma J.\n", 197 | " Snyder, 78, Galion, violation of grade crossing, fined $150.\n", 198 | " Daniel E.\n", 199 | " Moore, 68, Bucyrus, violation of grade crossing, fined $150.\n", 200 | " Damon Schramek, 45, Galion, violation of grade crossing, fined $150.\n", 201 | " Natalie E.\n", 202 | " Davis, 37, Galion, driving under suspension, fined $280, sentenced to 30 days in jail with all suspended.\n", 203 | " David M.\n", 204 | " Corbett, 25, Crestline, driving under suspension, fined $230.\n", 205 | " Michael W.\n", 206 | " Conley, 39, Crestline, driving under suspension, fined $225.\n", 207 | " Daniel R.\n", 208 | " Smith, 37, Galion, driving under suspension, fined $225.\n", 209 | " Anisha\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "from tqdm.auto import tqdm\n", 215 | "\n", 216 | "# for _id in tqdm(range(538176)):\n", 217 | "_id = -1\n", 218 | "input_ids = data_preprocessed[\"train\"][_id][\"input_ids\"]\n", 219 | "# print(type(input_ids), len(input_ids))\n", 220 | "decoded = add_eol(tokenizer.decode(input_ids))\n", 221 | "print(decoded)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [] 230 | } 231 | ], 232 | "metadata": { 233 | "kernelspec": { 234 | "display_name": "peft_pretraining_shala", 235 | "language": "python", 236 | "name": "python3" 237 | }, 238 | "language_info": { 239 | "codemirror_mode": { 240 | "name": "ipython", 241 | "version": 3 242 | }, 243 | "file_extension": ".py", 244 | "mimetype": "text/x-python", 245 | "name": "python", 246 | "nbconvert_exporter": "python", 247 | "pygments_lexer": "ipython3", 248 | "version": "3.10.11" 249 | }, 250 | "orig_nbformat": 4 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 2 254 | } 255 | -------------------------------------------------------------------------------- /notebooks/09_bar_plot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import matplotlib.pyplot as plt\n", 10 | "import numpy as np\n", 11 | "\n", 12 | "# Data\n", 13 | "relora_delta = [5693, 5723, 6050, 4942]\n", 14 | "full_models_delta = [1173, 1203, 2641, 0]\n", 15 | "projections = ['Q Projections', 'K Projections', 'V Projections', 'Down Projections']\n", 16 | "\n", 17 | "# Custom color palette\n", 18 | "colors = ['#E69F00', '#56B4E9', '#F0E442', '#009E73']\n", 19 | "\n", 20 | "# Create the bar plot\n", 21 | "fig, ax = plt.subplots(figsize=(6, 4))\n", 22 | "bar_width = 0.35\n", 23 | "index = np.arange(len(projections))\n", 24 | "\n", 25 | "relora = ax.bar(index, relora_delta, bar_width, label='ReLoRA', color=colors[0])\n", 26 | "full_models = ax.bar(index + bar_width, full_models_delta, bar_width, label='Full Models', color=colors[1])\n", 27 | "\n", 28 | "# Add labels and titles\n", 29 | "ax.set_xlabel('Projections')\n", 30 | "ax.set_ylabel('Delta')\n", 31 | "ax.set_title('Delta Comparison between ReLoRA and Full Models')\n", 32 | "ax.set_xticks(index + bar_width / 2)\n", 33 | "ax.set_xticklabels(projections)\n", 34 | "ax.legend()\n", 35 | "\n", 36 | "# Adjust spacing\n", 37 | "plt.tight_layout()\n", 38 | "\n", 39 | "# Save or display the plot\n", 40 | "plt.savefig('bar_plot.png', dpi=300)\n", 41 | "plt.show()" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "import matplotlib.pyplot as plt\n", 51 | "\n", 52 | "# Data for the bar plot\n", 53 | "labels = ['ReLoRA', 'Full-Rank', 'LoRA']\n", 54 | "q_projections = [5693, 1173, 7680]\n", 55 | "k_projections = [5723, 1203, 7680]\n", 56 | "v_projections = [6050, 2641, 7681]\n", 57 | "down_projections = [4942, 0, 7680]\n", 58 | "\n", 59 | "# Custom color palette\n", 60 | "colors = ['#FF5C5C', '#FFB157', '#2EC4B6']\n", 61 | "\n", 62 | "# Plotting the bar plot\n", 63 | "fig, ax = plt.subplots(figsize=(8, 4))\n", 64 | "width = 0.2\n", 65 | "\n", 66 | "ax.bar(labels, q_projections, width, label='Q Projections', color=colors[0])\n", 67 | "ax.bar(labels, k_projections, width, label='K Projections', color=colors[1], bottom=q_projections)\n", 68 | "ax.bar(labels, v_projections, width, label='V Projections', color=colors[2], bottom=[q + k for q, k in zip(q_projections, k_projections)])\n", 69 | "ax.bar(labels, down_projections, width, label='Down Projections', color=colors[0], alpha=0.5, hatch='/')\n", 70 | "\n", 71 | "# Adding labels and title\n", 72 | "ax.set_ylabel('Number of Singular Values < 0.1')\n", 73 | "ax.set_title('Singular Values Comparison')\n", 74 | "ax.legend()\n", 75 | "\n", 76 | "# Adjusting layout and saving the plot\n", 77 | "plt.tight_layout()\n", 78 | "plt.savefig('singular_values_plot.png')\n", 79 | "plt.show()" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "import matplotlib.pyplot as plt\n", 89 | "import numpy as np\n", 90 | "\n", 91 | "# Data\n", 92 | "methods = ['ReLoRA', 'Full-rank Training', 'LoRA']\n", 93 | "projections = ['$W_Q$', '$W_K$', '$W_V$', '$W_{up}$', '$W_{down}$']\n", 94 | "counts = [5693, 1173, 7680, 5723, 1203, 7680, 6050, 2641, 7681, 4884, 0, 7680, 4942, 0, 7680]\n", 95 | "\n", 96 | "# Custom color palette\n", 97 | "# colors = ['#E64B35', '#4DBBD5', '#00A087', '#3C5488', '#F39B7F']\n", 98 | "\n", 99 | "# Plotting\n", 100 | "fig, ax = plt.subplots(figsize=(5, 3))\n", 101 | "\n", 102 | "width = 0.2\n", 103 | "x = np.arange(len(projections))\n", 104 | "\n", 105 | "for i, method in enumerate(methods):\n", 106 | " bars = ax.bar(x + i * width, counts[i::3], width, label=method, alpha=0.7)\n", 107 | "\n", 108 | "# Customize the plot\n", 109 | "ax.set_xticks(x + width)\n", 110 | "ax.set_xticklabels(projections)\n", 111 | "ax.set_ylabel('#Singular Values < 0.1')\n", 112 | "# ax.set_xlabel('Projections')\n", 113 | "ax.legend()\n", 114 | "\n", 115 | "# Adjust layout\n", 116 | "plt.tight_layout()\n", 117 | "\n", 118 | "# Save or display the plot\n", 119 | "plt.savefig('bar_plot.png', dpi=300)\n", 120 | "plt.show()" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "import matplotlib.pyplot as plt\n", 130 | "import numpy as np\n", 131 | "\n", 132 | "# Data\n", 133 | "methods = ['Full-rank\\nTraining', 'ReLoRA', 'LoRA']\n", 134 | "projections = ['$W_Q$', '$W_K$', '$W_V$', '$W_{up}$', '$W_{down}$']\n", 135 | "counts = [1173, 5693, 7680, 1203, 5723, 7680, 2641, 6050, 7681, 0, 4884, 7680, 0, 4942, 7680]\n", 136 | "# divide by 12\n", 137 | "counts = [c // 12 for c in counts]\n", 138 | "\n", 139 | "# font 15\n", 140 | "plt.rc('font', size=13)\n", 141 | "\n", 142 | "# Custom color palette\n", 143 | "# colors = ['#E64B35', '#4DBBD5', '#00A087', '#3C5488', '#F39B7F']\n", 144 | "\n", 145 | "# Plotting\n", 146 | "fig, ax = plt.subplots(figsize=(8, 2.5), dpi=300)\n", 147 | "\n", 148 | "width = 0.2\n", 149 | "x = np.arange(len(projections))\n", 150 | "\n", 151 | "for i, method in enumerate(methods):\n", 152 | " alpha = {\"ReLoRA\": 0.9, \"Full-rank\\nTraining\": 0.5, \"LoRA\": 0.5}\n", 153 | " bars = ax.bar(x + i * width, counts[i::3], width, label=method, alpha=alpha[method])\n", 154 | "\n", 155 | "# Customize the plot\n", 156 | "ax.set_xticks(x + width)\n", 157 | "ax.set_xticklabels(projections)\n", 158 | "# ax.set_ylabel('# Singular Values < 0.1', fontsize=15)\n", 159 | "# ax.set_xlabel('Projections')\n", 160 | "# move legend to the right\n", 161 | "ax.legend(loc='center right', bbox_to_anchor=(1.3, 0.5))\n", 162 | "\n", 163 | "# ax.spines['top'].set_visible(False)\n", 164 | "# ax.spines['bottom'].set_visible(False)\n", 165 | "# ax.spines['left'].set_visible(False)\n", 166 | "# ax.spines['right'].set_visible(False)\n", 167 | "# ax.grid(axis='y', linestyle='--')\n", 168 | "\n", 169 | "# plt.subplots_adjust(left=0.15, right=10) # Adjust the left and right spacing here\n", 170 | "\n", 171 | "# Adjust layout\n", 172 | "plt.tight_layout()\n", 173 | "\n", 174 | "# Save or display the plot\n", 175 | "plt.savefig(\"zero_signuar_values.pdf\")\n", 176 | "plt.show()" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "import matplotlib.pyplot as plt\n", 186 | "import numpy as np\n", 187 | "\n", 188 | "# Data\n", 189 | "methods = ['Full-rank\\nTraining', 'ReLoRA', 'LoRA']\n", 190 | "projections = ['$W_Q$', '$W_K$', '$W_V$', '$W_{up}$', '$W_{down}$']\n", 191 | "counts = [1173, 5693, 7680, 1203, 5723, 7680, 2641, 6050, 7681, 0, 4884, 7680, 0, 4942, 7680]\n", 192 | "# divide by 12\n", 193 | "counts = [c // 12 for c in counts]\n", 194 | "\n", 195 | "# font 15\n", 196 | "plt.rc('font', size=13)\n", 197 | "\n", 198 | "# Custom color palette\n", 199 | "# colors = ['#E64B35', '#4DBBD5', '#00A087', '#3C5488', '#F39B7F']\n", 200 | "\n", 201 | "# Plotting\n", 202 | "fig, ax = plt.subplots(figsize=(8, 2.5), dpi=300)\n", 203 | "\n", 204 | "width = 0.2\n", 205 | "x = np.arange(len(projections))\n", 206 | "\n", 207 | "for i, method in enumerate(methods):\n", 208 | " alpha = {\"ReLoRA\": 0.9, \"Full-rank\\nTraining\": 0.5, \"LoRA\": 0.5}\n", 209 | " bars = ax.bar(x + i * width, counts[i::3], width, label=method, alpha=alpha[method])\n", 210 | "\n", 211 | "# Customize the plot\n", 212 | "ax.set_xticks(x + width)\n", 213 | "ax.set_xticklabels(projections)\n", 214 | "\n", 215 | "# Remove the zero tick from the y axis (only zero)\n", 216 | "# ax.set_yticks(ax.get_yticks()[1:])\n", 217 | "\n", 218 | "# ax.set_ylabel('# Singular Values < 0.1', fontsize=15)\n", 219 | "# ax.set_xlabel('Projections')\n", 220 | "# move legend to the right\n", 221 | "ax.legend(loc='center right', bbox_to_anchor=(1.3, 0.5))\n", 222 | "\n", 223 | "# add seaborn-like grid\n", 224 | "ax.grid(axis='y', linestyle='--')\n", 225 | "\n", 226 | "# plt.subplots_adjust(left=0.15, right=10) # Adjust the left and right spacing here\n", 227 | "\n", 228 | "# Adjust layout\n", 229 | "plt.tight_layout()\n", 230 | "\n", 231 | "# Save or display the plot\n", 232 | "plt.savefig(\"zero_signuar_values.pdf\")\n", 233 | "plt.show()\n" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "import matplotlib.pyplot as plt\n", 243 | "import numpy as np\n", 244 | "\n", 245 | "methods = ['ReLoRA', 'Full-rank Training', 'LoRA']\n", 246 | "projections = ['Q', 'K', 'V', 'Up', 'Down']\n", 247 | "\n", 248 | "counts = [\n", 249 | " 5693, 1173, 7680, # Q Projections\n", 250 | " 5723, 1203, 7680, # K Projections\n", 251 | " 6050, 2641, 7681, # V Projections\n", 252 | " 4884, 0, 7680, # Up Projections\n", 253 | " 4942, 0, 7680 # Down Projections\n", 254 | "]\n", 255 | "\n", 256 | "alpha = {'ReLoRA': 0.9, 'Full-rank Training': 0.5, 'LoRA': 0.5}\n", 257 | "\n", 258 | "width = 0.2 # the width of the bars\n", 259 | "x = np.arange(len(projections)) # the label locations\n", 260 | "\n", 261 | "fig, ax = plt.subplots()\n", 262 | "\n", 263 | "for i, method in enumerate(methods):\n", 264 | " ax.bar(x + i * width, counts[i::3], width, label=method, alpha=alpha[method])\n", 265 | "\n", 266 | "# Add some text for labels, title and custom x-axis tick labels, etc.\n", 267 | "ax.set_xlabel('Projections')\n", 268 | "ax.set_ylabel('Number of singular values < 0.1')\n", 269 | "ax.set_title('Comparison of different methods for various projections')\n", 270 | "ax.set_xticks(x)\n", 271 | "ax.set_xticklabels(projections)\n", 272 | "ax.legend()\n", 273 | "\n", 274 | "fig.tight_layout()\n", 275 | "plt.show()\n" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [] 284 | } 285 | ], 286 | "metadata": { 287 | "kernelspec": { 288 | "display_name": "base", 289 | "language": "python", 290 | "name": "python3" 291 | }, 292 | "language_info": { 293 | "codemirror_mode": { 294 | "name": "ipython", 295 | "version": 3 296 | }, 297 | "file_extension": ".py", 298 | "mimetype": "text/x-python", 299 | "name": "python", 300 | "nbconvert_exporter": "python", 301 | "pygments_lexer": "ipython3", 302 | "version": "3.10.9" 303 | }, 304 | "orig_nbformat": 4 305 | }, 306 | "nbformat": 4, 307 | "nbformat_minor": 2 308 | } 309 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /peft_pretraining/relora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | from typing import List 5 | from dataclasses import dataclass 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import bitsandbytes as bnb 11 | import bitsandbytes.functional as bnbF 12 | 13 | from transformers import AutoModelForCausalLM, AutoConfig 14 | 15 | from loguru import logger 16 | 17 | 18 | @dataclass 19 | class ReLoRaConfig: 20 | r: int 21 | lora_alpha: int 22 | lora_dropout: float 23 | target_modules: List[str] 24 | keep_original_weights: bool 25 | lora_only: bool = False 26 | trainable_scaling: bool = False 27 | quantize: str = None 28 | use_double_quant: bool = False 29 | 30 | 31 | def merge_and_reinit_functional(module): 32 | if not isinstance(module, ReLoRaLinear): 33 | return 34 | 35 | if module.quantize is not None: 36 | # Look below in merge_and_reinint method for the inspiration on how to implement this 37 | raise NotImplementedError("merge_and_reinit_functional for quantized models is not implemented yet. Use non-functional implementation") 38 | 39 | _delta = module.lora_B.weight @ module.lora_A.weight 40 | _delta = _delta * module._post_lora_scale() 41 | module.weight.data += _delta 42 | nn.init.kaiming_uniform_(module.lora_A.weight, a=math.sqrt(5)) 43 | 44 | nn.init.zeros_(module.lora_B.weight) 45 | if module.trainable_scaling: 46 | nn.init.zeros_(module.scaling) 47 | 48 | 49 | class ReLoRaModel(torch.nn.Module): 50 | def __init__( 51 | self, 52 | model, 53 | *, 54 | target_modules, 55 | r=128, 56 | lora_alpha=32, 57 | lora_dropout=0.1, 58 | keep_original_weights=True, 59 | lora_only=False, 60 | trainable_scaling=False, 61 | quantize=None, 62 | use_double_quant=False, 63 | ): 64 | if r <= 0: 65 | raise ValueError("r must be positive. If you want r == 0, use the original model.") 66 | 67 | super().__init__() 68 | self.wrapped_model: nn.Module = model 69 | self.r = r 70 | self.lora_alpha = lora_alpha 71 | self.lora_dropout = lora_dropout 72 | self.target_modules = target_modules 73 | self.keep_original_weights = keep_original_weights 74 | self.lora_only = lora_only 75 | self.trainable_scaling = trainable_scaling 76 | 77 | self._config = ReLoRaConfig( 78 | r=r, 79 | lora_alpha=lora_alpha, 80 | lora_dropout=lora_dropout, 81 | target_modules=target_modules, 82 | keep_original_weights=keep_original_weights, 83 | quantize=quantize, 84 | use_double_quant=use_double_quant, 85 | ) 86 | 87 | # patch methods 88 | self.forward = self.wrapped_model.forward 89 | 90 | target_modules_list = target_modules 91 | if isinstance(target_modules_list, str): 92 | target_modules_list = [target_modules_list] 93 | 94 | for module_name, module in self.wrapped_model.named_modules(): 95 | if not isinstance(module, nn.Linear): 96 | continue 97 | 98 | if not any(target_key in module_name for target_key in target_modules_list): 99 | continue 100 | 101 | weight_data = module.weight.data if keep_original_weights else None 102 | bias_data = None 103 | if module.bias is not None: 104 | bias_data = module.bias.data if keep_original_weights else None 105 | 106 | new_module = ReLoRaLinear( 107 | module.in_features, 108 | module.out_features, 109 | bias=module.bias is not None, 110 | r=self.r, 111 | lora_alpha=self.lora_alpha, 112 | lora_dropout=self.lora_dropout, 113 | lora_only=self.lora_only, 114 | trainable_scaling=self.trainable_scaling, 115 | quantize=quantize, 116 | weight_data=weight_data, 117 | bias_data=bias_data, 118 | bnb_4bit_use_double_quant=use_double_quant, 119 | ) 120 | if self.keep_original_weights: 121 | # make lora'ed network to be exacty the same as the original network at initialization 122 | nn.init.zeros_(new_module.lora_A.weight) 123 | assert new_module.lora_A.bias is None 124 | assert new_module.lora_B.bias is None 125 | 126 | if self.lora_only: 127 | assert not self.keep_original_weights 128 | module.weight = None 129 | 130 | del module 131 | 132 | parent = self._get_parent(module_name) 133 | module_suffix = module_name.split(".")[-1] 134 | setattr(parent, module_suffix, new_module) 135 | 136 | torch.cuda.empty_cache() 137 | 138 | def _get_parent(self, module_name): 139 | module_names_list = module_name.split(".") 140 | parent_name = ".".join(module_names_list[:-1]) 141 | parent = self.wrapped_model.get_submodule(parent_name) 142 | return parent 143 | 144 | def merge_and_reinit(self): 145 | for module in self.modules(): 146 | if isinstance(module, ReLoRaLinear): 147 | module.merge_and_reinit() 148 | 149 | def save_pretrained(self, path): 150 | self.wrapped_model.save_pretrained(path) 151 | with open(os.path.join(path, "relora_config.json"), "w") as f: 152 | json.dump(self._config.__dict__, f, indent=4) 153 | 154 | @classmethod 155 | def from_pretrained(cls, path): 156 | with open(os.path.join(path, "relora_config.json"), "r") as f: 157 | relora_config = json.load(f) 158 | 159 | config = AutoConfig.from_pretrained(path) 160 | 161 | base_model = AutoModelForCausalLM.from_config(config) 162 | if "keep_original" in relora_config: 163 | print("WARNING: keep_original is deprecated. Use lora_only instead.") 164 | print(f"keep_original: {relora_config['keep_original']}") 165 | relora_config["lora_only"] = not relora_config.pop("keep_original") 166 | relora_config["keep_original_weights"] = not relora_config["lora_only"] 167 | 168 | if "trainable_scaling" not in relora_config: 169 | relora_config["trainable_scaling"] = False 170 | 171 | model = cls(base_model, **relora_config) 172 | 173 | with open(os.path.join(path, "pytorch_model.bin"), "rb") as f: 174 | state_dict = torch.load(f, map_location="cpu") 175 | 176 | model.wrapped_model.load_state_dict(state_dict, strict=True) 177 | return model 178 | 179 | 180 | # The code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py 181 | class ReLoRaLinear(nn.Module): 182 | def __init__( 183 | self, 184 | in_features: int, 185 | out_features: int, 186 | r: int, 187 | *, 188 | lora_alpha: int = 1, 189 | lora_dropout: float = 0.1, 190 | lora_only: bool = False, 191 | weight_data=None, 192 | bias_data=None, 193 | trainable_scaling: bool = False, 194 | bias=True, 195 | device=None, 196 | dtype=None, 197 | quantize=False, 198 | bnb_4bit_use_double_quant=False, 199 | bnb_4bit_quant_type="nf4", 200 | ): 201 | """Wraps linear layer x W into x W + x W_a @ W_b * lora_alpha / r 202 | 203 | Notice that scale = lora_alpha / r. 204 | """ 205 | nn.Module.__init__(self) 206 | if r <= 0: 207 | raise ValueError("r must be positive. If you want r == 0, use the original model.") 208 | 209 | if lora_only: 210 | self.weight = None 211 | self.bias = None 212 | else: 213 | # if full model weight + lora weight 214 | if bias_data is None: 215 | bias_data = torch.zeros(out_features, device=device, dtype=dtype, requires_grad=True) if bias else None 216 | self.bias = nn.Parameter(bias_data) if bias else None 217 | 218 | if weight_data is None: 219 | # note that our trainable weight are W_a and W_b 220 | weight_data = torch.zeros(out_features, in_features, device=device, dtype=dtype, requires_grad=False) 221 | 222 | if quantize is None: 223 | self.weight = nn.Parameter(weight_data, requires_grad=False) 224 | elif quantize == "4bit": 225 | self.weight = bnb.nn.Params4bit( 226 | weight_data, 227 | requires_grad=False, 228 | compress_statistics=bnb_4bit_use_double_quant, 229 | quant_type=bnb_4bit_quant_type, 230 | ) 231 | elif quantize == "8bit": 232 | logger.warning("Int8 currently does not support merge_and_reinit! It will fail") 233 | self.weight = bnb.nn.Int8Params( 234 | weight_data, 235 | requires_grad=False, 236 | ) 237 | else: 238 | raise ValueError(f"Unknown quantize type: {quantize}") 239 | 240 | self.in_features = in_features 241 | self.out_features = out_features 242 | self.r = r 243 | self.lora_alpha = lora_alpha 244 | self.lora_dropout = nn.Dropout(p=lora_dropout) 245 | self.lora_only = lora_only 246 | self.trainable_scaling = trainable_scaling 247 | self.quantize = quantize 248 | 249 | if r > 0: 250 | self.lora_A = nn.Linear(in_features, r, bias=False) 251 | nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) 252 | self.lora_B = nn.Linear(r, out_features, bias=False) 253 | nn.init.zeros_(self.lora_B.weight) 254 | if trainable_scaling: 255 | self.scaling = nn.Parameter(torch.tensor([1.]), requires_grad=True) 256 | else: 257 | self.scaling = self.lora_alpha / self.r 258 | 259 | # Freezing the pre-trained weight matrix 260 | if not self.lora_only: 261 | self.weight.requires_grad = False 262 | 263 | def _post_lora_scale(self): 264 | if self.trainable_scaling: 265 | return self.scaling.tanh() 266 | 267 | return self.scaling 268 | 269 | @torch.no_grad() 270 | def merge_and_reinit(self): 271 | if self.lora_only: 272 | print("WARNING: Skipping merge and reinit, because only lora parameters are used") 273 | return 274 | 275 | if not self.quantize: 276 | self.weight.data += self.lora_B.weight @ self.lora_A.weight * self._post_lora_scale() 277 | elif self.quantize == "4bit": 278 | self.weight: bnb.nn.Params4bit 279 | _weight_fp = torch.empty(self.weight.data.shape, dtype=self.lora_B.weight.dtype, device=self.weight.data.device) 280 | bnbF.dequantize_4bit(self.weight.data, self.weight.quant_state, out=_weight_fp) 281 | _weight_fp += self.lora_B.weight @ self.lora_A.weight * self._post_lora_scale() 282 | self.weight.data, self.weight.quant_state = bnbF.quantize_4bit( 283 | _weight_fp, 284 | quant_type=self.weight.quant_type, 285 | compress_statistics=self.weight.compress_statistics, 286 | ) 287 | del _weight_fp 288 | elif self.quantize == "8bit": 289 | self.weight: bnb.nn.Int8Params 290 | _weight_fp = torch.empty(self.weight.data.shape, dtype=torch.bfloat16, device=self.weight.data.device) 291 | # !out assigned inplace 292 | bnbF.dequantize_blockwise(self.weight.data, self.self.lora_B.weight.dtype, out=_weight_fp) 293 | _weight_fp += self.lora_B.weight @ self.lora_A.weight * self._post_lora_scale() 294 | self.weight.data, self.weight.quant_state = bnbF.quantize_blockwise( 295 | _weight_fp, 296 | self.weight.quant_state, 297 | out=self.weight.data, 298 | ) 299 | del _weight_fp 300 | else: 301 | raise ValueError(f"Unknown quantize type: {self.quantize}") 302 | 303 | nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) 304 | 305 | nn.init.zeros_(self.lora_B.weight) 306 | if self.trainable_scaling: 307 | nn.init.zeros_(self.scaling) 308 | 309 | def forward(self, x: torch.Tensor): 310 | if self.lora_only: 311 | # just lora 312 | return self.lora_B(self.lora_A(self.lora_dropout(x))) * self._post_lora_scale() 313 | 314 | if self.quantize == "4bit": 315 | result = bnb.matmul_4bit(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state) 316 | elif self.quantize == "8bit": 317 | result = bnb.matmul(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state) 318 | else: 319 | result = F.linear(x, self.weight, bias=self.bias) 320 | 321 | if self.r > 0: 322 | result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self._post_lora_scale() 323 | return result 324 | -------------------------------------------------------------------------------- /peft_pretraining/megatron_dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, EleutherAI 2 | # This file is based on code by the authors denoted below and has been modified from its original version. 3 | # 4 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # Based on: https://github.com/EleutherAI/gpt-neox/blob/408e29d9c746a02d842917bb7447c5c4be0b42d4/megatron/data/gpt2_dataset.py 18 | 19 | """GPT2 style dataset.""" 20 | 21 | import os 22 | import time 23 | 24 | import torch 25 | import torch.distributed as dist 26 | import numpy as np 27 | from loguru import logger 28 | 29 | 30 | 31 | 32 | class GPT2Dataset(torch.utils.data.Dataset): 33 | def __init__( 34 | self, 35 | name, 36 | data_prefix, 37 | documents, 38 | indexed_dataset, 39 | num_samples, 40 | seq_length, 41 | seed, 42 | build_index_mappings=True, 43 | use_shared_fs=True, 44 | label_dataset=None, 45 | ): 46 | 47 | self.name = name 48 | self.indexed_dataset = indexed_dataset 49 | self.label_dataset = label_dataset 50 | 51 | # Checks 52 | assert np.min(documents) >= 0 53 | assert np.max(documents) < indexed_dataset.sizes.shape[0] 54 | 55 | if build_index_mappings: 56 | # Build index mappings. 57 | self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( 58 | self.name, 59 | data_prefix, 60 | documents, 61 | self.indexed_dataset.sizes, 62 | num_samples, 63 | seq_length, 64 | seed, 65 | use_shared_fs=use_shared_fs, 66 | ) 67 | self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 68 | self.sample_idx_len = self.sample_idx.shape[0] - 1 69 | 70 | if self.shuffle_idx_len != self.sample_idx_len - 1: 71 | logger.warning( 72 | f"shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" 73 | ) 74 | 75 | def __len__(self): 76 | return min(self.shuffle_idx_len, self.sample_idx_len) 77 | 78 | def __getitem__(self, idx): 79 | try: 80 | return self.get_item_unsafe(idx) 81 | except IndexError: 82 | new_idx = idx % len(self) 83 | logger.warning( 84 | f"Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" 85 | ) 86 | return self[new_idx] 87 | 88 | def get_item_unsafe(self, idx): 89 | # Get the shuffled index. 90 | idx = self.shuffle_idx[idx] 91 | # Start and end documents and offsets. 92 | doc_index_f = self.sample_idx[idx][0] 93 | doc_index_l = self.sample_idx[idx + 1][0] 94 | offset_f = self.sample_idx[idx][1] 95 | offset_l = self.sample_idx[idx + 1][1] 96 | # Labels and texts are supposed to be fully in sync. 97 | datasets = [self.indexed_dataset] if self.label_dataset is None else [self.indexed_dataset, self.label_dataset] 98 | samples = [] 99 | # If we are within the same document, just extract the chunk. 100 | for n, dataset in enumerate(datasets): 101 | if doc_index_f == doc_index_l: 102 | samples.append(dataset.get( 103 | self.doc_idx[doc_index_f], 104 | offset=offset_f, 105 | length=offset_l - offset_f + 1, 106 | )) 107 | else: 108 | # Otherwise, get the rest of the initial document. 109 | sample_list = [ 110 | dataset.get(self.doc_idx[doc_index_f], offset=offset_f) 111 | ] 112 | # Loop over all in between documents and add the entire document. 113 | for i in range(doc_index_f + 1, doc_index_l): 114 | sample_list.append(dataset.get(self.doc_idx[i])) 115 | # And finally add the relevant portion of last document. 116 | sample_list.append( 117 | dataset.get( 118 | self.doc_idx[doc_index_l], length=offset_l + 1 119 | ) 120 | ) 121 | samples.append(np.concatenate(sample_list)) 122 | 123 | if len(datasets) == 1: 124 | return {"input_ids": np.array(samples[0], dtype=np.int64)} 125 | 126 | return {"input_ids": np.array(samples[0], dtype=np.int64), "label": np.array(samples[1], dtype=np.int64)} 127 | 128 | 129 | def _build_index_mappings( 130 | name, 131 | data_prefix, 132 | documents, 133 | sizes, 134 | num_samples, 135 | seq_length, 136 | seed, 137 | use_shared_fs=True, 138 | ): 139 | """Build doc-idx, sample-idx, and shuffle-idx. 140 | doc-idx: is an array (ordered) of documents to be used in training. 141 | sample-idx: is the start document index and document offset for each 142 | training sample. 143 | shuffle-idx: maps the sample index into a random index into sample-idx. 144 | """ 145 | # Number of tokens in each epoch and number of required epochs. 146 | tokens_per_epoch = _num_tokens(documents, sizes) 147 | num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) 148 | # rng state 149 | np_rng = np.random.RandomState(seed=seed) 150 | 151 | # Filename of the index mappings. 152 | _filename = data_prefix 153 | _filename += "_{}_indexmap".format(name) 154 | _filename += "_{}ns".format(num_samples) 155 | _filename += "_{}sl".format(seq_length) 156 | _filename += "_{}s".format(seed) 157 | doc_idx_filename = _filename + "_doc_idx.npy" 158 | sample_idx_filename = _filename + "_sample_idx.npy" 159 | shuffle_idx_filename = _filename + "_shuffle_idx.npy" 160 | 161 | if not use_shared_fs: 162 | should_process_dataset = int(os.environ["LOCAL_RANK"]) == 0 163 | elif dist.is_initialized(): 164 | should_process_dataset = dist.get_rank() == 0 165 | else: 166 | should_process_dataset = True 167 | 168 | # Build the indexed mapping if not exist. 169 | if should_process_dataset: 170 | if ( 171 | (not os.path.isfile(doc_idx_filename)) 172 | or (not os.path.isfile(sample_idx_filename)) 173 | or (not os.path.isfile(shuffle_idx_filename)) 174 | ): 175 | logger.warning( 176 | " > WARNING: could not find index map files, building " 177 | "the indices on rank 0 ..." 178 | ) 179 | # doc-idx. 180 | start_time = time.time() 181 | doc_idx = _build_doc_idx(documents, num_epochs, np_rng) 182 | np.save(doc_idx_filename, doc_idx, allow_pickle=True) 183 | logger.info( 184 | " > elapsed time to build and save doc-idx mapping " 185 | "(seconds): {:4f}".format(time.time() - start_time) 186 | ) 187 | # sample-idx. 188 | start_time = time.time() 189 | # Use C++ implementation for speed. 190 | from peft_pretraining.megatron_dataset import helpers 191 | 192 | assert doc_idx.dtype == np.int32 193 | assert sizes.dtype == np.int32 194 | 195 | num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length 196 | if 2 * (num_samples + 1) < np.iinfo(np.int32).max: 197 | sample_idx = helpers.build_sample_idx_int32( 198 | sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch 199 | ) 200 | else: 201 | sample_idx = helpers.build_sample_idx_int64( 202 | sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch 203 | ) 204 | np.save(sample_idx_filename, sample_idx, allow_pickle=True) 205 | logger.info( 206 | " > elapsed time to build and save sample-idx mapping " 207 | "(seconds): {:4f}".format(time.time() - start_time) 208 | ) 209 | # shuffle-idx. 210 | start_time = time.time() 211 | # -1 is due to data structure used to retrieve the index: 212 | # sample i --> [sample_idx[i], sample_idx[i+1]) 213 | shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) 214 | np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) 215 | logger.info( 216 | " > elapsed time to build and save shuffle-idx mapping" 217 | " (seconds): {:4f}".format(time.time() - start_time) 218 | ) 219 | 220 | # This should be a barrier but nccl barrier assumes 221 | # device_index=rank which is not the case for model 222 | # parallel case 223 | counts = torch.cuda.LongTensor([1]) 224 | if dist.is_initialized(): 225 | dist.all_reduce(counts) 226 | 227 | # Load mappings. 228 | start_time = time.time() 229 | logger.info(" > loading doc-idx mapping from {}".format(doc_idx_filename)) 230 | doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r") 231 | logger.info(" > loading sample-idx mapping from {}".format(sample_idx_filename)) 232 | sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") 233 | logger.info(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) 234 | shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") 235 | logger.info( 236 | " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) 237 | ) 238 | logger.info(" total number of samples: {}".format(sample_idx.shape[0])) 239 | logger.info(" total number of epochs: {}".format(num_epochs)) 240 | 241 | return doc_idx, sample_idx, shuffle_idx 242 | 243 | 244 | def _num_tokens(documents, sizes): 245 | """Total number of tokens in the dataset.""" 246 | return np.sum(sizes[documents]) 247 | 248 | 249 | def _num_epochs(tokens_per_epoch, seq_length, num_samples): 250 | """Based on number of samples and sequence length, calculate how many 251 | epochs will be needed.""" 252 | num_epochs = 0 253 | total_tokens = 0 254 | while True: 255 | num_epochs += 1 256 | total_tokens += tokens_per_epoch 257 | # -1 is because we need to retrieve seq_length + 1 token each time 258 | # but the last token will overlap with the first token of the next 259 | # sample except for the last sample. 260 | if ((total_tokens - 1) // seq_length) >= num_samples: 261 | return num_epochs 262 | 263 | 264 | def _build_doc_idx(documents, num_epochs, np_rng): 265 | """Build an array with length = number-of-epochs * number-of-documents. 266 | Each index is mapped to a corresponding document.""" 267 | doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] 268 | doc_idx[:] = documents 269 | doc_idx = doc_idx.reshape(-1) 270 | doc_idx = doc_idx.astype(np.int32) 271 | np_rng.shuffle(doc_idx) 272 | return doc_idx 273 | 274 | 275 | def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): 276 | """Sample index mapping is a 2D array with sizes 277 | [number-of-samples + 1, 2] where [..., 0] contains 278 | the index into `doc_idx` and [..., 1] is the 279 | starting offset in that document.""" 280 | 281 | # Total number of samples. For -1 see comments in `_num_epochs`. 282 | num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length 283 | sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int64) 284 | 285 | # Index into sample_idx. 286 | sample_index = 0 287 | # Index into doc_idx. 288 | doc_idx_index = 0 289 | # Beginning offset for each document. 290 | doc_offset = 0 291 | # Start with first document and no offset. 292 | sample_idx[sample_index][0] = doc_idx_index 293 | sample_idx[sample_index][1] = doc_offset 294 | sample_index += 1 295 | while sample_index <= num_samples: 296 | # Start with a fresh sequence. 297 | remaining_seq_length = seq_length + 1 298 | while remaining_seq_length != 0: 299 | # Get the document length. 300 | doc_id = doc_idx[doc_idx_index] 301 | doc_length = sizes[doc_id] - doc_offset 302 | # And add it to the current sequence. 303 | remaining_seq_length -= doc_length 304 | # If we have more than a full sequence, adjust offset and set 305 | # remaining length to zero so we return from the while loop. 306 | # Note that -1 here is for the same reason we have -1 in 307 | # `_num_epochs` calculations. 308 | if remaining_seq_length <= 0: 309 | doc_offset += remaining_seq_length + doc_length - 1 310 | remaining_seq_length = 0 311 | else: 312 | # Otherwise, start from the beginning of the next document. 313 | doc_idx_index += 1 314 | doc_offset = 0 315 | # Record the sequence. 316 | sample_idx[sample_index][0] = doc_idx_index 317 | sample_idx[sample_index][1] = doc_offset 318 | sample_index += 1 319 | 320 | return sample_idx 321 | 322 | 323 | def _build_shuffle_idx(size, np_rng): 324 | """Build the range [0, size) and shuffle.""" 325 | dtype_ = np.uint32 326 | if size >= (np.iinfo(np.uint32).max - 1): 327 | dtype_ = np.int64 328 | shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) 329 | np_rng.shuffle(shuffle_idx) 330 | return shuffle_idx 331 | -------------------------------------------------------------------------------- /notebooks/05_check_ranks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('../')\n", 11 | "\n", 12 | "from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM\n", 13 | "from peft_pretraining.relora import ReLoRaModel" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "!ls ../checkpoints/llama_130m-2023-05-09-19-17-11" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# llama_130m-2023-05-09-19-18-46 is major-moon-154\n", 32 | "regular_250M = AutoModelForCausalLM.from_pretrained(\"../checkpoints/llama_60m-2023-05-13-17-13-02/model_10000\")\n", 33 | "\n", 34 | "peft_250M = ReLoRaModel.from_pretrained(\"../checkpoints/llama_130m-2023-05-09-19-17-11/model_10000\")\n", 35 | "\n", 36 | "# llama_7b = AutoModelForCausalLM.from_pretrained(\"huggyllama/llama-7b\")" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "import torch\n", 46 | "from tqdm import tqdm\n", 47 | "\n", 48 | "# get singular values of all layers\n", 49 | "q_projs = []\n", 50 | "k_projs = []\n", 51 | "v_projs = []\n", 52 | "o_projs = []\n", 53 | "gate_projs = []\n", 54 | "down_projs = []\n", 55 | "up_projs = []\n", 56 | "\n", 57 | "for layer in tqdm(regular_250M.model.layers):\n", 58 | " q_projs_weight = layer.self_attn.q_proj.weight.detach()\n", 59 | " singular_values = torch.svd(q_projs_weight).S\n", 60 | " q_projs.append(singular_values)\n", 61 | "\n", 62 | " k_projs_weight = layer.self_attn.k_proj.weight.detach()\n", 63 | " singular_values = torch.svd(k_projs_weight).S\n", 64 | " k_projs.append(singular_values)\n", 65 | "\n", 66 | " v_projs_weight = layer.self_attn.v_proj.weight.detach()\n", 67 | " singular_values = torch.svd(v_projs_weight).S\n", 68 | " v_projs.append(singular_values)\n", 69 | "\n", 70 | " o_projs_weight = layer.self_attn.o_proj.weight.detach()\n", 71 | " singular_values = torch.svd(o_projs_weight).S\n", 72 | " o_projs.append(singular_values)\n", 73 | "\n", 74 | " gate_projs_weight = layer.mlp.gate_proj.weight.detach()\n", 75 | " singular_values = torch.svd(gate_projs_weight).S\n", 76 | " gate_projs.append(singular_values)\n", 77 | "\n", 78 | " down_projs_weight = layer.mlp.down_proj.weight.detach()\n", 79 | " singular_values = torch.svd(down_projs_weight).S\n", 80 | " down_projs.append(singular_values)\n", 81 | "\n", 82 | " up_projs_weight = layer.mlp.up_proj.weight.detach()\n", 83 | " singular_values = torch.svd(up_projs_weight).S\n", 84 | " up_projs.append(singular_values)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "import torch\n", 94 | "from tqdm import tqdm\n", 95 | "\n", 96 | "def get_linear_weight_from_relora(relora_layer):\n", 97 | " return relora_layer.lora_B.weight @ relora_layer.lora_A.weight * relora_layer.scaling\n", 98 | "\n", 99 | "# get singular values of all layers\n", 100 | "peft_q_projs = []\n", 101 | "peft_k_projs = []\n", 102 | "peft_v_projs = []\n", 103 | "peft_o_projs = []\n", 104 | "peft_gate_projs = []\n", 105 | "peft_down_projs = []\n", 106 | "peft_up_projs = []\n", 107 | "\n", 108 | "for layer in tqdm(peft_250M.wrapped_model.model.layers):\n", 109 | " q_projs_weight = get_linear_weight_from_relora(layer.self_attn.q_proj).detach()\n", 110 | " singular_values = torch.svd(q_projs_weight).S\n", 111 | " peft_q_projs.append(singular_values)\n", 112 | "\n", 113 | " k_projs_weight = get_linear_weight_from_relora(layer.self_attn.k_proj).detach()\n", 114 | " singular_values = torch.svd(k_projs_weight).S\n", 115 | " peft_k_projs.append(singular_values)\n", 116 | "\n", 117 | " v_projs_weight = get_linear_weight_from_relora(layer.self_attn.v_proj).detach()\n", 118 | " singular_values = torch.svd(v_projs_weight).S\n", 119 | " peft_v_projs.append(singular_values)\n", 120 | "\n", 121 | " o_projs_weight = get_linear_weight_from_relora(layer.self_attn.o_proj).detach()\n", 122 | " singular_values = torch.svd(o_projs_weight).S\n", 123 | " peft_o_projs.append(singular_values)\n", 124 | "\n", 125 | " gate_projs_weight = get_linear_weight_from_relora(layer.mlp.gate_proj).detach()\n", 126 | " singular_values = torch.svd(gate_projs_weight).S\n", 127 | " peft_gate_projs.append(singular_values)\n", 128 | "\n", 129 | " down_projs_weight = get_linear_weight_from_relora(layer.mlp.down_proj).detach()\n", 130 | " singular_values = torch.svd(down_projs_weight).S\n", 131 | " peft_down_projs.append(singular_values)\n", 132 | "\n", 133 | " up_projs_weight = get_linear_weight_from_relora(layer.mlp.up_proj).detach()\n", 134 | " singular_values = torch.svd(up_projs_weight).S\n", 135 | " peft_up_projs.append(singular_values)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "import torch\n", 145 | "from tqdm import tqdm\n", 146 | "\n", 147 | "# get singular values of all layers\n", 148 | "# noised_peft_q_projs = []\n", 149 | "noised_peft_q_projs = []\n", 150 | "noised_peft_down_projs = []\n", 151 | "\n", 152 | "for layer in tqdm(peft_250M.wrapped_model.model.layers):\n", 153 | " q_projs_weight = get_linear_weight_from_relora(layer.self_attn.q_proj).detach() + torch.randn_like(get_linear_weight_from_relora(layer.self_attn.q_proj).detach()) * 0.04\n", 154 | " singular_values = torch.svd(q_projs_weight).S\n", 155 | " noised_peft_q_projs.append(singular_values)\n", 156 | "\n", 157 | " # down proj\n", 158 | " down_projs_weight = get_linear_weight_from_relora(layer.mlp.down_proj).detach() + torch.randn_like(get_linear_weight_from_relora(layer.mlp.down_proj).detach()) * 0.04\n", 159 | " singular_values = torch.svd(down_projs_weight).S\n", 160 | " noised_peft_down_projs.append(singular_values)\n" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "# perform magtinude pruning\n", 170 | "import numpy as np\n", 171 | "\n", 172 | "import torch\n", 173 | "from tqdm import tqdm\n", 174 | "\n", 175 | "pruned_peft_q_projs = []\n", 176 | "\n", 177 | "for layer in tqdm(peft_250M.wrapped_model.model.layers):\n", 178 | " q_projs_weight = get_linear_weight_from_relora(layer.self_attn.q_proj).detach()\n", 179 | "\n", 180 | " threshold_90p = np.percentile(q_projs_weight.abs().numpy(), 0.01)\n", 181 | " q_projs_weight = q_projs_weight * (q_projs_weight.abs() > threshold_90p)\n", 182 | " singular_values = torch.svd(q_projs_weight).S\n", 183 | " pruned_peft_q_projs.append(singular_values)\n", 184 | "\n", 185 | "# for regular\n", 186 | "\n", 187 | "pruned_q_projs = []\n", 188 | "\n", 189 | "for layer in tqdm(regular_250M.model.layers):\n", 190 | " q_projs_weight = layer.self_attn.q_proj.weight.detach()\n", 191 | "\n", 192 | " threshold_90p = np.percentile(q_projs_weight.abs(), 0.01)\n", 193 | " q_projs_weight = q_projs_weight * (q_projs_weight.abs() > threshold_90p)\n", 194 | " # q_projs_weight = torch.zeros_like(q_projs_weight)\n", 195 | " singular_values = torch.svd(q_projs_weight).S\n", 196 | " pruned_q_projs.append(singular_values)\n", 197 | "\n", 198 | "# random pruning of regular\n", 199 | "\n", 200 | "random_pruned_q_projs = []\n", 201 | "\n", 202 | "for layer in tqdm(regular_250M.model.layers):\n", 203 | " q_projs_weight = layer.self_attn.q_proj.weight.detach()\n", 204 | "\n", 205 | " threshold_90p = np.percentile(q_projs_weight.abs(), 0.01)\n", 206 | " q_projs_weight = q_projs_weight * (torch.rand_like(q_projs_weight) > threshold_90p)\n", 207 | " singular_values = torch.svd(q_projs_weight).S\n", 208 | " random_pruned_q_projs.append(singular_values)\n" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "# prune down projection\n", 218 | "\n", 219 | "pruned_down_projs = []\n", 220 | "\n", 221 | "for layer in tqdm(regular_250M.model.layers):\n", 222 | " down_projs_weight = layer.mlp.down_proj.weight.detach()\n", 223 | "\n", 224 | " threshold_90p = np.percentile(down_projs_weight.abs(), 0.01)\n", 225 | " down_projs_weight = down_projs_weight * (down_projs_weight.abs() > threshold_90p)\n", 226 | " singular_values = torch.svd(down_projs_weight).S\n", 227 | " pruned_down_projs.append(singular_values)\n" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "# plot histogram of singular values for q_projs over layers\n", 237 | "from matplotlib import pyplot as plt\n", 238 | "\n", 239 | "fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=100)\n", 240 | "ax.set_title(\"Singular Values of Q Projections\")\n", 241 | "ax.set_xlabel(\"Singular Value\")\n", 242 | "ax.set_ylabel(\"Frequency\")\n", 243 | "# ax.hist(torch.cat(noised_peft_q_projs).numpy(), density=True, bins=100, alpha=0.3, label=\"Noised PEFT\")\n", 244 | "ax.hist(torch.cat(q_projs).numpy(), density=True, bins=100, alpha=0.8, label=\"Regular\")\n", 245 | "ax.hist(torch.cat(peft_q_projs).numpy(), density=True, bins=100, alpha=0.5, label=\"PEFT\")\n", 246 | "# ax.hist(torch.cat(pruned_peft_q_projs).numpy(), density=True, bins=100, alpha=0.3, label=\"Pruned PEFT\")\n", 247 | "ax.hist(torch.cat(pruned_q_projs).numpy(), density=True, bins=100, alpha=0.3, label=\"Pruned Regular\")\n", 248 | "# ax.hist(torch.cat(random_pruned_q_projs).numpy(), density=True, bins=100, alpha=0.3, label=\"Random Pruned Regular\")\n", 249 | "\n", 250 | "# ylim\n", 251 | "ax.set_ylim(0, 4)\n", 252 | "ax.set_xlim(0, 6)\n", 253 | "\n", 254 | "ax.legend()\n", 255 | "plt.show()" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "# plot histogram of singular values for k_projs over layers\n", 265 | "from matplotlib import pyplot as plt\n", 266 | "\n", 267 | "fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=100)\n", 268 | "ax.set_title(\"Singular Values of K Projections\")\n", 269 | "ax.set_xlabel(\"Singular Value\")\n", 270 | "ax.set_ylabel(\"Frequency\")\n", 271 | "ax.hist(torch.cat(k_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n", 272 | "ax.hist(torch.cat(peft_k_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n", 273 | "ax.hist(torch.cat(llama7b_k_projs).numpy(), bins=100, alpha=0.5, label=\"LLAMA-7B\")\n", 274 | "plt.show()" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 284 | "ax.set_title(\"Singular Values of V Projections\")\n", 285 | "ax.set_xlabel(\"Singular Value\")\n", 286 | "ax.set_ylabel(\"Frequency\")\n", 287 | "ax.hist(torch.cat(v_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n", 288 | "ax.hist(torch.cat(peft_v_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n", 289 | "ax.hist(torch.cat(llama7b_v_projs).numpy(), bins=100, alpha=0.5, label=\"LLAMA-7B\")\n", 290 | "plt.show()" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 300 | "ax.set_title(\"Singular Values of O Projections\")\n", 301 | "ax.set_xlabel(\"Singular Value\")\n", 302 | "ax.set_ylabel(\"Frequency\")\n", 303 | "ax.hist(torch.cat(o_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n", 304 | "ax.hist(torch.cat(peft_o_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n", 305 | "plt.show()" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 315 | "ax.set_title(\"Singular Values of Up Projections\")\n", 316 | "ax.set_xlabel(\"Singular Value\")\n", 317 | "ax.set_ylabel(\"Frequency\")\n", 318 | "ax.hist(torch.cat(up_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n", 319 | "ax.hist(torch.cat(peft_up_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n", 320 | "plt.show()" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 330 | "ax.set_title(\"Singular Values of Down Projections\")\n", 331 | "ax.set_xlabel(\"Singular Value\")\n", 332 | "ax.set_ylabel(\"Frequency\")\n", 333 | "ax.hist(torch.cat(down_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n", 334 | "ax.hist(torch.cat(peft_down_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n", 335 | "plt.show()" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 345 | "ax.set_title(\"Singular Values of Gate Projections\")\n", 346 | "ax.set_xlabel(\"Singular Value\")\n", 347 | "ax.set_ylabel(\"Frequency\")\n", 348 | "ax.hist(torch.cat(gate_projs).numpy(), bins=100, alpha=0.5, label=\"Regular\")\n", 349 | "ax.hist(torch.cat(peft_gate_projs).numpy(), bins=100, alpha=0.5, label=\"PEFT\")\n", 350 | "plt.show()" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [] 359 | } 360 | ], 361 | "metadata": { 362 | "kernelspec": { 363 | "display_name": "base", 364 | "language": "python", 365 | "name": "python3" 366 | }, 367 | "language_info": { 368 | "codemirror_mode": { 369 | "name": "ipython", 370 | "version": 3 371 | }, 372 | "file_extension": ".py", 373 | "mimetype": "text/x-python", 374 | "name": "python", 375 | "nbconvert_exporter": "python", 376 | "pygments_lexer": "ipython3", 377 | "version": "3.10.9" 378 | }, 379 | "orig_nbformat": 4 380 | }, 381 | "nbformat": 4, 382 | "nbformat_minor": 2 383 | } 384 | -------------------------------------------------------------------------------- /notebooks/16_quantized.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/mnt/shared_home/vlialin/miniconda3/envs/peft_pretraining_shala/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import torch\n", 19 | "import torch.nn as nn\n", 20 | "import torch.nn.functional as F\n", 21 | "\n", 22 | "import bitsandbytes as bnb\n", 23 | "import bitsandbytes.functional as bnbF\n", 24 | "\n", 25 | "from peft_pretraining.modeling_llama import LlamaForCausalLM\n", 26 | "from peft_pretraining.relora import ReLoRaModel\n", 27 | "\n", 28 | "from transformers import AutoTokenizer" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stderr", 38 | "output_type": "stream", 39 | "text": [ 40 | "/mnt/shared_home/vlialin/miniconda3/envs/peft_pretraining_shala/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", 41 | "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", 42 | "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n", 43 | "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", 44 | "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n", 45 | " warnings.warn(\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "tokenizer = AutoTokenizer.from_pretrained(\"t5-base\")\n", 51 | "orig_model = LlamaForCausalLM.from_pretrained(\"../checkpoints/llama_250m-2023-06-09-11-29-56_up_to_5K/model_5000\")#, load_in_8bit=True)#, load_in_4bit=True)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "input_ids = tokenizer(\"Why am I doing this?\", return_tensors=\"pt\").input_ids\n", 61 | "# orig_out = orig_model(input_ids=input_ids)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "model = ReLoRaModel(\n", 71 | " orig_model,\n", 72 | " r=128,\n", 73 | " lora_alpha=32,\n", 74 | " lora_dropout=0.1,\n", 75 | " target_modules=[\"attn\", \"attention\", \"mlp\"],\n", 76 | " trainable_scaling=False,\n", 77 | " keep_original_weights=True,\n", 78 | " quantize4bit=True,\n", 79 | " use_double_quant=True,\n", 80 | ")\n", 81 | "model = model.to(dtype=torch.bfloat16, device=\"cuda\")" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "trainable_parameters = [p for p in model.parameters() if p.requires_grad]\n", 91 | "optimizer = torch.optim.Adam(trainable_parameters, lr=0.001)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 6, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "input_ids = input_ids.cuda()\n", 101 | "quantized_out = model(input_ids, labels=input_ids)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 7, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "loss = quantized_out.loss\n", 111 | "loss.backward()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 8, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "optimizer.step()" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 13, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "weight = model.wrapped_model.model.layers[0].self_attn.q_proj.weight" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 14, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "weight_data_fp = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 15, 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "data": { 148 | "text/plain": [ 149 | "tensor([[ 0.0106, 0.0106, 0.0656, ..., -0.0438, -0.0330, 0.0354],\n", 150 | " [ 0.0519, 0.0317, 0.0404, ..., -0.0113, 0.0270, -0.0056],\n", 151 | " [-0.0621, 0.0349, -0.0326, ..., 0.0363, 0.0104, 0.0218],\n", 152 | " ...,\n", 153 | " [ 0.0286, -0.0145, -0.0267, ..., 0.0047, 0.0199, -0.0309],\n", 154 | " [-0.0207, -0.0048, 0.0231, ..., 0.0368, 0.0368, -0.0186],\n", 155 | " [-0.0327, -0.0246, -0.0057, ..., -0.0520, 0.0293, 0.0000]],\n", 156 | " device='cuda:0', dtype=torch.float16)" 157 | ] 158 | }, 159 | "execution_count": 15, 160 | "metadata": {}, 161 | "output_type": "execute_result" 162 | } 163 | ], 164 | "source": [ 165 | "weight_data_fp" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 9, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "Parameter containing:\n", 177 | "Parameter(Params4bit([[153],\n", 178 | " [247],\n", 179 | " [114],\n", 180 | " ...,\n", 181 | " [198],\n", 182 | " [ 48],\n", 183 | " [215]], device='cuda:0', dtype=torch.uint8))" 184 | ] 185 | }, 186 | "execution_count": 9, 187 | "metadata": {}, 188 | "output_type": "execute_result" 189 | } 190 | ], 191 | "source": [ 192 | "model.wrapped_model.model.layers[0].self_attn.q_proj.weight" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 12, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "data": { 202 | "text/plain": [ 203 | "ReLoRaModel(\n", 204 | " (wrapped_model): LlamaForCausalLM(\n", 205 | " (model): LlamaModel(\n", 206 | " (embed_tokens): Embedding(32000, 768, padding_idx=31999)\n", 207 | " (layers): ModuleList(\n", 208 | " (0-23): 24 x LlamaDecoderLayer(\n", 209 | " (self_attn): LlamaAttention(\n", 210 | " (q_proj): ReLoRaLinear(\n", 211 | " (lora_dropout): Dropout(p=0.1, inplace=False)\n", 212 | " (lora_A): Linear(in_features=768, out_features=128, bias=False)\n", 213 | " (lora_B): Linear(in_features=128, out_features=768, bias=False)\n", 214 | " )\n", 215 | " (k_proj): ReLoRaLinear(\n", 216 | " (lora_dropout): Dropout(p=0.1, inplace=False)\n", 217 | " (lora_A): Linear(in_features=768, out_features=128, bias=False)\n", 218 | " (lora_B): Linear(in_features=128, out_features=768, bias=False)\n", 219 | " )\n", 220 | " (v_proj): ReLoRaLinear(\n", 221 | " (lora_dropout): Dropout(p=0.1, inplace=False)\n", 222 | " (lora_A): Linear(in_features=768, out_features=128, bias=False)\n", 223 | " (lora_B): Linear(in_features=128, out_features=768, bias=False)\n", 224 | " )\n", 225 | " (o_proj): ReLoRaLinear(\n", 226 | " (lora_dropout): Dropout(p=0.1, inplace=False)\n", 227 | " (lora_A): Linear(in_features=768, out_features=128, bias=False)\n", 228 | " (lora_B): Linear(in_features=128, out_features=768, bias=False)\n", 229 | " )\n", 230 | " (rotary_emb): LlamaRotaryEmbedding()\n", 231 | " )\n", 232 | " (mlp): LlamaMLP(\n", 233 | " (gate_proj): ReLoRaLinear(\n", 234 | " (lora_dropout): Dropout(p=0.1, inplace=False)\n", 235 | " (lora_A): Linear(in_features=768, out_features=128, bias=False)\n", 236 | " (lora_B): Linear(in_features=128, out_features=2560, bias=False)\n", 237 | " )\n", 238 | " (down_proj): ReLoRaLinear(\n", 239 | " (lora_dropout): Dropout(p=0.1, inplace=False)\n", 240 | " (lora_A): Linear(in_features=2560, out_features=128, bias=False)\n", 241 | " (lora_B): Linear(in_features=128, out_features=768, bias=False)\n", 242 | " )\n", 243 | " (up_proj): ReLoRaLinear(\n", 244 | " (lora_dropout): Dropout(p=0.1, inplace=False)\n", 245 | " (lora_A): Linear(in_features=768, out_features=128, bias=False)\n", 246 | " (lora_B): Linear(in_features=128, out_features=2560, bias=False)\n", 247 | " )\n", 248 | " (act_fn): SiLUActivation()\n", 249 | " )\n", 250 | " (input_layernorm): LlamaRMSNorm()\n", 251 | " (post_attention_layernorm): LlamaRMSNorm()\n", 252 | " )\n", 253 | " )\n", 254 | " (norm): LlamaRMSNorm()\n", 255 | " )\n", 256 | " (lm_head): Linear(in_features=768, out_features=32000, bias=False)\n", 257 | " )\n", 258 | ")" 259 | ] 260 | }, 261 | "execution_count": 12, 262 | "metadata": {}, 263 | "output_type": "execute_result" 264 | } 265 | ], 266 | "source": [ 267 | "model" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 7, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "data": { 277 | "text/plain": [ 278 | "tensor([[[-4.1537, 1.7079, 0.5386, ..., -3.2834, -3.8020, -4.5960],\n", 279 | " [-3.6189, 3.0790, 1.9448, ..., -3.3880, -4.8104, -4.4347],\n", 280 | " [-3.0543, 2.4709, 0.8437, ..., -3.1641, -4.6998, -4.3173],\n", 281 | " ...,\n", 282 | " [-4.4516, 1.1609, 1.1042, ..., -5.4319, -4.2324, -5.0842],\n", 283 | " [-4.5135, 5.6720, 2.0341, ..., -1.5236, -4.5093, -4.6042],\n", 284 | " [-4.4559, 0.2325, 2.2894, ..., -2.6422, -5.2159, -4.2815]]],\n", 285 | " grad_fn=)" 286 | ] 287 | }, 288 | "execution_count": 7, 289 | "metadata": {}, 290 | "output_type": "execute_result" 291 | } 292 | ], 293 | "source": [ 294 | "orig_out.logits" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 19, 300 | "metadata": {}, 301 | "outputs": [ 302 | { 303 | "data": { 304 | "text/plain": [ 305 | "tensor(15.6308, device='cuda:0', grad_fn=)" 306 | ] 307 | }, 308 | "execution_count": 19, 309 | "metadata": {}, 310 | "output_type": "execute_result" 311 | } 312 | ], 313 | "source": [ 314 | "torch.dist(hf4bit_out.logits, orig_out.logits.cuda(), p=2)" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 8, 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "data": { 324 | "text/plain": [ 325 | "tensor(83.9014, grad_fn=)" 326 | ] 327 | }, 328 | "execution_count": 8, 329 | "metadata": {}, 330 | "output_type": "execute_result" 331 | } 332 | ], 333 | "source": [ 334 | "torch.dist(orig_out.logits.cpu(), quantized_out.logits.cpu(), p=2)" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "metadata": {}, 347 | "source": [ 348 | "### Debug/bnb" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 11, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "in_features = 128\n", 358 | "out_features = 64\n", 359 | "use_double_quant = False\n", 360 | "\n", 361 | "weight = bnb.nn.Linear4bit(\n", 362 | " in_features,\n", 363 | " out_features,\n", 364 | " bias=False,\n", 365 | " compute_dtype=torch.bfloat16,\n", 366 | " compress_statistics=use_double_quant,\n", 367 | " quant_type=\"nf4\",\n", 368 | ")\n", 369 | "bias = torch.tensor(out_features, dtype=torch.bfloat16, requires_grad=True, device=\"cuda\")\n", 370 | "weight = weight.to(\"cuda\")\n", 371 | "\n", 372 | "lora_A = nn.Linear(in_features, 1, bias=False).to(\"cuda\", dtype=torch.bfloat16)\n", 373 | "lora_B = nn.Linear(1, out_features, bias=False).to(\"cuda\", dtype=torch.bfloat16)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 13, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "x = torch.randn(2, in_features, device=\"cuda\", dtype=torch.bfloat16)\n", 383 | "y = weight(x) + bias\n", 384 | "y = y + lora_B(lora_A(x))\n", 385 | "\n", 386 | "loss = y.sum()\n", 387 | "loss.backward()" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 16, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "orig_weight = torch.randn(in_features, out_features)" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 19, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "net = nn.Sequential(\n", 406 | " nn.Linear(in_features, out_features),\n", 407 | " nn.ReLU(),\n", 408 | " nn.Linear(out_features, out_features),\n", 409 | ")" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 20, 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "net[0].weight = bnb.nn.Params4bit(net[0].weight.data)" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 17, 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "quantized_weight = bnb.nn.Params4bit(orig_weight.data, requires_grad=False, compress_statistics=False, quant_type=\"nf4\")" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 18, 433 | "metadata": {}, 434 | "outputs": [ 435 | { 436 | "data": { 437 | "text/plain": [ 438 | "Parameter containing:\n", 439 | "Parameter(Params4bit([[ 0.5485, -0.2513, 0.2402, ..., -0.7881, -0.4519, -1.0543],\n", 440 | " [-0.3215, -0.1178, -0.0623, ..., -0.2657, -0.2037, 3.4480],\n", 441 | " [ 1.4118, -1.0065, 1.5193, ..., -1.7599, 1.3230, -1.3040],\n", 442 | " ...,\n", 443 | " [ 1.5272, 1.4868, 0.7169, ..., -0.0711, -0.4521, 0.9336],\n", 444 | " [-0.0707, -1.3644, 1.0509, ..., 0.7394, -1.6139, -0.9520],\n", 445 | " [ 1.7725, -1.4115, 1.2637, ..., 0.4864, 1.9556, -0.5330]]))" 446 | ] 447 | }, 448 | "execution_count": 18, 449 | "metadata": {}, 450 | "output_type": "execute_result" 451 | } 452 | ], 453 | "source": [ 454 | "quantized_weight" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "metadata": {}, 461 | "outputs": [], 462 | "source": [] 463 | } 464 | ], 465 | "metadata": { 466 | "kernelspec": { 467 | "display_name": "peft_pretraining_shala", 468 | "language": "python", 469 | "name": "python3" 470 | }, 471 | "language_info": { 472 | "codemirror_mode": { 473 | "name": "ipython", 474 | "version": 3 475 | }, 476 | "file_extension": ".py", 477 | "mimetype": "text/x-python", 478 | "name": "python", 479 | "nbconvert_exporter": "python", 480 | "pygments_lexer": "ipython3", 481 | "version": "3.10.11" 482 | }, 483 | "orig_nbformat": 4 484 | }, 485 | "nbformat": 4, 486 | "nbformat_minor": 2 487 | } 488 | -------------------------------------------------------------------------------- /peft_pretraining/training_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | from functools import partial 5 | 6 | import torch 7 | import torch.distributed as dist 8 | from torch.optim.lr_scheduler import LambdaLR 9 | from torch.distributed.optim import ZeroRedundancyOptimizer 10 | from torch.distributed.fsdp import ( 11 | FullyShardedDataParallel as FSDP, 12 | MixedPrecision, 13 | ) 14 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 15 | 16 | import transformers 17 | import wandb 18 | 19 | from loguru import logger 20 | 21 | 22 | from peft_pretraining.modeling_llama import LlamaDecoderLayer 23 | 24 | 25 | def initialize_fsdp(model, dtype): 26 | wrapping_policy = partial( 27 | transformer_auto_wrap_policy, 28 | transformer_layer_cls={ 29 | LlamaDecoderLayer, 30 | }, 31 | ) 32 | 33 | if dtype in ["bf16", "bfloat16"]: 34 | mixed_precision_policy = MixedPrecision( 35 | param_dtype=torch.bfloat16, 36 | reduce_dtype=torch.bfloat16, # Gradient communication precision 37 | buffer_dtype=torch.bfloat16, # Buffer precision 38 | ) 39 | elif dtype == "float32": 40 | mixed_precision_policy = MixedPrecision( 41 | param_dtype=torch.float32, 42 | reduce_dtype=torch.float32, # Gradient communication precision 43 | buffer_dtype=torch.float32, # Buffer precision 44 | ) 45 | else: 46 | raise ValueError(f"Dtype {dtype} not supported (only float32 and bfloat16 are)") 47 | 48 | model = FSDP( 49 | model, 50 | mixed_precision=mixed_precision_policy, 51 | auto_wrap_policy=wrapping_policy, 52 | ) 53 | return model 54 | 55 | 56 | def get_scheculer( 57 | optimizer, 58 | *, 59 | scheduler_type, 60 | num_training_steps, 61 | warmup_steps, 62 | min_lr_ratio, 63 | cycle_length=None, 64 | restart_warmup_steps=None, 65 | adjust_step=0, 66 | last_epoch=-1, 67 | ): 68 | if adjust_step != 0 and scheduler_type != "cosine_restarts": 69 | raise ValueError("adjust_step is only supported for cosine_restarts scheduler") 70 | 71 | if scheduler_type == "linear": 72 | return transformers.get_linear_schedule_with_warmup( 73 | optimizer, 74 | num_warmup_steps=warmup_steps, 75 | num_training_steps=num_training_steps, 76 | last_epoch=last_epoch, 77 | ) 78 | if scheduler_type == "cosine": 79 | return get_cyclical_cosine_schedule_with_min_lr( 80 | optimizer, 81 | num_warmup_steps=warmup_steps, 82 | num_training_steps=num_training_steps, 83 | cycle_length=cycle_length, 84 | min_lr_ratio=min_lr_ratio, 85 | last_epoch=last_epoch, 86 | ) 87 | if scheduler_type == "cosine_restarts": 88 | assert restart_warmup_steps is not None, "restart_warmup_steps must be specified for cosine_restarts scheduler" 89 | return get_cosine_schedule_with_multiple_warmups( 90 | optimizer, 91 | num_training_steps=num_training_steps, 92 | first_warmup_steps=warmup_steps, 93 | restart_warmup_steps=restart_warmup_steps, 94 | restart_every=cycle_length, 95 | min_lr_ratio=min_lr_ratio, 96 | last_epoch=last_epoch, 97 | adjust_step=adjust_step, 98 | ) 99 | 100 | raise NotImplementedError(f"Scheduler {scheduler_type} is not implemented") 101 | 102 | 103 | def get_cyclical_cosine_schedule_with_min_lr(optimizer, num_warmup_steps, num_training_steps, cycle_length, min_lr_ratio=0.1, last_epoch=-1): 104 | assert cycle_length is not None or num_training_steps is not None, "You must specify either cycle_length or num_training_steps" 105 | 106 | if cycle_length is None: 107 | cycle_length = num_training_steps 108 | 109 | if num_training_steps % cycle_length != 0: 110 | raise ValueError(f"num_training_steps ({num_training_steps}) must be divisible by cycle_length ({cycle_length})") 111 | 112 | lr_lambda = partial( 113 | _get_cyclical_cosine_schedule_with_min_lr_lambda, 114 | num_warmup_steps=num_warmup_steps, 115 | cycle_length=cycle_length, 116 | min_lr_ratio=min_lr_ratio, 117 | ) 118 | return LambdaLR(optimizer, lr_lambda, last_epoch) 119 | 120 | 121 | def get_cosine_schedule_with_multiple_warmups( 122 | optimizer, 123 | *, 124 | num_training_steps, 125 | first_warmup_steps, 126 | restart_warmup_steps, 127 | restart_every, 128 | min_lr_ratio=0.1, 129 | adjust_step=0, 130 | last_epoch=-1, 131 | ): 132 | if restart_every is None: 133 | raise ValueError("restart_every must be specified for cosine_restarts scheduler") 134 | 135 | if num_training_steps % restart_every != 0: 136 | raise ValueError(f"num_training_steps ({num_training_steps}) must be divisible by restart_every ({restart_every})") 137 | 138 | lr_lambda = partial( 139 | _get_cosine_schedule_with_multiple_warmups_lambda, 140 | num_training_steps=num_training_steps, 141 | first_warmup_steps=first_warmup_steps, 142 | restart_warmup_steps=restart_warmup_steps, 143 | restart_every=restart_every, 144 | min_lr_ratio=min_lr_ratio, 145 | adjust_step=adjust_step, 146 | ) 147 | return LambdaLR(optimizer, lr_lambda, last_epoch) 148 | 149 | 150 | @torch.no_grad() 151 | def random_pruning_(tensor, prune_ratio): 152 | """ 153 | Performs random pruning dimensionality reduction **inplace**. 154 | Only reduces the inner dimensionality, does not affect the shape of the tensor 155 | """ 156 | random_pruning_mask = torch.rand_like(tensor) > prune_ratio 157 | tensor.mul_(random_pruning_mask) 158 | 159 | 160 | @torch.no_grad() 161 | def magnitude_pruning_(tensor, prune_ratio): 162 | """ 163 | Performs magnitude pruning dimensionality reduction **inplace**. 164 | Only reduces the inner dimensionality, does not affect the shape of the tensor 165 | """ 166 | tensor_magnitude = torch.abs(tensor) 167 | threshold = torch.quantile(tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio).to(dtype=tensor.dtype) 168 | 169 | mask = tensor_magnitude > threshold 170 | tensor.mul_(mask.to(dtype=tensor.dtype)) 171 | 172 | 173 | def _get_cyclical_cosine_schedule_with_min_lr_lambda(current_step, *, num_warmup_steps, cycle_length, min_lr_ratio): 174 | assert 0 < min_lr_ratio <= 1.0, "min_lr_ratio must be in (0,1]" 175 | 176 | # compute where we are in the current cycle 177 | cycle_step = current_step % cycle_length 178 | 179 | if cycle_step < num_warmup_steps: 180 | if current_step != cycle_step: 181 | if cycle_step < 2: 182 | return 1e-7 183 | return float(cycle_step) / float(max(1, num_warmup_steps)) 184 | 185 | progress = float(cycle_step - num_warmup_steps) / float(max(1, cycle_length - num_warmup_steps)) 186 | cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) 187 | 188 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay 189 | 190 | 191 | def _get_cosine_schedule_with_multiple_warmups_lambda( 192 | current_step, 193 | *, 194 | num_training_steps, 195 | first_warmup_steps, 196 | restart_warmup_steps, 197 | restart_every, 198 | min_lr_ratio, 199 | adjust_step, 200 | ): 201 | """ 202 | Args: 203 | adjust_step: useful when continuing training from a warmed up checkpoint, 204 | it allows to sync the resets by reducing the number of steps 205 | after the first warmup and before the first reset. 206 | Thus, your ReLoRA resets can be synced with the optimizer resets. 207 | """ 208 | assert 0 < min_lr_ratio <= 1.0, "min_lr_ratio must be in (0,1]" 209 | assert restart_every > 0, "restart_every must be positive" 210 | assert adjust_step + first_warmup_steps <= num_training_steps, "warmup + adjust_step is more than full training steps" 211 | assert adjust_step + first_warmup_steps <= restart_every, "the first reset will happen before the warmup is done" 212 | 213 | if current_step < first_warmup_steps: 214 | return float(current_step) / float(max(1, first_warmup_steps)) 215 | 216 | _current_step = current_step + adjust_step 217 | 218 | restart_step = _current_step % restart_every 219 | restart_number = _current_step // restart_every 220 | 221 | if restart_step < restart_warmup_steps and current_step >= restart_every: 222 | # get expected lr multipler at the end of the warmup 223 | end_of_warmup_progress = ( 224 | float(restart_number * restart_every + restart_warmup_steps - first_warmup_steps) / 225 | float(max(1, num_training_steps - first_warmup_steps)) 226 | ) 227 | 228 | _cosine_decay = 0.5 * (1.0 + math.cos(math.pi * end_of_warmup_progress)) 229 | warmup_lr_multiplier = min_lr_ratio + (1.0 - min_lr_ratio) * _cosine_decay 230 | 231 | return float(restart_step) / float(max(1, restart_warmup_steps)) * warmup_lr_multiplier 232 | 233 | progress = float(_current_step - first_warmup_steps) / float(max(1, num_training_steps - first_warmup_steps)) 234 | cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) 235 | 236 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay 237 | 238 | 239 | def max_train_tokens_to_number(max_train_tokens): 240 | if max_train_tokens.endswith("M"): 241 | return int(max_train_tokens.rstrip("M")) * 1_000_000 242 | elif max_train_tokens.endswith("B"): 243 | return int(max_train_tokens.rstrip("B")) * 1_000_000_000 244 | else: 245 | return int(max_train_tokens) 246 | 247 | 248 | def get_last_training_state(save_dir): 249 | # list all directories in the save_dir 250 | # find the model with the highest number of iterations "{args.save_dir}/model_{update_step}" 251 | model_dirs = [d for d in os.listdir(save_dir) if d.startswith(f"model_")] 252 | if len(model_dirs) == 0: 253 | logger.warning(f"Save directory {save_dir} exists, but does not contain any models.") 254 | logger.warning("Starting training from scratch.") 255 | return None, None 256 | 257 | model_dirs = sorted(model_dirs, key=lambda x: int(x.split("_")[-1])) 258 | resume_from = os.path.join(save_dir, model_dirs[-1]) 259 | 260 | logger.info(f"Restarting training from {resume_from}") 261 | with open(os.path.join(resume_from, "training_state.json")) as f: 262 | training_state = json.load(f) 263 | 264 | return training_state, resume_from 265 | 266 | 267 | def optimizer_reset( 268 | optimizer, 269 | *, 270 | reset_params: list[torch.nn.Parameter], 271 | optimizer_state_keys: list[str], 272 | reset_optimizer_on_relora: bool, 273 | optimizer_random_pruning: float, 274 | optimizer_magnitude_pruning: float, 275 | ): 276 | """ 277 | optimizer_state_keys: e.g., ["exp_avg", "exp_avg_sq"] 278 | """ 279 | n_reset_types = ( 280 | int(bool(reset_optimizer_on_relora)) 281 | + int(bool(optimizer_random_pruning)) 282 | + int(bool(optimizer_magnitude_pruning)) 283 | ) 284 | if n_reset_types != 1: 285 | logger.warning(f"Got {reset_optimizer_on_relora=}, {optimizer_random_pruning=}, " 286 | f"{optimizer_magnitude_pruning=}") 287 | raise ValueError(f"Exactly one of reset_optimizer_on_relora, " 288 | f"optimizer_random_pruning, optimizer_magnitude_pruning must be True") 289 | 290 | # pruning_fn has to be inplace to work with ZeroRedundancyOptimizer 291 | if reset_optimizer_on_relora: 292 | logger.info("Resetting optimizer states to zeros") 293 | # looks like zeroing out breaks dictionary in the optimizer 294 | # see full error below 295 | pruning_fn = partial(random_pruning_, prune_ratio=0.999) 296 | elif optimizer_random_pruning: 297 | logger.info(f"Performing random pruning of optimizer states. " 298 | f"Pruning {optimizer_random_pruning} percent") 299 | pruning_fn = partial(random_pruning_, prune_ratio=optimizer_random_pruning) 300 | elif optimizer_magnitude_pruning: 301 | logger.info(f"Performing magnitude pruning of optimizer states. " 302 | f"Pruning {optimizer_magnitude_pruning} percent") 303 | pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning) 304 | else: 305 | raise ValueError("Unknown pruning type") 306 | 307 | # ############################################################ 308 | # A reminder on how optimizer state is structured for regular optimizers: 309 | # optimizer.state is a dict[torch.nn.Parameter, dict[str, torch.Tensor]] 310 | # optimizer.state[p] is a dict[str, torch.Tensor] where str is 311 | # an optimizer state key e.g., "exp_avg", "exp_avg_sq" 312 | # Note that none of these tensors has parameter names 313 | # and parameter maps to a **dictionary** of opt. states, not a tensor 314 | # 315 | # For ZeroRedundancyOptimizer, it works differently. 316 | # ZeroRedundancyOptimizer.state always maps to empty dicts. 317 | # Instead, it uses optimizer.optim.state for rank-local updates. 318 | # 319 | # For some reason, zeroing out a tensor in ZeroRedundancyOptimizer.opt.state 320 | # causes an error during state_dict collection. 321 | # This is why we use 0.999 pruning ratio for reset_optimizer case. 322 | # 323 | # Here's an error that happens: 324 | # 325 | # Traceback (most recent call last): 326 | # File ".../peft_pretraining/torchrun_main.py", line 866, in 327 | # main(args) 328 | # File ".../peft_pretraining/torchrun_main.py", line 715, in main 329 | # save_model( 330 | # File ".../peft_pretraining/torchrun_main.py", line 289, in save_model 331 | # save_model_ddp(model, optimizer, scheduler, training_state_checkpoint, run_config, save_dir) 332 | # File ".../peft_pretraining/torchrun_main.py", line 224, in save_model_ddp 333 | # optimizer.consolidate_state_dict() 334 | # File ".../python3.10/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py", line 565, in consolidate_state_dict 335 | # self.optim.state_dict(), 336 | # File ".../python3.10/site-packages/torch/optim/optimizer.py", line 364, in state_dict 337 | # packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v 338 | # File ".../python3.10/site-packages/torch/optim/optimizer.py", line 364, in 339 | # packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v 340 | # KeyError: 140580723685184 341 | # 342 | # One one hand, the hypothesis is that making a zero tensor 343 | # is implementing by changing the pointer in the memory to 344 | # an existing zero-tensor. But on the other hand, we didn't 345 | # have issues with that when using regular Adam, without ZeroRedundancyOptimizer wrapper. 346 | # ############################################################ 347 | n_zeros = 0 348 | n_total = 0 349 | 350 | optimizer_state = optimizer.state 351 | if isinstance(optimizer, ZeroRedundancyOptimizer): 352 | optimizer_state = optimizer.optim.state 353 | 354 | for p in reset_params: 355 | param_state = optimizer_state[p] 356 | if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer 357 | continue 358 | for key in optimizer_state_keys: 359 | pruning_fn(param_state[key]) # pruning fn has to be inplace to keep the same keys in the dict 360 | n_total += param_state[key].numel() 361 | n_zeros += torch.sum(param_state[key] == 0).item() 362 | 363 | _zeroed = n_zeros / (1e-7 + n_total) * 100 364 | logger.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}") 365 | 366 | 367 | def print_optimizer_state_size(optimizer): 368 | # Count the number of floats in the first and second moments 369 | first_moment_count = 0 370 | second_moment_count = 0 371 | 372 | optimizer_state = optimizer.state 373 | if isinstance(optimizer, ZeroRedundancyOptimizer): 374 | optimizer_state = optimizer.optim.state 375 | 376 | for state in optimizer_state.values(): 377 | if len(state) == 0: # no state for this param, happens for ZeRo optimizer 378 | continue 379 | 380 | first_moment_count += torch.numel(state['exp_avg']) 381 | second_moment_count += torch.numel(state['exp_avg_sq']) 382 | 383 | global_rank = 0 384 | if dist.is_initialized(): 385 | global_rank = dist.get_rank() 386 | 387 | print(f"(Rank {global_rank}) Number of floats in the first moment: {first_moment_count / 1_000_000:.2f}M") 388 | print(f"(Rank {global_rank}) Number of floats in the second moment: {second_moment_count / 1_000_000:.2f}M") 389 | 390 | 391 | def check_lr_and_alert(optimizer, max_lr): 392 | global_rank = 0 if not dist.is_initialized() else dist.get_rank() 393 | 394 | lr = optimizer.param_groups[0]["lr"] 395 | if lr <= max_lr: return 396 | 397 | alert_message = f"Optimizer lr after the reset is large. This can lead to instability. Current lr is {lr}" 398 | logger.warning(alert_message) 399 | if global_rank == 0: 400 | wandb.alert( 401 | title="Learning rate issue", 402 | text=alert_message, 403 | level=wandb.AlertLevel.WARN, 404 | ) 405 | 406 | def delete_old_checkpoints(save_dir, keep): 407 | if keep is None: 408 | return 409 | 410 | checkpoints = [d for d in os.listdir(save_dir) if d.startswith(f"model_")] 411 | if len(checkpoints) <= keep: 412 | return 413 | 414 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("_")[-1])) 415 | for checkpoint in checkpoints[:-keep]: 416 | checkpoint_path = os.path.join(save_dir, checkpoint) 417 | logger.info(f"Deleting checkpoint {checkpoint_path}") 418 | os.system(f"rm -rf {checkpoint_path}") 419 | -------------------------------------------------------------------------------- /peft_pretraining/megatron_dataset/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, EleutherAI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from typing import List, Tuple 17 | from itertools import zip_longest 18 | from functools import partial 19 | 20 | import torch 21 | import torch.utils.data 22 | import torch.distributed as dist 23 | 24 | import numpy as np 25 | from loguru import logger 26 | 27 | 28 | from peft_pretraining.megatron_dataset.indexed_dataset import make_dataset as make_indexed_dataset 29 | from peft_pretraining.megatron_dataset.blendable_dataset import BlendableDataset 30 | from peft_pretraining.megatron_dataset.dataset import GPT2Dataset 31 | from peft_pretraining.megatron_dataset.samplers import DistributedBatchSampler 32 | 33 | 34 | def make_data_loader(dataset, neox_args): 35 | """Build dataloader given an input dataset.""" 36 | if dataset is None: 37 | return None 38 | # Data parallel arguments. 39 | world_size = 1 40 | rank = 0 41 | if dist.is_initialized(): 42 | world_size = dist.get_world_size() 43 | rank = dist.get_rank() 44 | else: 45 | logger.warning("Not using distributed mode. Should only be used for debugging.") 46 | 47 | global_batch_size = neox_args.batch_size * world_size 48 | num_workers = neox_args.num_workers 49 | 50 | # Use a simple sampler with distributed batch sampler. 51 | sampler = torch.utils.data.SequentialSampler(dataset) 52 | batch_sampler = DistributedBatchSampler( 53 | sampler=sampler, 54 | batch_size=global_batch_size, 55 | drop_last=True, 56 | rank=rank, 57 | world_size=world_size, 58 | ) 59 | # Torch dataloader. 60 | return torch.utils.data.DataLoader( 61 | dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True 62 | ) 63 | 64 | 65 | def build_the_dataset( 66 | data_prefix, 67 | name, 68 | data_impl, 69 | num_samples, 70 | seq_length, 71 | seed, 72 | skip_warmup, 73 | build_index_mappings=True, 74 | label_prefix=None, 75 | ): 76 | """Build train/valid/test datasets.""" 77 | 78 | indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) 79 | if label_prefix is None: 80 | label_dataset = None 81 | else: 82 | label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup) 83 | 84 | total_num_of_documents = indexed_dataset.sizes.shape[0] 85 | logger.info(" {}:".format(name)) 86 | logger.info(" no. of documents:{}".format(total_num_of_documents)) 87 | dataset = None 88 | documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) 89 | dataset = GPT2Dataset( 90 | name, 91 | data_prefix, 92 | documents, 93 | indexed_dataset, 94 | num_samples, 95 | seq_length, 96 | seed, 97 | build_index_mappings=build_index_mappings, 98 | label_dataset=label_dataset, 99 | ) 100 | return dataset 101 | 102 | 103 | def build_train_valid_test_datasets( 104 | data_prefix, 105 | use_shared_fs, 106 | data_impl, 107 | splits_string, 108 | train_valid_test_num_samples, 109 | seq_length, 110 | seed, 111 | skip_warmup, 112 | ): 113 | """Build train, valid, and test datasets.""" 114 | 115 | # Indexed dataset. 116 | indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) 117 | 118 | total_num_of_documents = indexed_dataset.sizes.shape[0] 119 | splits = get_train_valid_test_split_(splits_string, total_num_of_documents) 120 | 121 | # Print stats about the splits. 122 | logger.info(" > dataset split:") 123 | 124 | def print_split_stats(name, index): 125 | logger.info(" {}:".format(name)) 126 | logger.info( 127 | " document indices in [{}, {}) total of {} " 128 | "documents".format( 129 | splits[index], splits[index + 1], splits[index + 1] - splits[index] 130 | ) 131 | ) 132 | 133 | print_split_stats("train", 0) 134 | print_split_stats("validation", 1) 135 | print_split_stats("test", 2) 136 | 137 | def build_dataset(index, name): 138 | dataset = None 139 | if splits[index + 1] > splits[index]: 140 | documents = np.arange( 141 | start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 142 | ) 143 | 144 | dataset = GPT2Dataset( 145 | name, 146 | data_prefix, 147 | documents, 148 | indexed_dataset, 149 | train_valid_test_num_samples[index], 150 | seq_length, 151 | seed, 152 | use_shared_fs=use_shared_fs, 153 | ) 154 | return dataset 155 | 156 | train_dataset = build_dataset(0, "train") 157 | valid_dataset = build_dataset(1, "valid") 158 | test_dataset = build_dataset(2, "test") 159 | 160 | return train_dataset, valid_dataset, test_dataset 161 | 162 | 163 | def get_train_valid_test_split_(splits_string, size): 164 | """Get dataset splits from comma or '/' separated string list.""" 165 | 166 | splits = [] 167 | if splits_string.find(",") != -1: 168 | splits = [float(s) for s in splits_string.split(",")] 169 | elif splits_string.find("/") != -1: 170 | splits = [float(s) for s in splits_string.split("/")] 171 | else: 172 | splits = [float(splits_string)] 173 | while len(splits) < 3: 174 | splits.append(0.0) 175 | splits = splits[:3] 176 | splits_sum = sum(splits) 177 | assert splits_sum > 0.0 178 | splits = [split / splits_sum for split in splits] 179 | splits_index = [0] 180 | for index, split in enumerate(splits): 181 | splits_index.append(splits_index[index] + int(round(split * float(size)))) 182 | diff = splits_index[-1] - size 183 | for index in range(1, len(splits_index)): 184 | splits_index[index] -= diff 185 | assert len(splits_index) == 4 186 | assert splits_index[-1] == size 187 | return splits_index 188 | 189 | 190 | def get_normalized_weights_and_num_samples( 191 | weights: List[float], num_samples: int 192 | ) -> Tuple[List[float], List[int]]: 193 | # Normalize weights 194 | weight_sum = sum(weights) 195 | assert weight_sum > 0.0 196 | weights = [weight / weight_sum for weight in weights] 197 | # Add 0.5% (the 1.005 factor) so in case the blending dataset does 198 | # not uniformly distribute the number of samples, we still have 199 | # samples left to feed to the network. 200 | weighted_num_samples = [] 201 | for weight in weights: 202 | weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) 203 | return weights, weighted_num_samples 204 | 205 | 206 | def build_weighted_datasets( 207 | neox_args, 208 | train_num_samples, 209 | valid_num_samples, 210 | test_num_samples, 211 | train_weights, 212 | valid_weights, 213 | test_weights, 214 | build_index_mappings=True, 215 | ): 216 | # build individual datasets 217 | train_datasets, valid_datasets, test_datasets = [], [], [] 218 | for i, (train_path, label_path, valid_path, test_path) in enumerate( 219 | zip_longest( 220 | neox_args.train_data_paths, 221 | neox_args.label_data_paths if neox_args.label_data_paths else [], 222 | neox_args.valid_data_paths, 223 | neox_args.test_data_paths, 224 | ) 225 | ): 226 | if train_path: 227 | train_datasets.append( 228 | build_the_dataset( 229 | data_prefix=train_path, 230 | name=f"train_{i}", 231 | data_impl=neox_args.data_impl, 232 | num_samples=train_num_samples[i], 233 | seq_length=neox_args.seq_length, 234 | seed=neox_args.seed, 235 | skip_warmup=(not neox_args.mmap_warmup), 236 | build_index_mappings=build_index_mappings, 237 | label_prefix=label_path, 238 | ) 239 | ) 240 | 241 | if valid_path: 242 | valid_datasets.append( 243 | build_the_dataset( 244 | data_prefix=valid_path, 245 | name=f"valid_{i}", 246 | data_impl=neox_args.data_impl, 247 | num_samples=valid_num_samples[i], 248 | seq_length=neox_args.seq_length, 249 | seed=neox_args.seed, 250 | skip_warmup=(not neox_args.mmap_warmup), 251 | build_index_mappings=build_index_mappings, 252 | ) 253 | ) 254 | 255 | if test_path: 256 | test_datasets.append( 257 | build_the_dataset( 258 | data_prefix=test_path, 259 | name=f"test_{i}", 260 | data_impl=neox_args.data_impl, 261 | num_samples=test_num_samples[i], 262 | seq_length=neox_args.seq_length, 263 | seed=neox_args.seed, 264 | skip_warmup=(not neox_args.mmap_warmup), 265 | build_index_mappings=build_index_mappings, 266 | ) 267 | ) 268 | return train_datasets, valid_datasets, test_datasets 269 | 270 | 271 | def weights_by_num_docs(l: list, alpha=0.3): 272 | """ 273 | Builds weights from a multinomial distribution over groups of data according to the number of 274 | samples in each group. 275 | 276 | We sample from a group according to the probability p(L) ∝ |L| ** α, 277 | where p(L) is the probability of sampling from a given group, 278 | |L| is the number of examples in that datapoint, 279 | and α is a coefficient that acts to upsample data from underrepresented groups 280 | 281 | Hence α (`alpha`) allows us to control how much to 'boost' the probability of training on low-resource groups. 282 | 283 | See https://arxiv.org/abs/1911.02116 for more details 284 | """ 285 | if len(l) == 1: 286 | return [1.0] 287 | 288 | total_n_docs = sum(l) 289 | unbiased_sample_probs = [i / total_n_docs for i in l] 290 | 291 | probs = [i**alpha for i in unbiased_sample_probs] 292 | 293 | # normalize 294 | total = sum(probs) 295 | probs = [i / total for i in probs] 296 | 297 | # weights should be the inverse of the number of samples 298 | unbiased_sample_probs_inverse = [1 - p for p in unbiased_sample_probs] 299 | weights = [p * p2 for p, p2 in zip(probs, unbiased_sample_probs_inverse)] 300 | 301 | # normalize 302 | total = sum(weights) 303 | weights = [i / total for i in weights] 304 | 305 | return weights 306 | 307 | # NOTE: original function was returning iterators, but now we return dataloaders 308 | def build_train_valid_test_dataloaders(neox_args): 309 | """XXX""" 310 | 311 | (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) 312 | 313 | logger.info("> building train, validation, and test datasets ...") 314 | 315 | # Ensure only the first/last pipeline stages have data loaders 316 | assert not neox_args.is_pipe_parallel, "we removed pipeilne parallelism from ReLoRA version of megatron dataloading for simplicity" 317 | pipe_load = True 318 | 319 | # Data loader only on rank 0 of each model parallel group. 320 | if pipe_load: 321 | # Number of train/valid/test samples. 322 | train_iters = neox_args.train_iters 323 | eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters 324 | test_iters = neox_args.eval_iters 325 | train_val_test_num_samples = [ 326 | train_iters * neox_args.train_batch_size, 327 | eval_iters * neox_args.train_batch_size, 328 | test_iters * neox_args.train_batch_size, 329 | ] 330 | 331 | if neox_args.train_data_paths: 332 | # when individual train / valid / test data paths are provided 333 | # normalize weight values and get num samples for each dataset 334 | train_weights, train_num_samples = get_normalized_weights_and_num_samples( 335 | neox_args.train_data_weights, train_val_test_num_samples[0] 336 | ) 337 | valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( 338 | neox_args.valid_data_weights, train_val_test_num_samples[1] 339 | ) 340 | test_weights, test_num_samples = get_normalized_weights_and_num_samples( 341 | neox_args.test_data_weights, train_val_test_num_samples[2] 342 | ) 343 | 344 | # build individual datasets 345 | train_datasets, valid_datasets, test_datasets = build_weighted_datasets( 346 | neox_args, 347 | train_num_samples, 348 | valid_num_samples, 349 | test_num_samples, 350 | train_weights, 351 | valid_weights, 352 | test_weights, 353 | build_index_mappings=not neox_args.weight_by_num_documents, 354 | ) 355 | 356 | if neox_args.weight_by_num_documents: 357 | 358 | # gets the number of documents in each datapath 359 | get_num_docs_list = lambda datasets: [ 360 | dataset.indexed_dataset.sizes.shape[0] for dataset in datasets 361 | ] 362 | train_num_docs, valid_num_docs, test_num_docs = ( 363 | get_num_docs_list(train_datasets), 364 | get_num_docs_list(valid_datasets), 365 | get_num_docs_list(test_datasets), 366 | ) 367 | 368 | # builds weights according to alpha + the number of docs 369 | fn = partial( 370 | weights_by_num_docs, alpha=neox_args.weighted_sampler_alpha 371 | ) 372 | train_weights, valid_weights, test_weights = ( 373 | fn(train_num_docs), 374 | fn(valid_num_docs), 375 | fn(test_num_docs), 376 | ) 377 | ( 378 | train_weights, 379 | train_num_samples, 380 | ) = get_normalized_weights_and_num_samples( 381 | train_weights, train_val_test_num_samples[0] 382 | ) 383 | ( 384 | valid_weights, 385 | valid_num_samples, 386 | ) = get_normalized_weights_and_num_samples( 387 | valid_weights, train_val_test_num_samples[1] 388 | ) 389 | test_weights, test_num_samples = get_normalized_weights_and_num_samples( 390 | test_weights, train_val_test_num_samples[2] 391 | ) 392 | 393 | # rebuild datasets weighted according to new weights 394 | train_datasets, valid_datasets, test_datasets = build_weighted_datasets( 395 | neox_args, 396 | train_num_samples, 397 | valid_num_samples, 398 | test_num_samples, 399 | train_weights, 400 | valid_weights, 401 | test_weights, 402 | ) 403 | 404 | if train_datasets: 405 | train_ds = BlendableDataset(train_datasets, train_weights) 406 | if valid_datasets: 407 | valid_ds = BlendableDataset(valid_datasets, valid_weights) 408 | if test_datasets: 409 | test_ds = BlendableDataset(test_datasets, test_weights) 410 | else: 411 | # when just data_path is provided 412 | # split dataset into train, valid and test from data_path 413 | train_ds, valid_ds, test_ds = build_train_valid_test_datasets( 414 | data_prefix=neox_args.data_path, 415 | use_shared_fs=neox_args.use_shared_fs, 416 | data_impl=neox_args.data_impl, 417 | splits_string=neox_args.split, 418 | train_valid_test_num_samples=train_val_test_num_samples, 419 | seq_length=neox_args.seq_length, 420 | seed=neox_args.seed, 421 | skip_warmup=(not neox_args.mmap_warmup), 422 | ) 423 | 424 | # Build dataloders. 425 | train_dataloader = make_data_loader(train_ds, neox_args=neox_args) 426 | valid_dataloader = make_data_loader(valid_ds, neox_args=neox_args) 427 | test_dataloader = make_data_loader(test_ds, neox_args=neox_args) 428 | 429 | # Flags to know if we need to do training/validation/testing. 430 | do_train = train_dataloader is not None and neox_args.train_iters > 0 431 | do_valid = valid_dataloader is not None and neox_args.eval_iters > 0 432 | do_test = test_dataloader is not None and neox_args.eval_iters > 0 433 | # Need to broadcast num_tokens and num_type_tokens. 434 | flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)]) 435 | else: 436 | raise RuntimeError("We removed pipeilne parallelism from ReLoRA version of megatron dataloading for simplicity") 437 | flags = torch.cuda.LongTensor([0, 0, 0]) 438 | 439 | neox_args.do_train = flags[0].item() 440 | neox_args.do_valid = flags[1].item() 441 | neox_args.do_test = flags[2].item() 442 | 443 | # Shift the start iterations. 444 | if train_dataloader is not None: 445 | train_dataloader.batch_sampler.start_iter = ( 446 | neox_args.iteration * neox_args.gradient_accumulation_steps 447 | ) % len(train_dataloader) 448 | logger.info( 449 | "setting training data start iteration to {}".format( 450 | train_dataloader.batch_sampler.start_iter 451 | ) 452 | ) 453 | if valid_dataloader is not None: 454 | start_iter_val = ( 455 | (neox_args.iteration * neox_args.gradient_accumulation_steps) 456 | // neox_args.eval_interval 457 | ) * neox_args.eval_iters 458 | valid_dataloader.batch_sampler.start_iter = start_iter_val % len( 459 | valid_dataloader 460 | ) 461 | logger.info( 462 | "setting validation data start iteration to {}".format( 463 | valid_dataloader.batch_sampler.start_iter 464 | ) 465 | ) 466 | 467 | return train_dataloader, valid_dataloader, test_dataloader 468 | 469 | 470 | def compile_helper(): 471 | """Compile helper function at runtime. Make sure this 472 | is invoked on a single process.""" 473 | import os 474 | import subprocess 475 | 476 | path = os.path.abspath(os.path.dirname(__file__)) 477 | ret = subprocess.run(["make", "-C", path]) 478 | if ret.returncode != 0: 479 | print("Making C++ dataset helpers module failed, exiting.") 480 | import sys 481 | 482 | sys.exit(1) 483 | --------------------------------------------------------------------------------