├── requirements.txt ├── llm-experiments ├── scripts │ ├── run-train-interleave-cnn.sh │ ├── run-train-interleave-cnn-optimizer.sh │ ├── run-train-interleave-cnn-scratch.sh │ ├── run-train-interleave-cnn-optimizer-reset.sh │ ├── run-train-interleave-cnn-randommask.sh │ ├── run-train-interleave-cnn-randomwindow.sh │ ├── run-train-interleave-frozen-blocks.sh │ ├── run-train-interleave-cnn-scratch-ablation.sh │ └── run-train-interleave-cnn-rev.sh ├── visualization │ ├── visualization_minimums.py │ ├── visualize_rep.py │ └── dipping_matrix.py └── training │ ├── train_interleave_scratch.py │ ├── train_interleave_optimizer_reset.py │ ├── train_interleave_optimizer.py │ ├── train_interleave.py │ ├── train_interleave_randommask.py │ ├── train_interleave_frozen_blocks.py │ ├── train_interleave_randomwindow.py │ └── train_interleave_ablation.py ├── LICENSE ├── README.md ├── imagenet-experiments └── exp_imagenet.py └── igpt-experiments └── train_interleave_igpt.py /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.23.0 2 | datasets==2.14.4 3 | huggingface-hub==0.17.1 4 | matplotlib==3.8.0 5 | numpy==1.25.2 6 | peft==0.5.0 7 | Pillow==9.4.0 8 | scikit-learn==1.3.0 9 | scipy==1.11.2 10 | timm==0.9.12 11 | tokenizers==0.15.0 12 | torch==2.0.1 13 | torchvision==0.15.2 14 | tqdm==4.66.1 15 | transformers==4.35.2 16 | -------------------------------------------------------------------------------- /llm-experiments/scripts/run-train-interleave-cnn.sh: -------------------------------------------------------------------------------- 1 | python training/train_interleave.py --model_name_or_path "EleutherAI/pythia-1b" --revision main --dataset_name "cnn_dailymail" --dataset_config_name 3.0.0 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --learning_rate 0.001 --output_dir checkpoints/1b-5grad-512 --save_prefix batch1_gpu1 --block_size 512 --num_train_epochs 5 --overwrite_cache --save_freq 1 --num-grad-steps 5 --num-data-samples 25 2 | 3 | -------------------------------------------------------------------------------- /llm-experiments/scripts/run-train-interleave-cnn-optimizer.sh: -------------------------------------------------------------------------------- 1 | python training/train_interleave_optimizer.py --model_name_or_path "EleutherAI/pythia-1b" --revision main --dataset_name "cnn_dailymail" --dataset_config_name 3.0.0 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --learning_rate 0.0001 --output_dir checkpoints/Adam1e-4 --save_prefix batch1_gpu1 --block_size 256 --num_train_epochs 5 --overwrite_cache --save_freq 1 --num-grad-steps 10 --num-data-samples 25 2 | 3 | -------------------------------------------------------------------------------- /llm-experiments/scripts/run-train-interleave-cnn-scratch.sh: -------------------------------------------------------------------------------- 1 | python training/train_interleave.py --model_name_or_path "EleutherAI/pythia-2.8b" --revision step0 --dataset_name "cnn_dailymail" --dataset_config_name 3.0.0 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --learning_rate 0.001 --output_dir checkpoints/pythia-2.8b-10grad-25tasks-scratch --save_prefix batch1_gpu1 --block_size 256 --num_train_epochs 5 --overwrite_cache --save_freq 1 --num-grad-steps 10 --num-data-samples 25 2 | -------------------------------------------------------------------------------- /llm-experiments/scripts/run-train-interleave-cnn-optimizer-reset.sh: -------------------------------------------------------------------------------- 1 | python training/train_interleave_optimizer_reset.py --model_name_or_path "EleutherAI/pythia-1b" --revision main --dataset_name "cnn_dailymail" --dataset_config_name 3.0.0 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --learning_rate 0.0001 --output_dir checkpoints/Adam1e-4-reset --save_prefix batch1_gpu1 --block_size 256 --num_train_epochs 5 --overwrite_cache --save_freq 1 --num-grad-steps 10 --num-data-samples 25 2 | 3 | -------------------------------------------------------------------------------- /llm-experiments/scripts/run-train-interleave-cnn-randommask.sh: -------------------------------------------------------------------------------- 1 | python training/train_interleave_randommask.py --model_name_or_path "EleutherAI/pythia-1b" --revision main --dataset_name "cnn_dailymail" --dataset_config_name 3.0.0 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --learning_rate 0.001 --output_dir checkpoints/randommask-0.3 --save_prefix batch1_gpu1 --block_size 256 --num_train_epochs 5 --overwrite_cache --save_freq 1 --num-grad-steps 10 --num-data-samples 25 --random-prob 0.3 2 | -------------------------------------------------------------------------------- /llm-experiments/scripts/run-train-interleave-cnn-randomwindow.sh: -------------------------------------------------------------------------------- 1 | python training/train_interleave_randomwindow.py --model_name_or_path "EleutherAI/pythia-1b" --revision main --dataset_name "cnn_dailymail" --dataset_config_name 3.0.0 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --learning_rate 0.001 --output_dir checkpoints/randomwindow-0 --save_prefix batch1_gpu1 --block_size 512 --num_train_epochs 5 --overwrite_cache --save_freq 1 --num-grad-steps 10 --num-data-samples 25 --random-window 0 2 | 3 | -------------------------------------------------------------------------------- /llm-experiments/scripts/run-train-interleave-frozen-blocks.sh: -------------------------------------------------------------------------------- 1 | python training/train_interleave_frozen_blocks.py --model_name_or_path "EleutherAI/pythia-1b" --revision main --dataset_name "cnn_dailymail" --dataset_config_name 3.0.0 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --learning_rate 0.001 --output_dir checkpoints/frozenblocksrev8 --save_prefix batch1_gpu1 --block_size 256 --num_train_epochs 5 --overwrite_cache --save_freq 1 --num-grad-steps 10 --num-data-samples 25 --num-frozen-blocks-rev 8 2 | 3 | -------------------------------------------------------------------------------- /llm-experiments/scripts/run-train-interleave-cnn-scratch-ablation.sh: -------------------------------------------------------------------------------- 1 | python training/train_interleave_scratch.py --model_name_or_path "EleutherAI/pythia-1b" --dataset_name "cnn_dailymail" --dataset_config_name 3.0.0 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --learning_rate 0.001 --output_dir checkpoints/ablation-w2048-d16-h2 --save_prefix batch1_gpu1 --block_size 256 --num_train_epochs 5 --overwrite_cache --save_freq 1 --num-grad-steps 10 --num-data-samples 25 --num_hidden_layers 16 --hidden_size 2048 --num_attention_heads 2 2 | -------------------------------------------------------------------------------- /llm-experiments/scripts/run-train-interleave-cnn-rev.sh: -------------------------------------------------------------------------------- 1 | python training/train_interleave_rev.py --model_name_or_path "EleutherAI/pythia-1b" --revision main --dataset_name "cnn_dailymail" --dataset_config_name 3.0.0 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --learning_rate 0.001 --output_dir checkpoints/pythia-1b-10grad-50data-ckpt144k-cnn-256-shuffle-rev-sgd-48 --save_prefix batch1_gpu1 --block_size 256 --num_train_epochs 5 --overwrite_cache --checkpointing_steps epoch --save_freq 1 --num-grad-steps 10 --num-data-samples 50 --start-shuffle-data-samples 48 2 | 3 | -------------------------------------------------------------------------------- /llm-experiments/visualization/visualization_minimums.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | import torch 5 | import numpy as np 6 | import sklearn 7 | import sklearn.decomposition 8 | from tqdm import tqdm 9 | 10 | load_dir = '.' # Change to the experiment saving directory 11 | 12 | num_tasks = 25 13 | num_epochs = 3 14 | num_total_tasks = num_tasks * num_epochs 15 | last_layer_weights = [] 16 | all_tasks = range(num_total_tasks) 17 | # all_tasks = range(10) 18 | for task_num in tqdm(all_tasks): 19 | log_dir = f'{load_dir}/task_{task_num}' 20 | model_file = os.path.join(log_dir, 'pytorch_model.bin') 21 | model_weights = torch.load(model_file) 22 | last_layer_weight = model_weights['embed_out.weight'].cpu().numpy().flatten() 23 | last_layer_weights.append(last_layer_weight) 24 | 25 | last_layer_weights = np.stack(last_layer_weights) 26 | print(last_layer_weights.shape) 27 | 28 | pca = sklearn.decomposition.PCA(n_components=3) 29 | X = pca.fit_transform(last_layer_weights) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Agentic Learning AI Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Anticipatory Recovery in Sequential Cyclic Fine-tuning 2 | 3 | Code for the paper "Reawakening knowledge: Anticipatory recovery from catastrophic interference via structured training" ([NeurIPS 2024](https://openreview.net/pdf?id=YSs1z5udBY)). 4 | 5 | The `llm-experiments` folder includes the Language model ([Pythia](https://github.com/EleutherAI/pythia)) experiments; the `igpt-experiments` folder includes the [Image GPT](https://github.com/openai/image-gpt) experiments; the `imagenet-experiments` folder includes the image classification experiments. 6 | 7 | ## Installation 8 | To install requirements: 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## LLM Experiments 14 | 15 | Example commands for cyclic fine-tuning experiments can be found in the `llm-experiments/scripts` folder. 16 | 17 | Code for visualizing the pairwise recovery matrix (Figures 8b and 18), PCA in the last layer weights (Figure 9), and representations (Figure 8d) can be found in the `llm-experiments/visualization` folder. 18 | 19 | ## Image GPT Experiments 20 | 21 | Example command for cyclic fine-tuning with Image GPT: 22 | ``` 23 | python train_interleave_igpt.py \ 24 | --learning_rate 0.001 \ 25 | --model_size medium \ 26 | --output_dir ./medium-20steps \ 27 | --save_prefix batch1_gpu1 \ 28 | --num_train_epochs 5 \ 29 | --num-grad-steps 20 \ 30 | --num-data-samples 25 31 | ``` 32 | 33 | where ```num-grad-steps``` is the number of consecutive gradient update steps on each image, and ```num-data-samples``` is the number of images in the sequence. 34 | 35 | ## Acknowledgements 36 | 37 | The code is adapted from [Huggingface Transformers](https://github.com/huggingface/transformers.git) and Emin Orhan's [LLM Memory Experiments](https://github.com/eminorhan/llm-memory.git). 38 | -------------------------------------------------------------------------------- /imagenet-experiments/exp_imagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.models 6 | from torchvision import transforms 7 | from torchvision.datasets import ImageFolder 8 | import os 9 | from torch.utils.data import DataLoader, Subset 10 | from tqdm.auto import tqdm 11 | import random 12 | 13 | device = 'cuda' 14 | 15 | ### CONFIGS ### 16 | num_grad_steps = 10 17 | num_epochs = 5 18 | lr = 0.0001 19 | batch_size = 32 20 | num_tasks = 25 21 | 22 | # model = torchvision.models.vit_b_32(weights=torchvision.models.ViT_B_32_Weights.DEFAULT) 23 | # model = torchvision.models.vit_b_16() 24 | model = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT) 25 | transform = torchvision.models.VGG19_Weights.DEFAULT.transforms() # Correspond to the same model above 26 | 27 | print("Model Parameter Count", sum([np.prod(p.size()) for p in model.parameters()])) 28 | model = model.to(device) 29 | 30 | 31 | class_names = os.listdir('/imagenet/train') 32 | dataset = ImageFolder(root='/imagenet/train', transform=transform) 33 | random_idx = random.sample(range(len(dataset)), batch_size * num_tasks) 34 | small_dataset = Subset(dataset, random_idx) 35 | eval_dataset = Subset(dataset, random_idx) 36 | 37 | dataloader = DataLoader(small_dataset, shuffle=False, batch_size=batch_size, num_workers=2) 38 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=batch_size, num_workers=2) 39 | 40 | loss_fn = nn.CrossEntropyLoss() 41 | total_num_task_steps = num_epochs * num_tasks + 1 42 | train_losses = [] 43 | eval_losses = torch.zeros(total_num_task_steps, num_tasks) 44 | 45 | # Initial Eval 46 | model.eval() 47 | with torch.no_grad(): 48 | for eval_step, (eval_samples, eval_targets) in enumerate(eval_dataloader): 49 | eval_samples = eval_samples.to(device) 50 | eval_targets = eval_targets.to(device) 51 | eval_preds = model(eval_samples) 52 | eval_loss = loss_fn(eval_preds, eval_targets) 53 | eval_losses[0, eval_step] = eval_loss.detach() 54 | 55 | progress_bar = tqdm(range(num_epochs * num_tasks)) 56 | 57 | for epoch in range(num_epochs): 58 | for step, (samples, targets) in enumerate(dataloader): 59 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 60 | global_train_step = epoch * num_tasks + step + 1 61 | 62 | model.train() 63 | samples = samples.to(device) 64 | targets = targets.to(device) 65 | for _ in range(num_grad_steps): 66 | preds = model(samples) 67 | loss = loss_fn(preds, targets) 68 | train_losses.append(loss.detach().unsqueeze(0)) 69 | optimizer.zero_grad() 70 | loss.backward() 71 | optimizer.step() 72 | 73 | model.eval() 74 | with torch.no_grad(): 75 | for eval_step, (eval_samples, eval_targets) in enumerate(eval_dataloader): 76 | eval_samples = eval_samples.to(device) 77 | eval_targets = eval_targets.to(device) 78 | eval_preds = model(eval_samples) 79 | eval_loss = loss_fn(eval_preds, eval_targets) 80 | eval_losses[global_train_step, eval_step] = eval_loss.detach() 81 | 82 | progress_bar.update(1) 83 | 84 | eval_losses = eval_losses.cpu().numpy() 85 | -------------------------------------------------------------------------------- /igpt-experiments/train_interleave_igpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | 7 | import torch 8 | import numpy as np 9 | from torch.utils.data import DataLoader 10 | from tqdm.auto import tqdm 11 | import transformers 12 | from logging import getLogger 13 | from transformers import ( 14 | AutoConfig, 15 | default_data_collator, 16 | GPTNeoXForCausalLM, 17 | ) 18 | from copy import deepcopy 19 | from transformers import AutoImageProcessor, ImageGPTImageProcessor, ImageGPTForCausalImageModeling 20 | import torchvision 21 | from PIL import Image 22 | import time 23 | 24 | logger = getLogger(__name__) 25 | 26 | class CustomSamplerWithoutReplacement(torch.utils.data.Sampler): 27 | def __init__(self, length, shuffle_start=1): 28 | assert length >= shuffle_start 29 | self.length = length 30 | self.shuffle_start = shuffle_start 31 | self.indices = list(range(shuffle_start, length)) 32 | 33 | def __iter__(self): 34 | for idx in range(self.shuffle_start): 35 | yield idx 36 | random.shuffle(self.indices) 37 | for idx in self.indices: 38 | yield idx 39 | 40 | def __len__(self): 41 | return self.length 42 | 43 | class CIFAR100IGPT(torchvision.datasets.CIFAR100): 44 | def __init__(self, root, train, download, image_processor): 45 | super().__init__(root=root, train=train, download=download) 46 | self.image_processor = image_processor 47 | 48 | def __getitem__(self, index): 49 | img = self.data[index] 50 | img = Image.fromarray(img) 51 | batch = self.image_processor(img, return_tensors='pt') 52 | batch["labels"] = batch["input_ids"].detach().clone() 53 | return batch 54 | 55 | 56 | def parse_args(): 57 | 58 | parser = argparse.ArgumentParser(description="Finetune large language models on causal language modeling tasks") 59 | parser.add_argument("--dataset_name", type=str, default='CIFAR100', help="The name of the dataset to use (via the datasets library).") 60 | parser.add_argument("--model_size", type=str, default='small', choices=['small', 'medium', 'large'], help="pretrained image GPT model size.") 61 | parser.add_argument("--learning_rate", type=float, default=0.001, help="Initial learning rate (after the potential warmup period) to use.") 62 | parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") 63 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 64 | parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="If the training should continue from a checkpoint folder.") 65 | parser.add_argument("--save_prefix", type=str, default='', help="Informative string prefix for saving purposes.") 66 | parser.add_argument("--use_pretrained_weights", action=argparse.BooleanOptionalAction, help="Whether to use pretrained weights.") 67 | parser.set_defaults(use_pretrained_weights=True) 68 | parser.add_argument("--eval_every_step", action=argparse.BooleanOptionalAction, help="Whether to eval every step.") 69 | parser.set_defaults(eval_every_step=True) 70 | parser.add_argument("--use_validation", action=argparse.BooleanOptionalAction, help="Whether to eval on validation set.") 71 | parser.set_defaults(use_validation=False) 72 | parser.add_argument("--store_state", action=argparse.BooleanOptionalAction, help="Whether to store model weights.") 73 | parser.set_defaults(use_validation=False) 74 | parser.add_argument("--num-grad-steps", type=int, default=1, help="Number of gradient updates for each data point.") 75 | parser.add_argument("--batch_size", type=int, default=1, help="Number of images in each task (batch).") 76 | parser.add_argument("--num-data-samples", type=int, default=50, help="Number of tasks to interleave.") 77 | 78 | args = parser.parse_args() 79 | 80 | return args 81 | 82 | 83 | def main(): 84 | device = 'cuda' 85 | args = parse_args() 86 | print(args) 87 | 88 | logging.basicConfig( 89 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 90 | datefmt="%m/%d/%Y %H:%M:%S", 91 | level=logging.INFO, 92 | ) 93 | 94 | transformers.utils.logging.set_verbosity_error() 95 | 96 | if args.output_dir is not None: 97 | os.makedirs(args.output_dir, exist_ok=True) 98 | 99 | # Load pretrained model and Image Processor 100 | image_processor = AutoImageProcessor.from_pretrained(f"openai/imagegpt-{args.model_size}") 101 | model = ImageGPTForCausalImageModeling.from_pretrained(f"openai/imagegpt-{args.model_size}").to(device) 102 | 103 | # Load Datasets 104 | data_transform = torchvision.transforms.ToTensor() 105 | 106 | if args.dataset_name == 'CIFAR100': 107 | dataset = CIFAR100IGPT(root=str(os.environ.get('DATA')), train=True, download=True, image_processor=image_processor) 108 | else: 109 | raise NotImplementedError 110 | 111 | subset_task_index = random.sample(range(len(dataset)), args.num_data_samples * args.batch_size) 112 | train_dataset = torch.utils.data.Subset(dataset, subset_task_index) 113 | eval_dataset = torch.utils.data.Subset(dataset, subset_task_index) 114 | 115 | # model.resize_token_embeddings(len(tokenizer)) 116 | 117 | train_dataloader = DataLoader(train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.batch_size) 118 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.batch_size) 119 | 120 | # optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) 121 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 122 | 123 | # Scheduler and math around the number of training steps. 124 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) * args.num_grad_steps 125 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 126 | 127 | logger.info("***** Running training *****") 128 | logger.info(f" Num examples = {len(train_dataset)}") 129 | logger.info(f" Num Epochs = {args.num_train_epochs}") 130 | logger.info(f" Total optimization steps = {args.max_train_steps}") 131 | 132 | # Only show the progress bar once on each machine. 133 | progress_bar = tqdm(range(args.max_train_steps)) 134 | completed_steps = 0 135 | starting_epoch = 0 136 | 137 | # Potentially load in the weights and states from a previous save 138 | if args.resume_from_checkpoint: 139 | raise NotImplementedError() 140 | 141 | # update the progress_bar if load from checkpoint 142 | progress_bar.update(starting_epoch * num_update_steps_per_epoch) 143 | completed_steps = starting_epoch * num_update_steps_per_epoch 144 | 145 | total_num_task_steps = args.num_train_epochs * args.num_data_samples + 1 146 | 147 | train_losses = [] 148 | eval_losses_all = torch.zeros(total_num_task_steps, len(eval_dataloader)) 149 | 150 | # Initial Eval 151 | model.eval() 152 | with torch.no_grad(): 153 | for eval_step, (batch) in enumerate(eval_dataloader): 154 | batch['input_ids'] = batch['input_ids'].to(device) 155 | batch['labels'] = batch['labels'].to(device) 156 | outputs = model(**batch) 157 | loss = outputs.loss 158 | eval_losses_all[0, eval_step] = loss.detach() 159 | logger.info(f"Mean eval loss: {torch.mean(eval_losses_all[0, :])}") 160 | 161 | for epoch in range(starting_epoch, args.num_train_epochs): 162 | 163 | if args.resume_from_checkpoint and epoch == starting_epoch: 164 | if resume_step is not None and step < resume_step: 165 | progress_bar.update(1) 166 | completed_steps += 1 167 | continue 168 | 169 | for step, (batch) in enumerate(train_dataloader): 170 | global_train_step = epoch * args.num_data_samples + step + 1 171 | 172 | batch['input_ids'] = batch['input_ids'].to(device) 173 | batch['labels'] = batch['labels'].to(device) 174 | for _ in range(args.num_grad_steps): 175 | outputs = model(**batch) 176 | loss = outputs.loss 177 | train_losses.append(loss.detach().unsqueeze(0)) 178 | optimizer.zero_grad() 179 | loss.backward() 180 | optimizer.step() 181 | progress_bar.update(1) 182 | completed_steps += 1 183 | 184 | if args.eval_every_step: 185 | with torch.no_grad(): 186 | for eval_step, (eval_batch) in enumerate(eval_dataloader): 187 | eval_batch['input_ids'] = eval_batch['input_ids'].to(device) 188 | eval_batch['labels'] = eval_batch['labels'].to(device) 189 | eval_outputs = model(**eval_batch) 190 | eval_loss = eval_outputs.loss 191 | eval_losses_all[global_train_step, eval_step] = eval_loss.detach() 192 | 193 | if args.store_state: 194 | save_dir = f"task_{epoch * args.num_data_samples + step}.pth" 195 | save_dir = os.path.join(args.output_dir, save_dir) 196 | torch.save(model.state_dict(), save_dir) 197 | 198 | output_dir = f"epoch_{epoch}" 199 | if args.output_dir is not None: 200 | output_dir = os.path.join(args.output_dir, output_dir) 201 | os.makedirs(output_dir, exist_ok=True) 202 | 203 | # save train_losses 204 | train_losses_ckpt = torch.cat(train_losses) 205 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 206 | logger.info(f"Mean train loss: {np.mean(train_losses_ckpt)}") 207 | 208 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 209 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, completed_steps=completed_steps) 210 | 211 | if args.output_dir is not None: 212 | output_dir = os.path.join(args.output_dir, f'final') 213 | os.makedirs(output_dir, exist_ok=True) 214 | 215 | # save train_losses 216 | train_losses_ckpt = torch.cat(train_losses) 217 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 218 | logger.info(f"Final mean train loss: {np.mean(train_losses_ckpt)}") 219 | 220 | eval_losses_all_ckpt = eval_losses_all.cpu().numpy() 221 | 222 | # save results 223 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 224 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, eval_losses_ckpt=eval_losses_all_ckpt, completed_steps=completed_steps) 225 | 226 | 227 | if __name__ == "__main__": 228 | main() 229 | -------------------------------------------------------------------------------- /llm-experiments/visualization/visualize_rep.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from itertools import chain 7 | 8 | import torch 9 | import numpy as np 10 | from datasets import load_dataset 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import set_seed 18 | from transformers import ( 19 | CONFIG_MAPPING, 20 | MODEL_MAPPING, 21 | AutoConfig, 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | default_data_collator, 25 | GPTNeoXForCausalLM, 26 | ) 27 | from copy import deepcopy 28 | 29 | logger = get_logger(__name__) 30 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 31 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 32 | 33 | def parse_args(): 34 | 35 | parser = argparse.ArgumentParser(description="Finetune large language models on causal language modeling tasks") 36 | parser.add_argument("--dataset_name", type=str, default=None, help="The name of the dataset to use (via the datasets library).") 37 | parser.add_argument("--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") 38 | parser.add_argument("--train_file", type=str, default=None, help="A csv or a json file containing the training data.") 39 | parser.add_argument("--model_name_or_path", type=str, help="Path to pretrained model or model identifier from huggingface.co/models.", required=False) 40 | parser.add_argument("--revision", type=str, default='main', help="Model Branch") 41 | parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name") 42 | parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name") 43 | parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).") 44 | parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.") 45 | parser.add_argument("--learning_rate", type=float, default=0.0001, help="Initial learning rate (after the potential warmup period) to use.") 46 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 47 | parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") 48 | parser.add_argument("--load_dir", type=str, default=None, help="Directory to experiment for loading.") 49 | parser.add_argument("--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.") 50 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 51 | parser.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", choices=MODEL_TYPES) 52 | parser.add_argument("--block_size", type=int, default=None, help="The training dataset will be truncated to blocks of this size (after tokenization) for training.") 53 | parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.") 54 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 55 | parser.add_argument("--checkpointing_steps", type=str, default=None, help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.") 56 | parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="If the training should continue from a checkpoint folder.") 57 | parser.add_argument("--save_prefix", type=str, default='', help="Informative string prefix for saving purposes.") 58 | parser.add_argument("--use_pretrained_weights", action=argparse.BooleanOptionalAction, help="Whether to use pretrained weights.") 59 | parser.set_defaults(use_pretrained_weights=True) 60 | 61 | parser.add_argument("--eval_every_step", action=argparse.BooleanOptionalAction, help="Whether to eval every step.") 62 | parser.set_defaults(eval_every_step=True) 63 | 64 | parser.add_argument("--use_validation", action=argparse.BooleanOptionalAction, help="Whether to eval on validation set.") 65 | parser.set_defaults(use_validation=False) 66 | parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.") 67 | parser.add_argument("--eval_freq", type=int, default=1, help="Number of epochs before every recall experiment.") 68 | parser.add_argument("--save_freq", type=int, default=10, help="Number of epochs before every moodel and optimizer save.") 69 | parser.add_argument("--num-data-samples", type=int, default=1, help="Number of tasks to interleave.") 70 | parser.add_argument("--num-eval-data-samples", type=int, default=100, help="Number of tasks to interleave.") 71 | 72 | args = parser.parse_args() 73 | 74 | return args 75 | 76 | 77 | def main(): 78 | args = parse_args() 79 | print(args) 80 | 81 | eval_sample = [12] 82 | # eval_sample = range(args.num_data_samples) 83 | 84 | accelerator = Accelerator() 85 | 86 | logging.basicConfig( 87 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 88 | datefmt="%m/%d/%Y %H:%M:%S", 89 | level=logging.INFO, 90 | ) 91 | logger.info(accelerator.state, main_process_only=False) 92 | if accelerator.is_local_main_process: 93 | transformers.utils.logging.set_verbosity_info() 94 | else: 95 | transformers.utils.logging.set_verbosity_error() 96 | 97 | if args.seed is not None: 98 | set_seed(args.seed) 99 | 100 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 101 | 102 | if 'test' in raw_datasets.keys(): 103 | raw_datasets.pop('test') 104 | print("Length of Training Set", raw_datasets['train']) 105 | subset_task_index = range(args.num_data_samples) 106 | raw_datasets['train'] = raw_datasets['train'].select(subset_task_index) 107 | if args.use_validation: 108 | subset_task_index_eval = random.sample(range(len(raw_datasets['validation'])), args.num_eval_data_samples) 109 | raw_datasets['validation'] = raw_datasets['validation'].select(subset_task_index_eval) 110 | elif 'validation' in raw_datasets.keys(): 111 | raw_datasets.pop('validation') 112 | 113 | eval_datasets = deepcopy(raw_datasets) 114 | eval_datasets['train'] = eval_datasets['train'].select(eval_sample) 115 | 116 | # Load pretrained model and tokenizer 117 | if args.config_name: 118 | config = AutoConfig.from_pretrained(args.config_name) 119 | elif args.model_name_or_path: 120 | config = AutoConfig.from_pretrained(args.model_name_or_path) 121 | else: 122 | config = CONFIG_MAPPING[args.model_type]() 123 | logger.warning("You are instantiating a new config instance from scratch.") 124 | 125 | if args.tokenizer_name: 126 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 127 | elif args.model_name_or_path: 128 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 129 | if args.model_name_or_path.startswith("gpt2") or args.model_name_or_path.startswith("EleutherAI"): 130 | tokenizer.pad_token = tokenizer.eos_token 131 | else: 132 | raise ValueError() 133 | 134 | if args.model_name_or_path and args.use_pretrained_weights: 135 | if 'pythia' in args.model_name_or_path: 136 | model_author, model_name = args.model_name_or_path.split('/') 137 | model = GPTNeoXForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, revision=args.revision, cache_dir=f"./{model_name}/{args.revision}") 138 | else: 139 | model = AutoModelForCausalLM.from_predtrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config) 140 | else: 141 | logger.info("Training new model from scratch") 142 | if 'pythia' in args.model_name_or_path: 143 | model = GPTNeoXForCausalLM(config) 144 | else: 145 | model = AutoModelForCausalLM.from_config(config) 146 | 147 | model.resize_token_embeddings(len(tokenizer)) 148 | 149 | column_names = raw_datasets["train"].column_names 150 | eval_column_names = eval_datasets["train"].column_names 151 | 152 | test_text_column_name = 'highlights' 153 | text_column_name = eval_text_column_name = "article" 154 | 155 | if args.block_size is None: 156 | block_size = tokenizer.model_max_length 157 | else: 158 | block_size = args.block_size 159 | if args.block_size > tokenizer.model_max_length: 160 | logger.warning( 161 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" 162 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 163 | ) 164 | block_size = tokenizer.model_max_length 165 | print('Block size:', block_size) 166 | 167 | def tokenize_function_eval(examples): 168 | return tokenizer(examples[eval_text_column_name], truncation=True, max_length=block_size) 169 | 170 | with accelerator.main_process_first(): 171 | eval_tokenized_datasets = eval_datasets.map( 172 | tokenize_function_eval, 173 | batched=True, 174 | num_proc=args.preprocessing_num_workers, 175 | remove_columns=eval_column_names, 176 | load_from_cache_file=not args.overwrite_cache, 177 | desc="Running tokenizer on dataset", 178 | ) 179 | 180 | def preprocess_function(examples): 181 | examples["labels"] = examples["input_ids"].copy() 182 | examples["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in examples["labels"]] 183 | return examples 184 | 185 | with accelerator.main_process_first(): 186 | eval_lm_datasets = eval_tokenized_datasets.map( 187 | preprocess_function, 188 | batched=True, 189 | num_proc=args.preprocessing_num_workers, 190 | load_from_cache_file=not args.overwrite_cache, 191 | desc=f"Not grouping text.", 192 | ) 193 | 194 | eval_dataset = eval_lm_datasets["train"] 195 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 196 | 197 | for index in random.sample(range(len(eval_dataset)), 1): 198 | # logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") 199 | logger.info(f"Sample {index} of the validation set (decoded): {tokenizer.decode(eval_dataset[index]['input_ids'], skip_special_tokens=True)}") 200 | 201 | # Prepare everything with our `accelerator`. 202 | model, eval_dataloader = accelerator.prepare(model, eval_dataloader) 203 | 204 | num_steps = 125 205 | reps_all = np.zeros((num_steps, 262144)) 206 | 207 | # Initial Eval 208 | for task_idx in tqdm(range(0, num_steps)): 209 | model_weights = torch.load(f'{args.load_dir}/task_{task_idx}/pytorch_model.bin') 210 | model.load_state_dict(model_weights) 211 | model.eval() 212 | 213 | with torch.no_grad(): 214 | for step, batch in enumerate(eval_dataloader): 215 | rep = model.gpt_neox(batch['input_ids'])['last_hidden_state'].flatten() 216 | reps_all[task_idx] = rep.detach().cpu().numpy() 217 | 218 | del model_weights 219 | 220 | all_norms = np.zeros((num_steps, 1)) 221 | for i in range(num_steps): 222 | all_norms[i] = np.linalg.norm(reps_all[i]) 223 | norm_outer = all_norms @ all_norms.T 224 | all_corr = reps_all @ reps_all.T 225 | print(all_corr.shape) 226 | print(norm_outer.shape) 227 | all_corr = all_corr / norm_outer 228 | 229 | if __name__ == "__main__": 230 | main() 231 | -------------------------------------------------------------------------------- /llm-experiments/visualization/dipping_matrix.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from itertools import chain 7 | 8 | import torch 9 | import numpy as np 10 | from datasets import load_dataset 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import set_seed 18 | from transformers import ( 19 | CONFIG_MAPPING, 20 | MODEL_MAPPING, 21 | AutoConfig, 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | default_data_collator, 25 | GPTNeoXForCausalLM, 26 | ) 27 | from copy import deepcopy 28 | 29 | logger = get_logger(__name__) 30 | 31 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 32 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 33 | 34 | 35 | def parse_args(): 36 | 37 | parser = argparse.ArgumentParser(description="Finetune large language models on causal language modeling tasks") 38 | parser.add_argument("--dataset_name", type=str, default=None, required=True, help="The name of the dataset to use (via the datasets library).") 39 | parser.add_argument("--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") 40 | parser.add_argument("--model_name_or_path", type=str, help="Path to pretrained model or model identifier from huggingface.co/models.", required=False) 41 | parser.add_argument("--revision", type=str, default='main', help="Model Branch") 42 | parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name") 43 | parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name") 44 | parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).") 45 | parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.") 46 | parser.add_argument("--learning_rate", type=float, default=0.0001, help="Initial learning rate (after the potential warmup period) to use.") 47 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 48 | parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") 49 | parser.add_argument("--load_dir", type=str, default=None, help="Directory to experiment for loading.") 50 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 51 | parser.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", choices=MODEL_TYPES) 52 | parser.add_argument("--block_size", type=int, default=None, help="The training dataset will be truncated to blocks of this size (after tokenization) for training.") 53 | parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.") 54 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 55 | parser.add_argument("--save_prefix", type=str, default='', help="Informative string prefix for saving purposes.") 56 | parser.add_argument("--use_pretrained_weights", action=argparse.BooleanOptionalAction, help="Whether to use pretrained weights.") 57 | parser.set_defaults(use_pretrained_weights=True) 58 | 59 | parser.add_argument("--eval_every_step", action=argparse.BooleanOptionalAction, help="Whether to eval every step.") 60 | parser.set_defaults(eval_every_step=True) 61 | 62 | parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.") 63 | parser.add_argument("--eval_freq", type=int, default=1, help="Number of epochs before every recall experiment.") 64 | parser.add_argument("--save_freq", type=int, default=10, help="Number of epochs before every moodel and optimizer save.") 65 | 66 | parser.add_argument("--num-grad-steps", type=int, default=1, help="Number of gradient updates for each data point.") 67 | parser.add_argument("--num-data-samples", type=int, default=1, help="Number of tasks to interleave.") 68 | parser.add_argument("--num-eval-data-samples", type=int, default=100, help="Number of tasks to interleave.") 69 | 70 | args = parser.parse_args() 71 | 72 | return args 73 | 74 | 75 | def main(): 76 | args = parse_args() 77 | print(args) 78 | 79 | # eval_sample = [0] 80 | eval_sample = range(args.num_data_samples) 81 | 82 | accelerator = Accelerator() 83 | 84 | logging.basicConfig( 85 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 86 | datefmt="%m/%d/%Y %H:%M:%S", 87 | level=logging.INFO, 88 | ) 89 | logger.info(accelerator.state, main_process_only=False) 90 | if accelerator.is_local_main_process: 91 | transformers.utils.logging.set_verbosity_info() 92 | else: 93 | transformers.utils.logging.set_verbosity_error() 94 | 95 | if args.seed is not None: 96 | set_seed(args.seed) 97 | 98 | accelerator.wait_for_everyone() 99 | 100 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 101 | 102 | if 'test' in raw_datasets.keys(): 103 | raw_datasets.pop('test') 104 | print("Length of Training Set", raw_datasets['train']) 105 | 106 | # subset_task_index = random.sample(range(len(raw_datasets['train'])), args.num_data_samples) 107 | # subset_task_index = range(args.num_data_samples) 108 | subset_task_index = np.load(f'{args.load_dir}/sampled_indices.npy') 109 | 110 | raw_datasets['train'] = raw_datasets['train'].select(subset_task_index) 111 | raw_datasets.pop('validation') 112 | 113 | eval_datasets = deepcopy(raw_datasets) 114 | eval_datasets['train'] = eval_datasets['train'].select(eval_sample) 115 | 116 | # Load pretrained model and tokenizer 117 | if args.config_name: 118 | config = AutoConfig.from_pretrained(args.config_name) 119 | elif args.model_name_or_path: 120 | config = AutoConfig.from_pretrained(args.model_name_or_path) 121 | else: 122 | config = CONFIG_MAPPING[args.model_type]() 123 | logger.warning("You are instantiating a new config instance from scratch.") 124 | 125 | if args.tokenizer_name: 126 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 127 | elif args.model_name_or_path: 128 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 129 | if args.model_name_or_path.startswith("gpt2") or args.model_name_or_path.startswith("EleutherAI"): 130 | tokenizer.pad_token = tokenizer.eos_token 131 | else: 132 | raise ValueError() 133 | 134 | if args.model_name_or_path and args.use_pretrained_weights: 135 | if 'pythia' in args.model_name_or_path: 136 | model_author, model_name = args.model_name_or_path.split('/') 137 | model = GPTNeoXForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, revision=args.revision, cache_dir=f"./{model_name}/{args.revision}") 138 | else: 139 | model = AutoModelForCausalLM.from_predtrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config) 140 | else: 141 | logger.info("Training new model from scratch") 142 | if 'pythia' in args.model_name_or_path: 143 | model = GPTNeoXForCausalLM(config) 144 | else: 145 | model = AutoModelForCausalLM.from_config(config) 146 | 147 | model.resize_token_embeddings(len(tokenizer)) 148 | 149 | if not args.use_pretrained_weights: 150 | with torch.no_grad(): 151 | for name, param in model.named_parameters(): 152 | if 'norm' not in name and 'bias' not in name: 153 | # print(f'Layer {name}, Weight Scale {torch.max(param).data}') 154 | if 'query_key_value' in name or 'dense_h_to_4h' in name or 'embed_in' in name or 'embed_out' in name: 155 | param *= math.sqrt(2 / (5 * 2048)) / 0.02 156 | elif 'attention.dense' in name or 'dense_4h_to_h' in name: 157 | param *= 2 / 16 / math.sqrt(2048) / 0.02 158 | 159 | for name, param in model.named_parameters(): 160 | if 'norm' not in name and 'bias' not in name: 161 | print(f'Layer {name}, Weight Scale {torch.max(param).data}') 162 | 163 | column_names = raw_datasets["train"].column_names 164 | eval_column_names = eval_datasets["train"].column_names 165 | text_column_name = eval_text_column_name = "article" 166 | 167 | if args.block_size is None: 168 | block_size = tokenizer.model_max_length 169 | else: 170 | block_size = args.block_size 171 | if args.block_size > tokenizer.model_max_length: 172 | logger.warning( 173 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" 174 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 175 | ) 176 | block_size = tokenizer.model_max_length 177 | print('Block size:', block_size) 178 | 179 | def tokenize_function(examples): 180 | return tokenizer(examples[text_column_name], padding='max_length', truncation=True, max_length=block_size) 181 | 182 | def tokenize_function_eval(examples): 183 | return tokenizer(examples[eval_text_column_name], truncation=True, max_length=block_size) 184 | 185 | with accelerator.main_process_first(): 186 | tokenized_datasets = raw_datasets.map( 187 | tokenize_function, 188 | batched=True, 189 | num_proc=args.preprocessing_num_workers, 190 | remove_columns=column_names, 191 | load_from_cache_file=not args.overwrite_cache, 192 | desc="Running tokenizer on dataset", 193 | ) 194 | 195 | with accelerator.main_process_first(): 196 | eval_tokenized_datasets = eval_datasets.map( 197 | tokenize_function_eval, 198 | batched=True, 199 | num_proc=args.preprocessing_num_workers, 200 | remove_columns=eval_column_names, 201 | load_from_cache_file=not args.overwrite_cache, 202 | desc="Running tokenizer on dataset", 203 | ) 204 | 205 | def preprocess_function(examples): 206 | examples["labels"] = examples["input_ids"].copy() 207 | examples["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in examples["labels"]] 208 | return examples 209 | 210 | with accelerator.main_process_first(): 211 | lm_datasets = tokenized_datasets.map( 212 | preprocess_function, 213 | batched=True, 214 | num_proc=args.preprocessing_num_workers, 215 | load_from_cache_file=not args.overwrite_cache, 216 | desc=f"Not grouping text.", 217 | ) 218 | 219 | with accelerator.main_process_first(): 220 | eval_lm_datasets = eval_tokenized_datasets.map( 221 | preprocess_function, 222 | batched=True, 223 | num_proc=args.preprocessing_num_workers, 224 | load_from_cache_file=not args.overwrite_cache, 225 | desc=f"Not grouping text.", 226 | ) 227 | 228 | train_dataset = lm_datasets["train"] 229 | train_dataloader = DataLoader(train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size) 230 | 231 | eval_dataset = eval_lm_datasets["train"] 232 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 233 | 234 | for index in random.sample(range(len(train_dataset)), 1): 235 | index = 0 236 | # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 237 | logger.info(f"Sample {index} of the training set (decoded): {tokenizer.decode(train_dataset[index]['input_ids'], skip_special_tokens=True)}.") 238 | for index in random.sample(range(len(eval_dataset)), 1): 239 | # logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") 240 | logger.info(f"Sample {index} of the validation set (decoded): {tokenizer.decode(eval_dataset[index]['input_ids'], skip_special_tokens=True)}.") 241 | 242 | no_decay = ["bias", "layer_norm.weight"] 243 | optimizer_grouped_parameters = [ 244 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 245 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 246 | ] 247 | 248 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate) 249 | 250 | # Prepare everything with our `accelerator`. 251 | model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader) 252 | 253 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 254 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) * args.num_grad_steps 255 | 256 | # Train! 257 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes 258 | 259 | logger.info("***** Running training *****") 260 | logger.info(f" Num examples = {len(train_dataset)}") 261 | logger.info(f" Num Epochs = {args.num_train_epochs}") 262 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 263 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 264 | 265 | num_total_steps = args.num_train_epochs * len(eval_dataloader) 266 | eval_losses_ckpt = torch.zeros(num_total_steps, args.num_data_samples) 267 | eval_losses_all = torch.zeros(num_total_steps, args.num_data_samples, args.num_data_samples) 268 | 269 | for step_id in tqdm(range(num_total_steps)): 270 | 271 | model_ckpt = torch.load(f'{load_dir}/task_{step_id}/pytorch_model.bin') 272 | model.load_state_dict(model_ckpt) 273 | model.eval() 274 | with torch.no_grad(): 275 | for eval_step, eval_batch in enumerate(eval_dataloader): 276 | eval_outputs = model(**eval_batch) 277 | eval_loss = eval_outputs.loss 278 | eval_losses_ckpt[step_id, eval_step] = eval_loss.detach() 279 | 280 | for train_step, batch in enumerate(train_dataloader): 281 | model.load_state_dict(model_ckpt) 282 | model.train() 283 | for grad_step in range(args.num_grad_steps): 284 | outputs = model(**batch) 285 | loss = outputs.loss 286 | optimizer.zero_grad() 287 | accelerator.backward(loss) 288 | optimizer.step() 289 | 290 | model.eval() 291 | with torch.no_grad(): 292 | for eval_step, eval_batch in enumerate(eval_dataloader): 293 | eval_outputs = model(**eval_batch) 294 | eval_loss = eval_outputs.loss 295 | eval_losses_all[step_id, train_step, eval_step] = eval_losses_ckpt[step_id, eval_step] - eval_loss.detach() 296 | 297 | eval_losses_all = eval_losses_all.cpu().numpy() 298 | 299 | if __name__ == "__main__": 300 | main() 301 | -------------------------------------------------------------------------------- /llm-experiments/training/train_interleave_scratch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from itertools import chain 7 | 8 | import torch 9 | import numpy as np 10 | from datasets import load_dataset 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import set_seed 18 | from transformers import ( 19 | CONFIG_MAPPING, 20 | MODEL_MAPPING, 21 | AutoConfig, 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | default_data_collator, 25 | GPTNeoXForCausalLM, 26 | ) 27 | from copy import deepcopy 28 | 29 | logger = get_logger(__name__) 30 | 31 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 32 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 33 | 34 | 35 | def parse_args(): 36 | 37 | parser = argparse.ArgumentParser(description="Finetune large language models on causal language modeling tasks") 38 | parser.add_argument("--dataset_name", type=str, default=None, required=True, help="The name of the dataset to use (via the datasets library).") 39 | parser.add_argument("--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") 40 | parser.add_argument("--model_name_or_path", type=str, help="Path to pretrained model or model identifier from huggingface.co/models.", required=False) 41 | parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name") 42 | parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name") 43 | parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).") 44 | parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.") 45 | parser.add_argument("--learning_rate", type=float, default=0.0001, help="Initial learning rate (after the potential warmup period) to use.") 46 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 47 | parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") 48 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 49 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 50 | parser.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", choices=MODEL_TYPES) 51 | parser.add_argument("--block_size", type=int, default=None, help="The training dataset will be truncated to blocks of this size (after tokenization) for training.") 52 | parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.") 53 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 54 | parser.add_argument("--save_prefix", type=str, default='', help="Informative string prefix for saving purposes.") 55 | 56 | parser.add_argument("--eval_every_step", action=argparse.BooleanOptionalAction, help="Whether to eval every step.") 57 | parser.set_defaults(eval_every_step=True) 58 | parser.add_argument("--use_validation", action=argparse.BooleanOptionalAction, help="Whether to eval on validation set.") 59 | parser.set_defaults(use_validation=False) 60 | 61 | parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.") 62 | parser.add_argument("--eval_freq", type=int, default=1, help="Number of epochs before every recall experiment.") 63 | parser.add_argument("--save_freq", type=int, default=10, help="Number of epochs before every moodel and optimizer save.") 64 | 65 | parser.add_argument("--num-grad-steps", type=int, default=1, help="Number of gradient updates for each data point.") 66 | parser.add_argument("--num-data-samples", type=int, default=1, help="Number of tasks to interleave.") 67 | parser.add_argument("--num-eval-data-samples", type=int, default=100, help="Number of tasks to interleave.") 68 | 69 | parser.add_argument("--num_hidden_layers", type=int, default=16, help="Number of hidden layers.") 70 | parser.add_argument("--num_attention_heads", type=int, default=8, help="Number of attention heads.") 71 | parser.add_argument("--hidden_size", type=int, default=2048, help="Hidden size.") 72 | 73 | args = parser.parse_args() 74 | 75 | return args 76 | 77 | 78 | def main(): 79 | args = parse_args() 80 | print(args) 81 | 82 | # eval_sample = [0] 83 | eval_sample = range(args.num_data_samples) 84 | 85 | accelerator = Accelerator() 86 | 87 | logging.basicConfig( 88 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 89 | datefmt="%m/%d/%Y %H:%M:%S", 90 | level=logging.INFO, 91 | ) 92 | logger.info(accelerator.state, main_process_only=False) 93 | if accelerator.is_local_main_process: 94 | transformers.utils.logging.set_verbosity_info() 95 | else: 96 | transformers.utils.logging.set_verbosity_error() 97 | 98 | if args.seed is not None: 99 | set_seed(args.seed) 100 | 101 | if accelerator.is_main_process: 102 | if args.output_dir is not None: 103 | os.makedirs(args.output_dir, exist_ok=True) 104 | accelerator.wait_for_everyone() 105 | 106 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 107 | 108 | if 'test' in raw_datasets.keys(): 109 | raw_datasets.pop('test') 110 | print("Length of Training Set", raw_datasets['train']) 111 | 112 | subset_task_index = random.sample(range(len(raw_datasets['train'])), args.num_data_samples) 113 | # subset_task_index = range(args.num_data_samples) 114 | 115 | raw_datasets['train'] = raw_datasets['train'].select(subset_task_index) 116 | if args.use_validation: 117 | subset_task_index_eval = random.sample(range(len(raw_datasets['validation'])), args.num_eval_data_samples) 118 | raw_datasets['validation'] = raw_datasets['validation'].select(subset_task_index_eval) 119 | elif 'validation' in raw_datasets.keys(): 120 | raw_datasets.pop('validation') 121 | 122 | eval_datasets = deepcopy(raw_datasets) 123 | eval_datasets['train'] = eval_datasets['train'].select(eval_sample) 124 | 125 | # Load pretrained model and tokenizer 126 | if args.config_name: 127 | config = AutoConfig.from_pretrained(args.config_name) 128 | elif args.model_name_or_path: 129 | config = AutoConfig.from_pretrained(args.model_name_or_path) 130 | else: 131 | config = CONFIG_MAPPING[args.model_type]() 132 | logger.warning("You are instantiating a new config instance from scratch.") 133 | 134 | if args.tokenizer_name: 135 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 136 | elif args.model_name_or_path: 137 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 138 | if args.model_name_or_path.startswith("gpt2") or args.model_name_or_path.startswith("EleutherAI"): 139 | tokenizer.pad_token = tokenizer.eos_token 140 | else: 141 | raise ValueError() 142 | 143 | config.num_hidden_layers = args.num_hidden_layers # Default 16 144 | config.num_attention_heads = args.num_attention_heads # Default 8 145 | config.hidden_size = args.hidden_size # Default 2048 146 | 147 | logger.info("Training new model from scratch") 148 | if 'pythia' in args.model_name_or_path: 149 | model = GPTNeoXForCausalLM(config) 150 | else: 151 | model = AutoModelForCausalLM.from_config(config) 152 | 153 | model.resize_token_embeddings(len(tokenizer)) 154 | 155 | with torch.no_grad(): 156 | for name, param in model.named_parameters(): 157 | if 'norm' not in name and 'bias' not in name: 158 | # print(f'Layer {name}, Weight Scale {torch.max(param).data}') 159 | if 'query_key_value' in name or 'dense_h_to_4h' in name or 'embed_in' in name or 'embed_out' in name: 160 | param *= math.sqrt(2 / (5 * args.hidden_size)) / 0.02 161 | elif 'attention.dense' in name or 'dense_4h_to_h' in name: 162 | param *= 2 / args.num_hidden_layers / math.sqrt(args.hidden_size) / 0.02 163 | 164 | for name, param in model.named_parameters(): 165 | if 'norm' not in name and 'bias' not in name: 166 | print(f'Layer {name}, Weight Scale {torch.max(param).data}') 167 | 168 | column_names = raw_datasets["train"].column_names 169 | eval_column_names = eval_datasets["train"].column_names 170 | 171 | test_text_column_name = 'highlights' 172 | text_column_name = eval_text_column_name = "article" 173 | 174 | if args.block_size is None: 175 | block_size = tokenizer.model_max_length 176 | else: 177 | block_size = args.block_size 178 | if args.block_size > tokenizer.model_max_length: 179 | logger.warning( 180 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" 181 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 182 | ) 183 | block_size = tokenizer.model_max_length 184 | print('Block size:', block_size) 185 | 186 | def tokenize_function(examples): 187 | return tokenizer(examples[text_column_name], padding='max_length', truncation=True, max_length=block_size) 188 | 189 | def tokenize_function_eval(examples): 190 | return tokenizer(examples[eval_text_column_name], truncation=True, max_length=block_size) 191 | 192 | def tokenize_function_test(examples): 193 | return tokenizer(examples[test_text_column_name], truncation=True, max_length=block_size) 194 | 195 | with accelerator.main_process_first(): 196 | tokenized_datasets = raw_datasets.map( 197 | tokenize_function, 198 | batched=True, 199 | num_proc=args.preprocessing_num_workers, 200 | remove_columns=column_names, 201 | load_from_cache_file=not args.overwrite_cache, 202 | desc="Running tokenizer on dataset", 203 | ) 204 | 205 | with accelerator.main_process_first(): 206 | eval_tokenized_datasets = eval_datasets.map( 207 | tokenize_function_eval, 208 | batched=True, 209 | num_proc=args.preprocessing_num_workers, 210 | remove_columns=eval_column_names, 211 | load_from_cache_file=not args.overwrite_cache, 212 | desc="Running tokenizer on dataset", 213 | ) 214 | 215 | with accelerator.main_process_first(): 216 | test_tokenized_datasets = eval_datasets.map( 217 | tokenize_function_test, 218 | batched=True, 219 | num_proc=args.preprocessing_num_workers, 220 | remove_columns=eval_column_names, 221 | load_from_cache_file=not args.overwrite_cache, 222 | desc="Running tokenizer on dataset", 223 | ) 224 | 225 | def preprocess_function(examples): 226 | examples["labels"] = examples["input_ids"].copy() 227 | examples["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in examples["labels"]] 228 | return examples 229 | 230 | with accelerator.main_process_first(): 231 | lm_datasets = tokenized_datasets.map( 232 | preprocess_function, 233 | batched=True, 234 | num_proc=args.preprocessing_num_workers, 235 | load_from_cache_file=not args.overwrite_cache, 236 | desc=f"Not grouping text.", 237 | ) 238 | 239 | with accelerator.main_process_first(): 240 | eval_lm_datasets = eval_tokenized_datasets.map( 241 | preprocess_function, 242 | batched=True, 243 | num_proc=args.preprocessing_num_workers, 244 | load_from_cache_file=not args.overwrite_cache, 245 | desc=f"Not grouping text.", 246 | ) 247 | 248 | with accelerator.main_process_first(): 249 | test_lm_datasets = test_tokenized_datasets.map( 250 | preprocess_function, 251 | batched=True, 252 | num_proc=args.preprocessing_num_workers, 253 | load_from_cache_file=not args.overwrite_cache, 254 | desc=f"Not grouping text.", 255 | ) 256 | 257 | train_dataset = lm_datasets["train"] 258 | train_dataloader = DataLoader(train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size) 259 | 260 | eval_dataset = eval_lm_datasets["train"] 261 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 262 | 263 | test_dataset = test_lm_datasets['train'] 264 | test_dataloader = DataLoader(test_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 265 | 266 | 267 | for index in random.sample(range(len(train_dataset)), 1): 268 | index = 0 269 | # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 270 | logger.info(f"Sample {index} of the training set (decoded): {tokenizer.decode(train_dataset[index]['input_ids'], skip_special_tokens=True)}.") 271 | for index in random.sample(range(len(eval_dataset)), 1): 272 | # logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") 273 | logger.info(f"Sample {index} of the validation set (decoded): {tokenizer.decode(eval_dataset[index]['input_ids'], skip_special_tokens=True)}.") 274 | 275 | no_decay = ["bias", "layer_norm.weight"] 276 | optimizer_grouped_parameters = [ 277 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 278 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 279 | ] 280 | 281 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate) 282 | 283 | # Prepare everything with our `accelerator`. 284 | model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader, test_dataloader) 285 | 286 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 287 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) * args.num_grad_steps 288 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 289 | 290 | # Train! 291 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes 292 | 293 | logger.info("***** Running training *****") 294 | logger.info(f" Num examples = {len(train_dataset)}") 295 | logger.info(f" Num Epochs = {args.num_train_epochs}") 296 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 297 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 298 | logger.info(f" Total optimization steps = {args.max_train_steps}") 299 | 300 | # Only show the progress bar once on each machine. 301 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 302 | completed_steps = 0 303 | starting_epoch = 0 304 | 305 | # update the progress_bar if load from checkpoint 306 | progress_bar.update(starting_epoch * num_update_steps_per_epoch) 307 | completed_steps = starting_epoch * num_update_steps_per_epoch 308 | 309 | total_num_task_steps = args.num_train_epochs * args.num_data_samples + 1 310 | 311 | train_losses = [] 312 | eval_losses_all = torch.zeros(total_num_task_steps, len(eval_dataloader)) 313 | test_losses_all = torch.zeros(total_num_task_steps, len(test_dataloader)) 314 | 315 | # Initial Eval 316 | model.eval() 317 | with torch.no_grad(): 318 | for eval_step, batch in enumerate(eval_dataloader): 319 | outputs = model(**batch) 320 | loss = outputs.loss 321 | eval_losses_all[0, eval_step] = loss.detach() 322 | logger.info(f"Mean eval loss: {torch.mean(eval_losses_all[0, :])}") 323 | 324 | for test_step, batch in enumerate(test_dataloader): 325 | outputs = model(**batch) 326 | loss = outputs.loss 327 | test_losses_all[0, test_step] = loss.detach() 328 | logger.info(f"Mean test loss: {torch.mean(test_losses_all[0, :])}") 329 | 330 | for epoch in range(starting_epoch, args.num_train_epochs): 331 | 332 | for step, batch in enumerate(train_dataloader): 333 | 334 | global_train_step = epoch * args.num_data_samples + step + 1 335 | 336 | model.train() 337 | 338 | for grad_step in range(args.num_grad_steps): 339 | 340 | assert model.training 341 | outputs = model(**batch) 342 | loss = outputs.loss 343 | # keep track of the loss at each epoch 344 | train_losses.append(loss.detach().unsqueeze(0)) 345 | optimizer.zero_grad() 346 | accelerator.backward(loss) 347 | optimizer.step() 348 | 349 | # Checks if the accelerator has performed an optimization step behind the scenes 350 | if accelerator.sync_gradients: 351 | progress_bar.update(1) 352 | completed_steps += 1 353 | 354 | if args.eval_every_step: 355 | model.eval() 356 | with torch.no_grad(): 357 | for eval_step, eval_batch in enumerate(eval_dataloader): 358 | eval_outputs = model(**eval_batch) 359 | eval_loss = eval_outputs.loss 360 | eval_losses_all[global_train_step, eval_step] = eval_loss.detach() 361 | 362 | for test_step, test_batch in enumerate(test_dataloader): 363 | test_outputs = model(**test_batch) 364 | test_loss = test_outputs.loss 365 | test_losses_all[global_train_step, test_step] = test_loss.detach() 366 | 367 | # Logging 368 | output_dir = f"epoch_{epoch}" 369 | if args.output_dir is not None: 370 | output_dir = os.path.join(args.output_dir, output_dir) 371 | 372 | os.makedirs(output_dir, exist_ok=True) 373 | # if epoch == 0 or (epoch+1) % args.save_freq == 0: 374 | # accelerator.save_state(output_dir) 375 | 376 | # save train_losses 377 | train_losses_ckpt = torch.cat(train_losses) 378 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 379 | logger.info(f"Mean train loss: {np.mean(train_losses_ckpt)}") 380 | 381 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 382 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, completed_steps=completed_steps) 383 | 384 | 385 | if args.output_dir is not None: 386 | output_dir = os.path.join(args.output_dir, f'final') 387 | # save model and tokenizer 388 | accelerator.wait_for_everyone() 389 | unwrapped_model = accelerator.unwrap_model(model) 390 | unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save) 391 | if accelerator.is_main_process: 392 | tokenizer.save_pretrained(output_dir) 393 | 394 | # save train_losses 395 | train_losses_ckpt = torch.cat(train_losses) 396 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 397 | logger.info(f"Final mean train loss: {np.mean(train_losses_ckpt)}") 398 | 399 | eval_losses_all_ckpt = eval_losses_all.cpu().numpy() 400 | test_losses_all_ckpt = test_losses_all.cpu().numpy() 401 | 402 | # save results 403 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 404 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, eval_losses_ckpt=eval_losses_all_ckpt, test_losses_ckpt=test_losses_all_ckpt, completed_steps=completed_steps) 405 | 406 | 407 | if __name__ == "__main__": 408 | main() 409 | -------------------------------------------------------------------------------- /llm-experiments/training/train_interleave_optimizer_reset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from itertools import chain 7 | 8 | import torch 9 | import numpy as np 10 | from datasets import load_dataset 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import set_seed 18 | from transformers import ( 19 | CONFIG_MAPPING, 20 | MODEL_MAPPING, 21 | AutoConfig, 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | default_data_collator, 25 | GPTNeoXForCausalLM, 26 | ) 27 | from copy import deepcopy 28 | 29 | logger = get_logger(__name__) 30 | 31 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 32 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 33 | 34 | 35 | def parse_args(): 36 | 37 | parser = argparse.ArgumentParser(description="Finetune large language models on causal language modeling tasks") 38 | parser.add_argument("--dataset_name", type=str, default=None, required=True, help="The name of the dataset to use (via the datasets library).") 39 | parser.add_argument("--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") 40 | parser.add_argument("--model_name_or_path", type=str, help="Path to pretrained model or model identifier from huggingface.co/models.", required=False) 41 | parser.add_argument("--revision", type=str, default='main', help="Model Branch") 42 | parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name") 43 | parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name") 44 | parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).") 45 | parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.") 46 | parser.add_argument("--learning_rate", type=float, default=0.0001, help="Initial learning rate (after the potential warmup period) to use.") 47 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 48 | parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") 49 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 50 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 51 | parser.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", choices=MODEL_TYPES) 52 | parser.add_argument("--block_size", type=int, default=None, help="The training dataset will be truncated to blocks of this size (after tokenization) for training.") 53 | parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.") 54 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 55 | parser.add_argument("--save_prefix", type=str, default='', help="Informative string prefix for saving purposes.") 56 | parser.add_argument("--use_pretrained_weights", action=argparse.BooleanOptionalAction, help="Whether to use pretrained weights.") 57 | parser.set_defaults(use_pretrained_weights=True) 58 | 59 | parser.add_argument("--eval_every_step", action=argparse.BooleanOptionalAction, help="Whether to eval every step.") 60 | parser.set_defaults(eval_every_step=True) 61 | 62 | parser.add_argument("--use_validation", action=argparse.BooleanOptionalAction, help="Whether to eval on validation set.") 63 | parser.set_defaults(use_validation=False) 64 | 65 | parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.") 66 | parser.add_argument("--eval_freq", type=int, default=1, help="Number of epochs before every recall experiment.") 67 | parser.add_argument("--save_freq", type=int, default=10, help="Number of epochs before every moodel and optimizer save.") 68 | 69 | parser.add_argument("--num-grad-steps", type=int, default=1, help="Number of gradient updates for each data point.") 70 | parser.add_argument("--num-data-samples", type=int, default=1, help="Number of tasks to interleave.") 71 | parser.add_argument("--num-eval-data-samples", type=int, default=100, help="Number of tasks to interleave.") 72 | 73 | args = parser.parse_args() 74 | 75 | return args 76 | 77 | 78 | def main(): 79 | args = parse_args() 80 | print(args) 81 | 82 | # eval_sample = [0] 83 | eval_sample = range(args.num_data_samples) 84 | 85 | accelerator = Accelerator() 86 | 87 | logging.basicConfig( 88 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 89 | datefmt="%m/%d/%Y %H:%M:%S", 90 | level=logging.INFO, 91 | ) 92 | logger.info(accelerator.state, main_process_only=False) 93 | if accelerator.is_local_main_process: 94 | transformers.utils.logging.set_verbosity_info() 95 | else: 96 | transformers.utils.logging.set_verbosity_error() 97 | 98 | if args.seed is not None: 99 | set_seed(args.seed) 100 | 101 | if accelerator.is_main_process: 102 | if args.output_dir is not None: 103 | os.makedirs(args.output_dir, exist_ok=True) 104 | accelerator.wait_for_everyone() 105 | 106 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 107 | 108 | if 'test' in raw_datasets.keys(): 109 | raw_datasets.pop('test') 110 | print("Length of Training Set", raw_datasets['train']) 111 | 112 | subset_task_index = random.sample(range(len(raw_datasets['train'])), args.num_data_samples) 113 | # subset_task_index = range(args.num_data_samples) 114 | 115 | raw_datasets['train'] = raw_datasets['train'].select(subset_task_index) 116 | if args.use_validation: 117 | subset_task_index_eval = random.sample(range(len(raw_datasets['validation'])), args.num_eval_data_samples) 118 | raw_datasets['validation'] = raw_datasets['validation'].select(subset_task_index_eval) 119 | elif 'validation' in raw_datasets.keys(): 120 | raw_datasets.pop('validation') 121 | 122 | eval_datasets = deepcopy(raw_datasets) 123 | eval_datasets['train'] = eval_datasets['train'].select(eval_sample) 124 | 125 | # Load pretrained model and tokenizer 126 | if args.config_name: 127 | config = AutoConfig.from_pretrained(args.config_name) 128 | elif args.model_name_or_path: 129 | config = AutoConfig.from_pretrained(args.model_name_or_path) 130 | else: 131 | config = CONFIG_MAPPING[args.model_type]() 132 | logger.warning("You are instantiating a new config instance from scratch.") 133 | 134 | if args.tokenizer_name: 135 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 136 | elif args.model_name_or_path: 137 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 138 | if args.model_name_or_path.startswith("gpt2") or args.model_name_or_path.startswith("EleutherAI"): 139 | tokenizer.pad_token = tokenizer.eos_token 140 | else: 141 | raise ValueError() 142 | 143 | if args.model_name_or_path and args.use_pretrained_weights: 144 | if 'pythia' in args.model_name_or_path: 145 | model_author, model_name = args.model_name_or_path.split('/') 146 | model = GPTNeoXForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, revision=args.revision, cache_dir=f"./{model_name}/{args.revision}") 147 | else: 148 | model = AutoModelForCausalLM.from_predtrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config) 149 | else: 150 | logger.info("Training new model from scratch") 151 | if 'pythia' in args.model_name_or_path: 152 | model = GPTNeoXForCausalLM(config) 153 | else: 154 | model = AutoModelForCausalLM.from_config(config) 155 | 156 | model.resize_token_embeddings(len(tokenizer)) 157 | 158 | if not args.use_pretrained_weights: 159 | with torch.no_grad(): 160 | for name, param in model.named_parameters(): 161 | if 'norm' not in name and 'bias' not in name: 162 | # print(f'Layer {name}, Weight Scale {torch.max(param).data}') 163 | if 'query_key_value' in name or 'dense_h_to_4h' in name or 'embed_in' in name or 'embed_out' in name: 164 | param *= math.sqrt(2 / (5 * 2048)) / 0.02 165 | elif 'attention.dense' in name or 'dense_4h_to_h' in name: 166 | param *= 2 / 16 / math.sqrt(2048) / 0.02 167 | 168 | for name, param in model.named_parameters(): 169 | if 'norm' not in name and 'bias' not in name: 170 | print(f'Layer {name}, Weight Scale {torch.max(param).data}') 171 | 172 | column_names = raw_datasets["train"].column_names 173 | eval_column_names = eval_datasets["train"].column_names 174 | 175 | test_text_column_name = 'highlights' 176 | text_column_name = eval_text_column_name = "article" 177 | 178 | if args.block_size is None: 179 | block_size = tokenizer.model_max_length 180 | else: 181 | block_size = args.block_size 182 | if args.block_size > tokenizer.model_max_length: 183 | logger.warning( 184 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" 185 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 186 | ) 187 | block_size = tokenizer.model_max_length 188 | print('Block size:', block_size) 189 | 190 | def tokenize_function(examples): 191 | return tokenizer(examples[text_column_name], padding='max_length', truncation=True, max_length=block_size) 192 | 193 | def tokenize_function_eval(examples): 194 | return tokenizer(examples[eval_text_column_name], truncation=True, max_length=block_size) 195 | 196 | def tokenize_function_test(examples): 197 | return tokenizer(examples[test_text_column_name], truncation=True, max_length=block_size) 198 | 199 | with accelerator.main_process_first(): 200 | tokenized_datasets = raw_datasets.map( 201 | tokenize_function, 202 | batched=True, 203 | num_proc=args.preprocessing_num_workers, 204 | remove_columns=column_names, 205 | load_from_cache_file=not args.overwrite_cache, 206 | desc="Running tokenizer on dataset", 207 | ) 208 | 209 | with accelerator.main_process_first(): 210 | eval_tokenized_datasets = eval_datasets.map( 211 | tokenize_function_eval, 212 | batched=True, 213 | num_proc=args.preprocessing_num_workers, 214 | remove_columns=eval_column_names, 215 | load_from_cache_file=not args.overwrite_cache, 216 | desc="Running tokenizer on dataset", 217 | ) 218 | 219 | with accelerator.main_process_first(): 220 | test_tokenized_datasets = eval_datasets.map( 221 | tokenize_function_test, 222 | batched=True, 223 | num_proc=args.preprocessing_num_workers, 224 | remove_columns=eval_column_names, 225 | load_from_cache_file=not args.overwrite_cache, 226 | desc="Running tokenizer on dataset", 227 | ) 228 | 229 | def preprocess_function(examples): 230 | examples["labels"] = examples["input_ids"].copy() 231 | examples["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in examples["labels"]] 232 | return examples 233 | 234 | with accelerator.main_process_first(): 235 | lm_datasets = tokenized_datasets.map( 236 | preprocess_function, 237 | batched=True, 238 | num_proc=args.preprocessing_num_workers, 239 | load_from_cache_file=not args.overwrite_cache, 240 | desc=f"Not grouping text.", 241 | ) 242 | 243 | with accelerator.main_process_first(): 244 | eval_lm_datasets = eval_tokenized_datasets.map( 245 | preprocess_function, 246 | batched=True, 247 | num_proc=args.preprocessing_num_workers, 248 | load_from_cache_file=not args.overwrite_cache, 249 | desc=f"Not grouping text.", 250 | ) 251 | 252 | with accelerator.main_process_first(): 253 | test_lm_datasets = test_tokenized_datasets.map( 254 | preprocess_function, 255 | batched=True, 256 | num_proc=args.preprocessing_num_workers, 257 | load_from_cache_file=not args.overwrite_cache, 258 | desc=f"Not grouping text.", 259 | ) 260 | 261 | train_dataset = lm_datasets["train"] 262 | train_dataloader = DataLoader(train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size) 263 | 264 | eval_dataset = eval_lm_datasets["train"] 265 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 266 | 267 | test_dataset = test_lm_datasets['train'] 268 | test_dataloader = DataLoader(test_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 269 | 270 | 271 | for index in random.sample(range(len(train_dataset)), 1): 272 | index = 0 273 | # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 274 | logger.info(f"Sample {index} of the training set (decoded): {tokenizer.decode(train_dataset[index]['input_ids'], skip_special_tokens=True)}.") 275 | for index in random.sample(range(len(eval_dataset)), 1): 276 | # logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") 277 | logger.info(f"Sample {index} of the validation set (decoded): {tokenizer.decode(eval_dataset[index]['input_ids'], skip_special_tokens=True)}.") 278 | 279 | # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 280 | 281 | # Prepare everything with our `accelerator`. 282 | model, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(model, train_dataloader, eval_dataloader, test_dataloader) 283 | 284 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 285 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) * args.num_grad_steps 286 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 287 | 288 | # Train! 289 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes 290 | 291 | logger.info("***** Running training *****") 292 | logger.info(f" Num examples = {len(train_dataset)}") 293 | logger.info(f" Num Epochs = {args.num_train_epochs}") 294 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 295 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 296 | logger.info(f" Total optimization steps = {args.max_train_steps}") 297 | 298 | # Only show the progress bar once on each machine. 299 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 300 | completed_steps = 0 301 | starting_epoch = 0 302 | 303 | # update the progress_bar if load from checkpoint 304 | progress_bar.update(starting_epoch * num_update_steps_per_epoch) 305 | completed_steps = starting_epoch * num_update_steps_per_epoch 306 | 307 | total_num_task_steps = args.num_train_epochs * args.num_data_samples + 1 308 | 309 | train_losses = [] 310 | eval_losses_all = torch.zeros(total_num_task_steps, len(eval_dataloader)) 311 | test_losses_all = torch.zeros(total_num_task_steps, len(test_dataloader)) 312 | 313 | # Initial Eval 314 | model.eval() 315 | with torch.no_grad(): 316 | for eval_step, batch in enumerate(eval_dataloader): 317 | outputs = model(**batch) 318 | loss = outputs.loss 319 | eval_losses_all[0, eval_step] = loss.detach() 320 | logger.info(f"Mean eval loss: {torch.mean(eval_losses_all[0, :])}") 321 | 322 | for test_step, batch in enumerate(test_dataloader): 323 | outputs = model(**batch) 324 | loss = outputs.loss 325 | test_losses_all[0, test_step] = loss.detach() 326 | logger.info(f"Mean test loss: {torch.mean(test_losses_all[0, :])}") 327 | 328 | for epoch in range(starting_epoch, args.num_train_epochs): 329 | 330 | for step, batch in enumerate(train_dataloader): 331 | 332 | if step > 0: 333 | del optimizer 334 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 335 | accelerator._optimizers = [] 336 | optimizer = accelerator.prepare(optimizer) 337 | 338 | global_train_step = epoch * args.num_data_samples + step + 1 339 | 340 | model.train() 341 | 342 | for grad_step in range(args.num_grad_steps): 343 | 344 | assert model.training 345 | outputs = model(**batch) 346 | loss = outputs.loss 347 | # keep track of the loss at each epoch 348 | train_losses.append(loss.detach().unsqueeze(0)) 349 | optimizer.zero_grad() 350 | accelerator.backward(loss) 351 | optimizer.step() 352 | 353 | # Checks if the accelerator has performed an optimization step behind the scenes 354 | if accelerator.sync_gradients: 355 | progress_bar.update(1) 356 | completed_steps += 1 357 | 358 | if args.eval_every_step: 359 | model.eval() 360 | with torch.no_grad(): 361 | for eval_step, eval_batch in enumerate(eval_dataloader): 362 | eval_outputs = model(**eval_batch) 363 | eval_loss = eval_outputs.loss 364 | eval_losses_all[global_train_step, eval_step] = eval_loss.detach() 365 | 366 | for test_step, test_batch in enumerate(test_dataloader): 367 | test_outputs = model(**test_batch) 368 | test_loss = test_outputs.loss 369 | test_losses_all[global_train_step, test_step] = test_loss.detach() 370 | 371 | # Logging 372 | output_dir = f"epoch_{epoch}" 373 | if args.output_dir is not None: 374 | output_dir = os.path.join(args.output_dir, output_dir) 375 | 376 | os.makedirs(output_dir, exist_ok=True) 377 | # if epoch == 0 or (epoch+1) % args.save_freq == 0: 378 | # accelerator.save_state(output_dir) 379 | 380 | # save train_losses 381 | train_losses_ckpt = torch.cat(train_losses) 382 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 383 | logger.info(f"Mean train loss: {np.mean(train_losses_ckpt)}") 384 | 385 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 386 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, completed_steps=completed_steps) 387 | 388 | 389 | if args.output_dir is not None: 390 | output_dir = os.path.join(args.output_dir, f'final') 391 | # save model and tokenizer 392 | accelerator.wait_for_everyone() 393 | # unwrapped_model = accelerator.unwrap_model(model) 394 | # unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save) 395 | if accelerator.is_main_process: 396 | tokenizer.save_pretrained(output_dir) 397 | 398 | # save train_losses 399 | train_losses_ckpt = torch.cat(train_losses) 400 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 401 | logger.info(f"Final mean train loss: {np.mean(train_losses_ckpt)}") 402 | 403 | eval_losses_all_ckpt = eval_losses_all.cpu().numpy() 404 | test_losses_all_ckpt = test_losses_all.cpu().numpy() 405 | 406 | # save results 407 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 408 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, eval_losses_ckpt=eval_losses_all_ckpt, test_losses_ckpt=test_losses_all_ckpt, completed_steps=completed_steps) 409 | 410 | 411 | if __name__ == "__main__": 412 | main() 413 | -------------------------------------------------------------------------------- /llm-experiments/training/train_interleave_optimizer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from itertools import chain 7 | 8 | import torch 9 | import numpy as np 10 | from datasets import load_dataset 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import set_seed 18 | from transformers import ( 19 | CONFIG_MAPPING, 20 | MODEL_MAPPING, 21 | AutoConfig, 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | default_data_collator, 25 | GPTNeoXForCausalLM, 26 | ) 27 | from copy import deepcopy 28 | 29 | logger = get_logger(__name__) 30 | 31 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 32 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 33 | 34 | 35 | def parse_args(): 36 | 37 | parser = argparse.ArgumentParser(description="Finetune large language models on causal language modeling tasks") 38 | parser.add_argument("--dataset_name", type=str, default=None, required=True, help="The name of the dataset to use (via the datasets library).") 39 | parser.add_argument("--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") 40 | parser.add_argument("--model_name_or_path", type=str, help="Path to pretrained model or model identifier from huggingface.co/models.", required=False) 41 | parser.add_argument("--revision", type=str, default='main', help="Model Branch") 42 | parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name") 43 | parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name") 44 | parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).") 45 | parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.") 46 | parser.add_argument("--learning_rate", type=float, default=0.0001, help="Initial learning rate (after the potential warmup period) to use.") 47 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 48 | parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") 49 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 50 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 51 | parser.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", choices=MODEL_TYPES) 52 | parser.add_argument("--block_size", type=int, default=None, help="The training dataset will be truncated to blocks of this size (after tokenization) for training.") 53 | parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.") 54 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 55 | parser.add_argument("--save_prefix", type=str, default='', help="Informative string prefix for saving purposes.") 56 | parser.add_argument("--use_pretrained_weights", action=argparse.BooleanOptionalAction, help="Whether to use pretrained weights.") 57 | parser.set_defaults(use_pretrained_weights=True) 58 | 59 | parser.add_argument("--eval_every_step", action=argparse.BooleanOptionalAction, help="Whether to eval every step.") 60 | parser.set_defaults(eval_every_step=True) 61 | 62 | parser.add_argument("--use_validation", action=argparse.BooleanOptionalAction, help="Whether to eval on validation set.") 63 | parser.set_defaults(use_validation=False) 64 | 65 | parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.") 66 | parser.add_argument("--eval_freq", type=int, default=1, help="Number of epochs before every recall experiment.") 67 | parser.add_argument("--save_freq", type=int, default=10, help="Number of epochs before every moodel and optimizer save.") 68 | 69 | parser.add_argument("--num-grad-steps", type=int, default=1, help="Number of gradient updates for each data point.") 70 | parser.add_argument("--num-data-samples", type=int, default=1, help="Number of tasks to interleave.") 71 | parser.add_argument("--num-eval-data-samples", type=int, default=100, help="Number of tasks to interleave.") 72 | 73 | args = parser.parse_args() 74 | 75 | return args 76 | 77 | 78 | def main(): 79 | args = parse_args() 80 | print(args) 81 | 82 | # eval_sample = [0] 83 | eval_sample = range(args.num_data_samples) 84 | 85 | accelerator = Accelerator() 86 | 87 | logging.basicConfig( 88 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 89 | datefmt="%m/%d/%Y %H:%M:%S", 90 | level=logging.INFO, 91 | ) 92 | logger.info(accelerator.state, main_process_only=False) 93 | if accelerator.is_local_main_process: 94 | transformers.utils.logging.set_verbosity_info() 95 | else: 96 | transformers.utils.logging.set_verbosity_error() 97 | 98 | if args.seed is not None: 99 | set_seed(args.seed) 100 | 101 | if accelerator.is_main_process: 102 | if args.output_dir is not None: 103 | os.makedirs(args.output_dir, exist_ok=True) 104 | accelerator.wait_for_everyone() 105 | 106 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 107 | 108 | if 'test' in raw_datasets.keys(): 109 | raw_datasets.pop('test') 110 | print("Length of Training Set", raw_datasets['train']) 111 | 112 | subset_task_index = random.sample(range(len(raw_datasets['train'])), args.num_data_samples) 113 | # subset_task_index = range(args.num_data_samples) 114 | 115 | raw_datasets['train'] = raw_datasets['train'].select(subset_task_index) 116 | if args.use_validation: 117 | subset_task_index_eval = random.sample(range(len(raw_datasets['validation'])), args.num_eval_data_samples) 118 | raw_datasets['validation'] = raw_datasets['validation'].select(subset_task_index_eval) 119 | elif 'validation' in raw_datasets.keys(): 120 | raw_datasets.pop('validation') 121 | 122 | eval_datasets = deepcopy(raw_datasets) 123 | eval_datasets['train'] = eval_datasets['train'].select(eval_sample) 124 | 125 | # Load pretrained model and tokenizer 126 | if args.config_name: 127 | config = AutoConfig.from_pretrained(args.config_name) 128 | elif args.model_name_or_path: 129 | config = AutoConfig.from_pretrained(args.model_name_or_path) 130 | else: 131 | config = CONFIG_MAPPING[args.model_type]() 132 | logger.warning("You are instantiating a new config instance from scratch.") 133 | 134 | if args.tokenizer_name: 135 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 136 | elif args.model_name_or_path: 137 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 138 | if args.model_name_or_path.startswith("gpt2") or args.model_name_or_path.startswith("EleutherAI"): 139 | tokenizer.pad_token = tokenizer.eos_token 140 | else: 141 | raise ValueError() 142 | 143 | if args.model_name_or_path and args.use_pretrained_weights: 144 | if 'pythia' in args.model_name_or_path: 145 | model_author, model_name = args.model_name_or_path.split('/') 146 | model = GPTNeoXForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, revision=args.revision, cache_dir=f"./{model_name}/{args.revision}") 147 | else: 148 | model = AutoModelForCausalLM.from_predtrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config) 149 | else: 150 | logger.info("Training new model from scratch") 151 | if 'pythia' in args.model_name_or_path: 152 | model = GPTNeoXForCausalLM(config) 153 | else: 154 | model = AutoModelForCausalLM.from_config(config) 155 | 156 | model.resize_token_embeddings(len(tokenizer)) 157 | 158 | if not args.use_pretrained_weights: 159 | with torch.no_grad(): 160 | for name, param in model.named_parameters(): 161 | if 'norm' not in name and 'bias' not in name: 162 | # print(f'Layer {name}, Weight Scale {torch.max(param).data}') 163 | if 'query_key_value' in name or 'dense_h_to_4h' in name or 'embed_in' in name or 'embed_out' in name: 164 | param *= math.sqrt(2 / (5 * 2048)) / 0.02 165 | elif 'attention.dense' in name or 'dense_4h_to_h' in name: 166 | param *= 2 / 16 / math.sqrt(2048) / 0.02 167 | 168 | for name, param in model.named_parameters(): 169 | if 'norm' not in name and 'bias' not in name: 170 | print(f'Layer {name}, Weight Scale {torch.max(param).data}') 171 | 172 | column_names = raw_datasets["train"].column_names 173 | eval_column_names = eval_datasets["train"].column_names 174 | 175 | test_text_column_name = 'highlights' 176 | text_column_name = eval_text_column_name = "article" 177 | 178 | if args.block_size is None: 179 | block_size = tokenizer.model_max_length 180 | else: 181 | block_size = args.block_size 182 | if args.block_size > tokenizer.model_max_length: 183 | logger.warning( 184 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" 185 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 186 | ) 187 | block_size = tokenizer.model_max_length 188 | print('Block size:', block_size) 189 | 190 | def tokenize_function(examples): 191 | return tokenizer(examples[text_column_name], padding='max_length', truncation=True, max_length=block_size) 192 | 193 | def tokenize_function_eval(examples): 194 | return tokenizer(examples[eval_text_column_name], truncation=True, max_length=block_size) 195 | 196 | def tokenize_function_test(examples): 197 | return tokenizer(examples[test_text_column_name], truncation=True, max_length=block_size) 198 | 199 | with accelerator.main_process_first(): 200 | tokenized_datasets = raw_datasets.map( 201 | tokenize_function, 202 | batched=True, 203 | num_proc=args.preprocessing_num_workers, 204 | remove_columns=column_names, 205 | load_from_cache_file=not args.overwrite_cache, 206 | desc="Running tokenizer on dataset", 207 | ) 208 | 209 | with accelerator.main_process_first(): 210 | eval_tokenized_datasets = eval_datasets.map( 211 | tokenize_function_eval, 212 | batched=True, 213 | num_proc=args.preprocessing_num_workers, 214 | remove_columns=eval_column_names, 215 | load_from_cache_file=not args.overwrite_cache, 216 | desc="Running tokenizer on dataset", 217 | ) 218 | 219 | with accelerator.main_process_first(): 220 | test_tokenized_datasets = eval_datasets.map( 221 | tokenize_function_test, 222 | batched=True, 223 | num_proc=args.preprocessing_num_workers, 224 | remove_columns=eval_column_names, 225 | load_from_cache_file=not args.overwrite_cache, 226 | desc="Running tokenizer on dataset", 227 | ) 228 | 229 | def preprocess_function(examples): 230 | examples["labels"] = examples["input_ids"].copy() 231 | examples["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in examples["labels"]] 232 | return examples 233 | 234 | with accelerator.main_process_first(): 235 | lm_datasets = tokenized_datasets.map( 236 | preprocess_function, 237 | batched=True, 238 | num_proc=args.preprocessing_num_workers, 239 | load_from_cache_file=not args.overwrite_cache, 240 | desc=f"Not grouping text.", 241 | ) 242 | 243 | with accelerator.main_process_first(): 244 | eval_lm_datasets = eval_tokenized_datasets.map( 245 | preprocess_function, 246 | batched=True, 247 | num_proc=args.preprocessing_num_workers, 248 | load_from_cache_file=not args.overwrite_cache, 249 | desc=f"Not grouping text.", 250 | ) 251 | 252 | with accelerator.main_process_first(): 253 | test_lm_datasets = test_tokenized_datasets.map( 254 | preprocess_function, 255 | batched=True, 256 | num_proc=args.preprocessing_num_workers, 257 | load_from_cache_file=not args.overwrite_cache, 258 | desc=f"Not grouping text.", 259 | ) 260 | 261 | train_dataset = lm_datasets["train"] 262 | train_dataloader = DataLoader(train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size) 263 | 264 | eval_dataset = eval_lm_datasets["train"] 265 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 266 | 267 | test_dataset = test_lm_datasets['train'] 268 | test_dataloader = DataLoader(test_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 269 | 270 | 271 | for index in random.sample(range(len(train_dataset)), 1): 272 | index = 0 273 | # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 274 | logger.info(f"Sample {index} of the training set (decoded): {tokenizer.decode(train_dataset[index]['input_ids'], skip_special_tokens=True)}.") 275 | for index in random.sample(range(len(eval_dataset)), 1): 276 | # logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") 277 | logger.info(f"Sample {index} of the validation set (decoded): {tokenizer.decode(eval_dataset[index]['input_ids'], skip_special_tokens=True)}.") 278 | 279 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 280 | 281 | # Prepare everything with our `accelerator`. 282 | model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader, test_dataloader) 283 | 284 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 285 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) * args.num_grad_steps 286 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 287 | 288 | # Train! 289 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes 290 | 291 | logger.info("***** Running training *****") 292 | logger.info(f" Num examples = {len(train_dataset)}") 293 | logger.info(f" Num Epochs = {args.num_train_epochs}") 294 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 295 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 296 | logger.info(f" Total optimization steps = {args.max_train_steps}") 297 | 298 | # Only show the progress bar once on each machine. 299 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 300 | completed_steps = 0 301 | starting_epoch = 0 302 | 303 | # update the progress_bar if load from checkpoint 304 | progress_bar.update(starting_epoch * num_update_steps_per_epoch) 305 | completed_steps = starting_epoch * num_update_steps_per_epoch 306 | 307 | total_num_task_steps = args.num_train_epochs * args.num_data_samples + 1 308 | 309 | train_losses = [] 310 | eval_losses_all = torch.zeros(total_num_task_steps, len(eval_dataloader)) 311 | test_losses_all = torch.zeros(total_num_task_steps, len(test_dataloader)) 312 | 313 | # Initial Eval 314 | model.eval() 315 | with torch.no_grad(): 316 | for eval_step, batch in enumerate(eval_dataloader): 317 | outputs = model(**batch) 318 | loss = outputs.loss 319 | eval_losses_all[0, eval_step] = loss.detach() 320 | logger.info(f"Mean eval loss: {torch.mean(eval_losses_all[0, :])}") 321 | 322 | for test_step, batch in enumerate(test_dataloader): 323 | outputs = model(**batch) 324 | loss = outputs.loss 325 | test_losses_all[0, test_step] = loss.detach() 326 | logger.info(f"Mean test loss: {torch.mean(test_losses_all[0, :])}") 327 | 328 | for epoch in range(starting_epoch, args.num_train_epochs): 329 | 330 | for step, batch in enumerate(train_dataloader): 331 | 332 | # if step > 0: 333 | # del optimizer 334 | # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 335 | # accelerator._optimizers = [] 336 | # optimizer = accelerator.prepare(optimizer) 337 | 338 | global_train_step = epoch * args.num_data_samples + step + 1 339 | 340 | model.train() 341 | 342 | for grad_step in range(args.num_grad_steps): 343 | 344 | assert model.training 345 | outputs = model(**batch) 346 | loss = outputs.loss 347 | # keep track of the loss at each epoch 348 | train_losses.append(loss.detach().unsqueeze(0)) 349 | optimizer.zero_grad() 350 | accelerator.backward(loss) 351 | optimizer.step() 352 | 353 | # Checks if the accelerator has performed an optimization step behind the scenes 354 | if accelerator.sync_gradients: 355 | progress_bar.update(1) 356 | completed_steps += 1 357 | 358 | if args.eval_every_step: 359 | model.eval() 360 | with torch.no_grad(): 361 | for eval_step, eval_batch in enumerate(eval_dataloader): 362 | eval_outputs = model(**eval_batch) 363 | eval_loss = eval_outputs.loss 364 | eval_losses_all[global_train_step, eval_step] = eval_loss.detach() 365 | 366 | for test_step, test_batch in enumerate(test_dataloader): 367 | test_outputs = model(**test_batch) 368 | test_loss = test_outputs.loss 369 | test_losses_all[global_train_step, test_step] = test_loss.detach() 370 | 371 | # Logging 372 | output_dir = f"epoch_{epoch}" 373 | if args.output_dir is not None: 374 | output_dir = os.path.join(args.output_dir, output_dir) 375 | 376 | os.makedirs(output_dir, exist_ok=True) 377 | # if epoch == 0 or (epoch+1) % args.save_freq == 0: 378 | # accelerator.save_state(output_dir) 379 | 380 | # save train_losses 381 | train_losses_ckpt = torch.cat(train_losses) 382 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 383 | logger.info(f"Mean train loss: {np.mean(train_losses_ckpt)}") 384 | 385 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 386 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, completed_steps=completed_steps) 387 | 388 | 389 | if args.output_dir is not None: 390 | output_dir = os.path.join(args.output_dir, f'final') 391 | # save model and tokenizer 392 | accelerator.wait_for_everyone() 393 | # unwrapped_model = accelerator.unwrap_model(model) 394 | # unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save) 395 | if accelerator.is_main_process: 396 | tokenizer.save_pretrained(output_dir) 397 | 398 | # save train_losses 399 | train_losses_ckpt = torch.cat(train_losses) 400 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 401 | logger.info(f"Final mean train loss: {np.mean(train_losses_ckpt)}") 402 | 403 | eval_losses_all_ckpt = eval_losses_all.cpu().numpy() 404 | test_losses_all_ckpt = test_losses_all.cpu().numpy() 405 | 406 | # save results 407 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 408 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, eval_losses_ckpt=eval_losses_all_ckpt, test_losses_ckpt=test_losses_all_ckpt, completed_steps=completed_steps) 409 | 410 | 411 | if __name__ == "__main__": 412 | main() 413 | -------------------------------------------------------------------------------- /llm-experiments/training/train_interleave.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from itertools import chain 7 | 8 | import torch 9 | import numpy as np 10 | from datasets import load_dataset 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import set_seed 18 | from transformers import ( 19 | CONFIG_MAPPING, 20 | MODEL_MAPPING, 21 | AutoConfig, 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | default_data_collator, 25 | GPTNeoXForCausalLM, 26 | ) 27 | from copy import deepcopy 28 | 29 | logger = get_logger(__name__) 30 | 31 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 32 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 33 | 34 | 35 | def parse_args(): 36 | 37 | parser = argparse.ArgumentParser(description="Finetune large language models on causal language modeling tasks") 38 | parser.add_argument("--dataset_name", type=str, default=None, required=True, help="The name of the dataset to use (via the datasets library).") 39 | parser.add_argument("--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") 40 | parser.add_argument("--model_name_or_path", type=str, help="Path to pretrained model or model identifier from huggingface.co/models.", required=False) 41 | parser.add_argument("--revision", type=str, default='main', help="Model Branch") 42 | parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name") 43 | parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name") 44 | parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).") 45 | parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.") 46 | parser.add_argument("--learning_rate", type=float, default=0.0001, help="Initial learning rate (after the potential warmup period) to use.") 47 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 48 | parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") 49 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 50 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 51 | parser.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", choices=MODEL_TYPES) 52 | parser.add_argument("--block_size", type=int, default=None, help="The training dataset will be truncated to blocks of this size (after tokenization) for training.") 53 | parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.") 54 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 55 | parser.add_argument("--save_prefix", type=str, default='', help="Informative string prefix for saving purposes.") 56 | parser.add_argument("--use_pretrained_weights", action=argparse.BooleanOptionalAction, help="Whether to use pretrained weights.") 57 | parser.set_defaults(use_pretrained_weights=True) 58 | 59 | parser.add_argument("--eval_every_step", action=argparse.BooleanOptionalAction, help="Whether to eval every step.") 60 | parser.set_defaults(eval_every_step=True) 61 | 62 | parser.add_argument("--use_validation", action=argparse.BooleanOptionalAction, help="Whether to eval on validation set.") 63 | parser.set_defaults(use_validation=False) 64 | 65 | parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.") 66 | parser.add_argument("--eval_freq", type=int, default=1, help="Number of epochs before every recall experiment.") 67 | parser.add_argument("--save_freq", type=int, default=10, help="Number of epochs before every moodel and optimizer save.") 68 | 69 | parser.add_argument("--num-grad-steps", type=int, default=1, help="Number of gradient updates for each data point.") 70 | parser.add_argument("--num-data-samples", type=int, default=1, help="Number of tasks to interleave.") 71 | parser.add_argument("--num-eval-data-samples", type=int, default=100, help="Number of tasks to interleave.") 72 | 73 | args = parser.parse_args() 74 | 75 | return args 76 | 77 | 78 | def main(): 79 | args = parse_args() 80 | print(args) 81 | 82 | # eval_sample = [0] 83 | eval_sample = range(args.num_data_samples) 84 | 85 | accelerator = Accelerator() 86 | 87 | logging.basicConfig( 88 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 89 | datefmt="%m/%d/%Y %H:%M:%S", 90 | level=logging.INFO, 91 | ) 92 | logger.info(accelerator.state, main_process_only=False) 93 | if accelerator.is_local_main_process: 94 | transformers.utils.logging.set_verbosity_info() 95 | else: 96 | transformers.utils.logging.set_verbosity_error() 97 | 98 | if args.seed is not None: 99 | set_seed(args.seed) 100 | 101 | if accelerator.is_main_process: 102 | if args.output_dir is not None: 103 | os.makedirs(args.output_dir, exist_ok=True) 104 | accelerator.wait_for_everyone() 105 | 106 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 107 | 108 | if 'test' in raw_datasets.keys(): 109 | raw_datasets.pop('test') 110 | print("Length of Training Set", raw_datasets['train']) 111 | 112 | subset_task_index = random.sample(range(len(raw_datasets['train'])), args.num_data_samples) 113 | # subset_task_index = range(args.num_data_samples) 114 | 115 | raw_datasets['train'] = raw_datasets['train'].select(subset_task_index) 116 | if args.use_validation: 117 | subset_task_index_eval = random.sample(range(len(raw_datasets['validation'])), args.num_eval_data_samples) 118 | raw_datasets['validation'] = raw_datasets['validation'].select(subset_task_index_eval) 119 | elif 'validation' in raw_datasets.keys(): 120 | raw_datasets.pop('validation') 121 | 122 | eval_datasets = deepcopy(raw_datasets) 123 | eval_datasets['train'] = eval_datasets['train'].select(eval_sample) 124 | 125 | # Load pretrained model and tokenizer 126 | if args.config_name: 127 | config = AutoConfig.from_pretrained(args.config_name) 128 | elif args.model_name_or_path: 129 | config = AutoConfig.from_pretrained(args.model_name_or_path) 130 | else: 131 | config = CONFIG_MAPPING[args.model_type]() 132 | logger.warning("You are instantiating a new config instance from scratch.") 133 | 134 | if args.tokenizer_name: 135 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 136 | elif args.model_name_or_path: 137 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 138 | if args.model_name_or_path.startswith("gpt2") or args.model_name_or_path.startswith("EleutherAI"): 139 | tokenizer.pad_token = tokenizer.eos_token 140 | else: 141 | raise ValueError() 142 | 143 | if args.model_name_or_path and args.use_pretrained_weights: 144 | if 'pythia' in args.model_name_or_path: 145 | model_author, model_name = args.model_name_or_path.split('/') 146 | model = GPTNeoXForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, revision=args.revision, cache_dir=f"./{model_name}/{args.revision}") 147 | else: 148 | model = AutoModelForCausalLM.from_predtrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config) 149 | else: 150 | logger.info("Training new model from scratch") 151 | if 'pythia' in args.model_name_or_path: 152 | model = GPTNeoXForCausalLM(config) 153 | else: 154 | model = AutoModelForCausalLM.from_config(config) 155 | 156 | model.resize_token_embeddings(len(tokenizer)) 157 | 158 | if not args.use_pretrained_weights: 159 | with torch.no_grad(): 160 | for name, param in model.named_parameters(): 161 | if 'norm' not in name and 'bias' not in name: 162 | # print(f'Layer {name}, Weight Scale {torch.max(param).data}') 163 | if 'query_key_value' in name or 'dense_h_to_4h' in name or 'embed_in' in name or 'embed_out' in name: 164 | param *= math.sqrt(2 / (5 * 2048)) / 0.02 165 | elif 'attention.dense' in name or 'dense_4h_to_h' in name: 166 | param *= 2 / 16 / math.sqrt(2048) / 0.02 167 | 168 | for name, param in model.named_parameters(): 169 | if 'norm' not in name and 'bias' not in name: 170 | print(f'Layer {name}, Weight Scale {torch.max(param).data}') 171 | 172 | column_names = raw_datasets["train"].column_names 173 | eval_column_names = eval_datasets["train"].column_names 174 | 175 | test_text_column_name = 'highlights' 176 | text_column_name = eval_text_column_name = "article" 177 | 178 | if args.block_size is None: 179 | block_size = tokenizer.model_max_length 180 | else: 181 | block_size = args.block_size 182 | if args.block_size > tokenizer.model_max_length: 183 | logger.warning( 184 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" 185 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 186 | ) 187 | block_size = tokenizer.model_max_length 188 | print('Block size:', block_size) 189 | 190 | def tokenize_function(examples): 191 | return tokenizer(examples[text_column_name], padding='max_length', truncation=True, max_length=block_size) 192 | 193 | def tokenize_function_eval(examples): 194 | return tokenizer(examples[eval_text_column_name], truncation=True, max_length=block_size) 195 | 196 | def tokenize_function_test(examples): 197 | return tokenizer(examples[test_text_column_name], truncation=True, max_length=block_size) 198 | 199 | with accelerator.main_process_first(): 200 | tokenized_datasets = raw_datasets.map( 201 | tokenize_function, 202 | batched=True, 203 | num_proc=args.preprocessing_num_workers, 204 | remove_columns=column_names, 205 | load_from_cache_file=not args.overwrite_cache, 206 | desc="Running tokenizer on dataset", 207 | ) 208 | 209 | with accelerator.main_process_first(): 210 | eval_tokenized_datasets = eval_datasets.map( 211 | tokenize_function_eval, 212 | batched=True, 213 | num_proc=args.preprocessing_num_workers, 214 | remove_columns=eval_column_names, 215 | load_from_cache_file=not args.overwrite_cache, 216 | desc="Running tokenizer on dataset", 217 | ) 218 | 219 | with accelerator.main_process_first(): 220 | test_tokenized_datasets = eval_datasets.map( 221 | tokenize_function_test, 222 | batched=True, 223 | num_proc=args.preprocessing_num_workers, 224 | remove_columns=eval_column_names, 225 | load_from_cache_file=not args.overwrite_cache, 226 | desc="Running tokenizer on dataset", 227 | ) 228 | 229 | def preprocess_function(examples): 230 | examples["labels"] = examples["input_ids"].copy() 231 | examples["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in examples["labels"]] 232 | return examples 233 | 234 | with accelerator.main_process_first(): 235 | lm_datasets = tokenized_datasets.map( 236 | preprocess_function, 237 | batched=True, 238 | num_proc=args.preprocessing_num_workers, 239 | load_from_cache_file=not args.overwrite_cache, 240 | desc=f"Not grouping text.", 241 | ) 242 | 243 | with accelerator.main_process_first(): 244 | eval_lm_datasets = eval_tokenized_datasets.map( 245 | preprocess_function, 246 | batched=True, 247 | num_proc=args.preprocessing_num_workers, 248 | load_from_cache_file=not args.overwrite_cache, 249 | desc=f"Not grouping text.", 250 | ) 251 | 252 | with accelerator.main_process_first(): 253 | test_lm_datasets = test_tokenized_datasets.map( 254 | preprocess_function, 255 | batched=True, 256 | num_proc=args.preprocessing_num_workers, 257 | load_from_cache_file=not args.overwrite_cache, 258 | desc=f"Not grouping text.", 259 | ) 260 | 261 | train_dataset = lm_datasets["train"] 262 | train_dataloader = DataLoader(train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size) 263 | 264 | eval_dataset = eval_lm_datasets["train"] 265 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 266 | 267 | test_dataset = test_lm_datasets['train'] 268 | test_dataloader = DataLoader(test_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 269 | 270 | 271 | for index in random.sample(range(len(train_dataset)), 1): 272 | index = 0 273 | # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 274 | logger.info(f"Sample {index} of the training set (decoded): {tokenizer.decode(train_dataset[index]['input_ids'], skip_special_tokens=True)}.") 275 | for index in random.sample(range(len(eval_dataset)), 1): 276 | # logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") 277 | logger.info(f"Sample {index} of the validation set (decoded): {tokenizer.decode(eval_dataset[index]['input_ids'], skip_special_tokens=True)}.") 278 | 279 | no_decay = ["bias", "layer_norm.weight"] 280 | optimizer_grouped_parameters = [ 281 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 282 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 283 | ] 284 | 285 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate) 286 | 287 | # Prepare everything with our `accelerator`. 288 | model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader, test_dataloader) 289 | 290 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 291 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) * args.num_grad_steps 292 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 293 | 294 | # Train! 295 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes 296 | 297 | logger.info("***** Running training *****") 298 | logger.info(f" Num examples = {len(train_dataset)}") 299 | logger.info(f" Num Epochs = {args.num_train_epochs}") 300 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 301 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 302 | logger.info(f" Total optimization steps = {args.max_train_steps}") 303 | 304 | # Only show the progress bar once on each machine. 305 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 306 | completed_steps = 0 307 | starting_epoch = 0 308 | 309 | # update the progress_bar if load from checkpoint 310 | progress_bar.update(starting_epoch * num_update_steps_per_epoch) 311 | completed_steps = starting_epoch * num_update_steps_per_epoch 312 | 313 | total_num_task_steps = args.num_train_epochs * args.num_data_samples + 1 314 | 315 | train_losses = [] 316 | eval_losses_all = torch.zeros(total_num_task_steps, len(eval_dataloader)) 317 | test_losses_all = torch.zeros(total_num_task_steps, len(test_dataloader)) 318 | 319 | # Initial Eval 320 | model.eval() 321 | with torch.no_grad(): 322 | for eval_step, batch in enumerate(eval_dataloader): 323 | outputs = model(**batch) 324 | loss = outputs.loss 325 | eval_losses_all[0, eval_step] = loss.detach() 326 | logger.info(f"Mean eval loss: {torch.mean(eval_losses_all[0, :])}") 327 | 328 | for test_step, batch in enumerate(test_dataloader): 329 | outputs = model(**batch) 330 | loss = outputs.loss 331 | test_losses_all[0, test_step] = loss.detach() 332 | logger.info(f"Mean test loss: {torch.mean(test_losses_all[0, :])}") 333 | 334 | for epoch in range(starting_epoch, args.num_train_epochs): 335 | 336 | for step, batch in enumerate(train_dataloader): 337 | 338 | global_train_step = epoch * args.num_data_samples + step + 1 339 | 340 | model.train() 341 | 342 | for grad_step in range(args.num_grad_steps): 343 | 344 | assert model.training 345 | outputs = model(**batch) 346 | loss = outputs.loss 347 | # keep track of the loss at each epoch 348 | train_losses.append(loss.detach().unsqueeze(0)) 349 | optimizer.zero_grad() 350 | accelerator.backward(loss) 351 | optimizer.step() 352 | 353 | # Checks if the accelerator has performed an optimization step behind the scenes 354 | if accelerator.sync_gradients: 355 | progress_bar.update(1) 356 | completed_steps += 1 357 | 358 | if args.eval_every_step: 359 | model.eval() 360 | with torch.no_grad(): 361 | for eval_step, eval_batch in enumerate(eval_dataloader): 362 | eval_outputs = model(**eval_batch) 363 | eval_loss = eval_outputs.loss 364 | eval_losses_all[global_train_step, eval_step] = eval_loss.detach() 365 | 366 | for test_step, test_batch in enumerate(test_dataloader): 367 | test_outputs = model(**test_batch) 368 | test_loss = test_outputs.loss 369 | test_losses_all[global_train_step, test_step] = test_loss.detach() 370 | 371 | # Logging 372 | output_dir = f"epoch_{epoch}" 373 | if args.output_dir is not None: 374 | output_dir = os.path.join(args.output_dir, output_dir) 375 | 376 | os.makedirs(output_dir, exist_ok=True) 377 | # if epoch == 0 or (epoch+1) % args.save_freq == 0: 378 | # accelerator.save_state(output_dir) 379 | 380 | # save train_losses 381 | train_losses_ckpt = torch.cat(train_losses) 382 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 383 | logger.info(f"Mean train loss: {np.mean(train_losses_ckpt)}") 384 | 385 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 386 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, completed_steps=completed_steps) 387 | 388 | 389 | if args.output_dir is not None: 390 | output_dir = os.path.join(args.output_dir, f'final') 391 | # save model and tokenizer 392 | accelerator.wait_for_everyone() 393 | # unwrapped_model = accelerator.unwrap_model(model) 394 | # unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save) 395 | if accelerator.is_main_process: 396 | tokenizer.save_pretrained(output_dir) 397 | 398 | # save train_losses 399 | train_losses_ckpt = torch.cat(train_losses) 400 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 401 | logger.info(f"Final mean train loss: {np.mean(train_losses_ckpt)}") 402 | 403 | eval_losses_all_ckpt = eval_losses_all.cpu().numpy() 404 | test_losses_all_ckpt = test_losses_all.cpu().numpy() 405 | 406 | # save results 407 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 408 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, eval_losses_ckpt=eval_losses_all_ckpt, test_losses_ckpt=test_losses_all_ckpt, completed_steps=completed_steps) 409 | 410 | 411 | if __name__ == "__main__": 412 | main() 413 | -------------------------------------------------------------------------------- /llm-experiments/training/train_interleave_randommask.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from itertools import chain 7 | 8 | import torch 9 | import numpy as np 10 | from datasets import load_dataset 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import set_seed 18 | from transformers import ( 19 | CONFIG_MAPPING, 20 | MODEL_MAPPING, 21 | AutoConfig, 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | default_data_collator, 25 | GPTNeoXForCausalLM, 26 | ) 27 | from copy import deepcopy 28 | 29 | logger = get_logger(__name__) 30 | 31 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 32 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 33 | 34 | 35 | def parse_args(): 36 | 37 | parser = argparse.ArgumentParser(description="Finetune large language models on causal language modeling tasks") 38 | parser.add_argument("--dataset_name", type=str, default=None, required=True, help="The name of the dataset to use (via the datasets library).") 39 | parser.add_argument("--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") 40 | parser.add_argument("--model_name_or_path", type=str, help="Path to pretrained model or model identifier from huggingface.co/models.", required=False) 41 | parser.add_argument("--revision", type=str, default='main', help="Model Branch") 42 | parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name") 43 | parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name") 44 | parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).") 45 | parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.") 46 | parser.add_argument("--learning_rate", type=float, default=0.0001, help="Initial learning rate (after the potential warmup period) to use.") 47 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 48 | parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") 49 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 50 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 51 | parser.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", choices=MODEL_TYPES) 52 | parser.add_argument("--block_size", type=int, default=None, help="The training dataset will be truncated to blocks of this size (after tokenization) for training.") 53 | parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.") 54 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 55 | parser.add_argument("--save_prefix", type=str, default='', help="Informative string prefix for saving purposes.") 56 | parser.add_argument("--use_pretrained_weights", action=argparse.BooleanOptionalAction, help="Whether to use pretrained weights.") 57 | parser.set_defaults(use_pretrained_weights=True) 58 | 59 | parser.add_argument("--eval_every_step", action=argparse.BooleanOptionalAction, help="Whether to eval every step.") 60 | parser.set_defaults(eval_every_step=True) 61 | 62 | parser.add_argument("--use_validation", action=argparse.BooleanOptionalAction, help="Whether to eval on validation set.") 63 | parser.set_defaults(use_validation=False) 64 | 65 | parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.") 66 | parser.add_argument("--eval_freq", type=int, default=1, help="Number of epochs before every recall experiment.") 67 | parser.add_argument("--save_freq", type=int, default=10, help="Number of epochs before every moodel and optimizer save.") 68 | 69 | parser.add_argument("--num-grad-steps", type=int, default=1, help="Number of gradient updates for each data point.") 70 | parser.add_argument("--num-data-samples", type=int, default=1, help="Number of tasks to interleave.") 71 | parser.add_argument("--num-eval-data-samples", type=int, default=100, help="Number of tasks to interleave.") 72 | parser.add_argument("--random-prob", type=float, default=0.0, help="Range of random window sampling") 73 | 74 | args = parser.parse_args() 75 | 76 | return args 77 | 78 | 79 | def main(): 80 | args = parse_args() 81 | print(args) 82 | 83 | # eval_sample = [0] 84 | eval_sample = range(args.num_data_samples) 85 | 86 | accelerator = Accelerator() 87 | 88 | logging.basicConfig( 89 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 90 | datefmt="%m/%d/%Y %H:%M:%S", 91 | level=logging.INFO, 92 | ) 93 | logger.info(accelerator.state, main_process_only=False) 94 | if accelerator.is_local_main_process: 95 | transformers.utils.logging.set_verbosity_info() 96 | else: 97 | transformers.utils.logging.set_verbosity_error() 98 | 99 | if args.seed is not None: 100 | set_seed(args.seed) 101 | 102 | if accelerator.is_main_process: 103 | if args.output_dir is not None: 104 | os.makedirs(args.output_dir, exist_ok=True) 105 | accelerator.wait_for_everyone() 106 | 107 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 108 | 109 | if 'test' in raw_datasets.keys(): 110 | raw_datasets.pop('test') 111 | print("Length of Training Set", raw_datasets['train']) 112 | 113 | subset_task_index = random.sample(range(len(raw_datasets['train'])), args.num_data_samples) 114 | # subset_task_index = range(args.num_data_samples) 115 | 116 | raw_datasets['train'] = raw_datasets['train'].select(subset_task_index) 117 | if args.use_validation: 118 | subset_task_index_eval = random.sample(range(len(raw_datasets['validation'])), args.num_eval_data_samples) 119 | raw_datasets['validation'] = raw_datasets['validation'].select(subset_task_index_eval) 120 | elif 'validation' in raw_datasets.keys(): 121 | raw_datasets.pop('validation') 122 | 123 | eval_datasets = deepcopy(raw_datasets) 124 | eval_datasets['train'] = eval_datasets['train'].select(eval_sample) 125 | 126 | # Load pretrained model and tokenizer 127 | if args.config_name: 128 | config = AutoConfig.from_pretrained(args.config_name) 129 | elif args.model_name_or_path: 130 | config = AutoConfig.from_pretrained(args.model_name_or_path) 131 | else: 132 | config = CONFIG_MAPPING[args.model_type]() 133 | logger.warning("You are instantiating a new config instance from scratch.") 134 | 135 | if args.tokenizer_name: 136 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 137 | elif args.model_name_or_path: 138 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 139 | if args.model_name_or_path.startswith("gpt2") or args.model_name_or_path.startswith("EleutherAI"): 140 | tokenizer.pad_token = tokenizer.eos_token 141 | else: 142 | raise ValueError() 143 | 144 | if args.model_name_or_path and args.use_pretrained_weights: 145 | if 'pythia' in args.model_name_or_path: 146 | model_author, model_name = args.model_name_or_path.split('/') 147 | model = GPTNeoXForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, revision=args.revision, cache_dir=f"./{model_name}/{args.revision}") 148 | else: 149 | model = AutoModelForCausalLM.from_predtrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config) 150 | else: 151 | logger.info("Training new model from scratch") 152 | if 'pythia' in args.model_name_or_path: 153 | model = GPTNeoXForCausalLM(config) 154 | else: 155 | model = AutoModelForCausalLM.from_config(config) 156 | 157 | model.resize_token_embeddings(len(tokenizer)) 158 | 159 | if not args.use_pretrained_weights: 160 | with torch.no_grad(): 161 | for name, param in model.named_parameters(): 162 | if 'norm' not in name and 'bias' not in name: 163 | # print(f'Layer {name}, Weight Scale {torch.max(param).data}') 164 | if 'query_key_value' in name or 'dense_h_to_4h' in name or 'embed_in' in name or 'embed_out' in name: 165 | param *= math.sqrt(2 / (5 * 2048)) / 0.02 166 | elif 'attention.dense' in name or 'dense_4h_to_h' in name: 167 | param *= 2 / 16 / math.sqrt(2048) / 0.02 168 | 169 | for name, param in model.named_parameters(): 170 | if 'norm' not in name and 'bias' not in name: 171 | print(f'Layer {name}, Weight Scale {torch.max(param).data}') 172 | 173 | column_names = raw_datasets["train"].column_names 174 | eval_column_names = eval_datasets["train"].column_names 175 | 176 | test_text_column_name = 'highlights' 177 | text_column_name = eval_text_column_name = "article" 178 | 179 | if args.block_size is None: 180 | block_size = tokenizer.model_max_length 181 | else: 182 | block_size = args.block_size 183 | if args.block_size > tokenizer.model_max_length: 184 | logger.warning( 185 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" 186 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 187 | ) 188 | block_size = tokenizer.model_max_length 189 | print('Block size:', block_size) 190 | 191 | def tokenize_function(examples): 192 | return tokenizer(examples[text_column_name], padding='max_length', truncation=True, max_length=block_size) 193 | 194 | def tokenize_function_eval(examples): 195 | return tokenizer(examples[eval_text_column_name], truncation=True, max_length=block_size) 196 | 197 | def tokenize_function_test(examples): 198 | return tokenizer(examples[test_text_column_name], truncation=True, max_length=block_size) 199 | 200 | with accelerator.main_process_first(): 201 | tokenized_datasets = raw_datasets.map( 202 | tokenize_function, 203 | batched=True, 204 | num_proc=args.preprocessing_num_workers, 205 | remove_columns=column_names, 206 | load_from_cache_file=not args.overwrite_cache, 207 | desc="Running tokenizer on dataset", 208 | ) 209 | 210 | with accelerator.main_process_first(): 211 | eval_tokenized_datasets = eval_datasets.map( 212 | tokenize_function_eval, 213 | batched=True, 214 | num_proc=args.preprocessing_num_workers, 215 | remove_columns=eval_column_names, 216 | load_from_cache_file=not args.overwrite_cache, 217 | desc="Running tokenizer on dataset", 218 | ) 219 | 220 | with accelerator.main_process_first(): 221 | test_tokenized_datasets = eval_datasets.map( 222 | tokenize_function_test, 223 | batched=True, 224 | num_proc=args.preprocessing_num_workers, 225 | remove_columns=eval_column_names, 226 | load_from_cache_file=not args.overwrite_cache, 227 | desc="Running tokenizer on dataset", 228 | ) 229 | 230 | def preprocess_function(examples): 231 | examples["labels"] = examples["input_ids"].copy() 232 | examples["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in examples["labels"]] 233 | return examples 234 | 235 | with accelerator.main_process_first(): 236 | lm_datasets = tokenized_datasets.map( 237 | preprocess_function, 238 | batched=True, 239 | num_proc=args.preprocessing_num_workers, 240 | load_from_cache_file=not args.overwrite_cache, 241 | desc=f"Not grouping text.", 242 | ) 243 | 244 | with accelerator.main_process_first(): 245 | eval_lm_datasets = eval_tokenized_datasets.map( 246 | preprocess_function, 247 | batched=True, 248 | num_proc=args.preprocessing_num_workers, 249 | load_from_cache_file=not args.overwrite_cache, 250 | desc=f"Not grouping text.", 251 | ) 252 | 253 | with accelerator.main_process_first(): 254 | test_lm_datasets = test_tokenized_datasets.map( 255 | preprocess_function, 256 | batched=True, 257 | num_proc=args.preprocessing_num_workers, 258 | load_from_cache_file=not args.overwrite_cache, 259 | desc=f"Not grouping text.", 260 | ) 261 | 262 | train_dataset = lm_datasets["train"] 263 | train_dataloader = DataLoader(train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size) 264 | 265 | eval_dataset = eval_lm_datasets["train"] 266 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 267 | 268 | test_dataset = test_lm_datasets['train'] 269 | test_dataloader = DataLoader(test_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 270 | 271 | 272 | for index in random.sample(range(len(train_dataset)), 1): 273 | index = 0 274 | # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 275 | logger.info(f"Sample {index} of the training set (decoded): {tokenizer.decode(train_dataset[index]['input_ids'], skip_special_tokens=True)}.") 276 | for index in random.sample(range(len(eval_dataset)), 1): 277 | # logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") 278 | logger.info(f"Sample {index} of the validation set (decoded): {tokenizer.decode(eval_dataset[index]['input_ids'], skip_special_tokens=True)}.") 279 | 280 | no_decay = ["bias", "layer_norm.weight"] 281 | optimizer_grouped_parameters = [ 282 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 283 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 284 | ] 285 | 286 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate) 287 | 288 | # Prepare everything with our `accelerator`. 289 | model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader, test_dataloader) 290 | 291 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 292 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) * args.num_grad_steps 293 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 294 | 295 | # Train! 296 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes 297 | 298 | logger.info("***** Running training *****") 299 | logger.info(f" Num examples = {len(train_dataset)}") 300 | logger.info(f" Num Epochs = {args.num_train_epochs}") 301 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 302 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 303 | logger.info(f" Total optimization steps = {args.max_train_steps}") 304 | 305 | # Only show the progress bar once on each machine. 306 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 307 | completed_steps = 0 308 | starting_epoch = 0 309 | 310 | # update the progress_bar if load from checkpoint 311 | progress_bar.update(starting_epoch * num_update_steps_per_epoch) 312 | completed_steps = starting_epoch * num_update_steps_per_epoch 313 | 314 | total_num_task_steps = args.num_train_epochs * args.num_data_samples + 1 315 | 316 | train_losses = [] 317 | eval_losses_all = torch.zeros(total_num_task_steps, len(eval_dataloader)) 318 | test_losses_all = torch.zeros(total_num_task_steps, len(test_dataloader)) 319 | 320 | # Initial Eval 321 | model.eval() 322 | with torch.no_grad(): 323 | for eval_step, batch in enumerate(eval_dataloader): 324 | outputs = model(**batch) 325 | loss = outputs.loss 326 | eval_losses_all[0, eval_step] = loss.detach() 327 | logger.info(f"Mean eval loss: {torch.mean(eval_losses_all[0, :])}") 328 | 329 | for test_step, batch in enumerate(test_dataloader): 330 | outputs = model(**batch) 331 | loss = outputs.loss 332 | test_losses_all[0, test_step] = loss.detach() 333 | logger.info(f"Mean test loss: {torch.mean(test_losses_all[0, :])}") 334 | 335 | for epoch in range(starting_epoch, args.num_train_epochs): 336 | 337 | for step, batch in enumerate(train_dataloader): 338 | 339 | global_train_step = epoch * args.num_data_samples + step + 1 340 | 341 | model.train() 342 | 343 | orig_batch = batch.copy() 344 | 345 | for grad_step in range(args.num_grad_steps): 346 | 347 | assert model.training 348 | 349 | mask = (torch.rand(orig_batch['input_ids'].shape) > args.random_prob).to('cuda') 350 | batch['input_ids'] = orig_batch['input_ids'] * mask 351 | 352 | outputs = model(**batch) 353 | loss = outputs.loss 354 | # keep track of the loss at each epoch 355 | train_losses.append(loss.detach().unsqueeze(0)) 356 | optimizer.zero_grad() 357 | accelerator.backward(loss) 358 | optimizer.step() 359 | 360 | # Checks if the accelerator has performed an optimization step behind the scenes 361 | if accelerator.sync_gradients: 362 | progress_bar.update(1) 363 | completed_steps += 1 364 | 365 | if args.eval_every_step: 366 | model.eval() 367 | with torch.no_grad(): 368 | for eval_step, eval_batch in enumerate(eval_dataloader): 369 | eval_outputs = model(**eval_batch) 370 | eval_loss = eval_outputs.loss 371 | eval_losses_all[global_train_step, eval_step] = eval_loss.detach() 372 | 373 | for test_step, test_batch in enumerate(test_dataloader): 374 | test_outputs = model(**test_batch) 375 | test_loss = test_outputs.loss 376 | test_losses_all[global_train_step, test_step] = test_loss.detach() 377 | 378 | # Logging 379 | output_dir = f"epoch_{epoch}" 380 | if args.output_dir is not None: 381 | output_dir = os.path.join(args.output_dir, output_dir) 382 | 383 | os.makedirs(output_dir, exist_ok=True) 384 | # if epoch == 0 or (epoch+1) % args.save_freq == 0: 385 | # accelerator.save_state(output_dir) 386 | 387 | # save train_losses 388 | train_losses_ckpt = torch.cat(train_losses) 389 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 390 | logger.info(f"Mean train loss: {np.mean(train_losses_ckpt)}") 391 | 392 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 393 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, completed_steps=completed_steps) 394 | 395 | 396 | if args.output_dir is not None: 397 | output_dir = os.path.join(args.output_dir, f'final') 398 | # save model and tokenizer 399 | accelerator.wait_for_everyone() 400 | # unwrapped_model = accelerator.unwrap_model(model) 401 | # unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save) 402 | if accelerator.is_main_process: 403 | tokenizer.save_pretrained(output_dir) 404 | 405 | # save train_losses 406 | train_losses_ckpt = torch.cat(train_losses) 407 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 408 | logger.info(f"Final mean train loss: {np.mean(train_losses_ckpt)}") 409 | 410 | eval_losses_all_ckpt = eval_losses_all.cpu().numpy() 411 | test_losses_all_ckpt = test_losses_all.cpu().numpy() 412 | 413 | # save results 414 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 415 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, eval_losses_ckpt=eval_losses_all_ckpt, test_losses_ckpt=test_losses_all_ckpt, completed_steps=completed_steps) 416 | 417 | 418 | if __name__ == "__main__": 419 | main() 420 | -------------------------------------------------------------------------------- /llm-experiments/training/train_interleave_frozen_blocks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from itertools import chain 7 | 8 | import torch 9 | import numpy as np 10 | from datasets import load_dataset 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import set_seed 18 | from transformers import ( 19 | CONFIG_MAPPING, 20 | MODEL_MAPPING, 21 | AutoConfig, 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | default_data_collator, 25 | GPTNeoXForCausalLM, 26 | ) 27 | from copy import deepcopy 28 | 29 | logger = get_logger(__name__) 30 | 31 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 32 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 33 | 34 | 35 | def parse_args(): 36 | 37 | parser = argparse.ArgumentParser(description="Finetune large language models on causal language modeling tasks") 38 | parser.add_argument("--dataset_name", type=str, default=None, required=True, help="The name of the dataset to use (via the datasets library).") 39 | parser.add_argument("--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") 40 | parser.add_argument("--model_name_or_path", type=str, help="Path to pretrained model or model identifier from huggingface.co/models.", required=False) 41 | parser.add_argument("--revision", type=str, default='main', help="Model Branch") 42 | parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name") 43 | parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name") 44 | parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).") 45 | parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.") 46 | parser.add_argument("--learning_rate", type=float, default=0.0001, help="Initial learning rate (after the potential warmup period) to use.") 47 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 48 | parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") 49 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 50 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 51 | parser.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", choices=MODEL_TYPES) 52 | parser.add_argument("--block_size", type=int, default=None, help="The training dataset will be truncated to blocks of this size (after tokenization) for training.") 53 | parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.") 54 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 55 | parser.add_argument("--save_prefix", type=str, default='', help="Informative string prefix for saving purposes.") 56 | parser.add_argument("--use_pretrained_weights", action=argparse.BooleanOptionalAction, help="Whether to use pretrained weights.") 57 | parser.set_defaults(use_pretrained_weights=True) 58 | 59 | parser.add_argument("--eval_every_step", action=argparse.BooleanOptionalAction, help="Whether to eval every step.") 60 | parser.set_defaults(eval_every_step=True) 61 | 62 | parser.add_argument("--use_validation", action=argparse.BooleanOptionalAction, help="Whether to eval on validation set.") 63 | parser.set_defaults(use_validation=False) 64 | 65 | parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.") 66 | parser.add_argument("--eval_freq", type=int, default=1, help="Number of epochs before every recall experiment.") 67 | parser.add_argument("--save_freq", type=int, default=10, help="Number of epochs before every moodel and optimizer save.") 68 | 69 | parser.add_argument("--num-grad-steps", type=int, default=1, help="Number of gradient updates for each data point.") 70 | parser.add_argument("--num-data-samples", type=int, default=1, help="Number of tasks to interleave.") 71 | parser.add_argument("--num-eval-data-samples", type=int, default=100, help="Number of tasks to interleave.") 72 | 73 | parser.add_argument("--num-frozen-blocks", type=int, default=0, help="Number of transformer blocks to freeze.") 74 | parser.add_argument("--num-frozen-blocks-rev", type=int, default=0, help="Number of transformer blocks to freeze (from the end)") 75 | 76 | args = parser.parse_args() 77 | 78 | return args 79 | 80 | 81 | def main(): 82 | args = parse_args() 83 | print(args) 84 | 85 | # eval_sample = [0] 86 | eval_sample = range(args.num_data_samples) 87 | 88 | accelerator = Accelerator() 89 | 90 | logging.basicConfig( 91 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 92 | datefmt="%m/%d/%Y %H:%M:%S", 93 | level=logging.INFO, 94 | ) 95 | logger.info(accelerator.state, main_process_only=False) 96 | if accelerator.is_local_main_process: 97 | transformers.utils.logging.set_verbosity_info() 98 | else: 99 | transformers.utils.logging.set_verbosity_error() 100 | 101 | if args.seed is not None: 102 | set_seed(args.seed) 103 | 104 | if accelerator.is_main_process: 105 | if args.output_dir is not None: 106 | os.makedirs(args.output_dir, exist_ok=True) 107 | accelerator.wait_for_everyone() 108 | 109 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 110 | 111 | if 'test' in raw_datasets.keys(): 112 | raw_datasets.pop('test') 113 | print("Length of Training Set", raw_datasets['train']) 114 | 115 | subset_task_index = random.sample(range(len(raw_datasets['train'])), args.num_data_samples) 116 | # subset_task_index = range(args.num_data_samples) 117 | 118 | raw_datasets['train'] = raw_datasets['train'].select(subset_task_index) 119 | if args.use_validation: 120 | subset_task_index_eval = random.sample(range(len(raw_datasets['validation'])), args.num_eval_data_samples) 121 | raw_datasets['validation'] = raw_datasets['validation'].select(subset_task_index_eval) 122 | elif 'validation' in raw_datasets.keys(): 123 | raw_datasets.pop('validation') 124 | 125 | eval_datasets = deepcopy(raw_datasets) 126 | eval_datasets['train'] = eval_datasets['train'].select(eval_sample) 127 | 128 | # Load pretrained model and tokenizer 129 | if args.config_name: 130 | config = AutoConfig.from_pretrained(args.config_name) 131 | elif args.model_name_or_path: 132 | config = AutoConfig.from_pretrained(args.model_name_or_path) 133 | else: 134 | config = CONFIG_MAPPING[args.model_type]() 135 | logger.warning("You are instantiating a new config instance from scratch.") 136 | 137 | if args.tokenizer_name: 138 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 139 | elif args.model_name_or_path: 140 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 141 | if args.model_name_or_path.startswith("gpt2") or args.model_name_or_path.startswith("EleutherAI"): 142 | tokenizer.pad_token = tokenizer.eos_token 143 | else: 144 | raise ValueError() 145 | 146 | if args.model_name_or_path and args.use_pretrained_weights: 147 | if 'pythia' in args.model_name_or_path: 148 | model_author, model_name = args.model_name_or_path.split('/') 149 | model = GPTNeoXForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, revision=args.revision, cache_dir=f"./{model_name}/{args.revision}") 150 | else: 151 | model = AutoModelForCausalLM.from_predtrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config) 152 | else: 153 | logger.info("Training new model from scratch") 154 | if 'pythia' in args.model_name_or_path: 155 | model = GPTNeoXForCausalLM(config) 156 | else: 157 | model = AutoModelForCausalLM.from_config(config) 158 | 159 | model.resize_token_embeddings(len(tokenizer)) 160 | 161 | if not args.use_pretrained_weights: 162 | with torch.no_grad(): 163 | for name, param in model.named_parameters(): 164 | if 'norm' not in name and 'bias' not in name: 165 | # print(f'Layer {name}, Weight Scale {torch.max(param).data}') 166 | if 'query_key_value' in name or 'dense_h_to_4h' in name or 'embed_in' in name or 'embed_out' in name: 167 | param *= math.sqrt(2 / (5 * 2048)) / 0.02 168 | elif 'attention.dense' in name or 'dense_4h_to_h' in name: 169 | param *= 2 / 16 / math.sqrt(2048) / 0.02 170 | 171 | for name, param in model.named_parameters(): 172 | if 'norm' not in name and 'bias' not in name: 173 | print(f'Layer {name}, Weight Scale {torch.max(param).data}') 174 | 175 | column_names = raw_datasets["train"].column_names 176 | eval_column_names = eval_datasets["train"].column_names 177 | 178 | test_text_column_name = 'highlights' 179 | text_column_name = eval_text_column_name = "article" 180 | 181 | if args.block_size is None: 182 | block_size = tokenizer.model_max_length 183 | else: 184 | block_size = args.block_size 185 | if args.block_size > tokenizer.model_max_length: 186 | logger.warning( 187 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" 188 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 189 | ) 190 | block_size = tokenizer.model_max_length 191 | print('Block size:', block_size) 192 | 193 | def tokenize_function(examples): 194 | return tokenizer(examples[text_column_name], padding='max_length', truncation=True, max_length=block_size) 195 | 196 | def tokenize_function_eval(examples): 197 | return tokenizer(examples[eval_text_column_name], truncation=True, max_length=block_size) 198 | 199 | def tokenize_function_test(examples): 200 | return tokenizer(examples[test_text_column_name], truncation=True, max_length=block_size) 201 | 202 | with accelerator.main_process_first(): 203 | tokenized_datasets = raw_datasets.map( 204 | tokenize_function, 205 | batched=True, 206 | num_proc=args.preprocessing_num_workers, 207 | remove_columns=column_names, 208 | load_from_cache_file=not args.overwrite_cache, 209 | desc="Running tokenizer on dataset", 210 | ) 211 | 212 | with accelerator.main_process_first(): 213 | eval_tokenized_datasets = eval_datasets.map( 214 | tokenize_function_eval, 215 | batched=True, 216 | num_proc=args.preprocessing_num_workers, 217 | remove_columns=eval_column_names, 218 | load_from_cache_file=not args.overwrite_cache, 219 | desc="Running tokenizer on dataset", 220 | ) 221 | 222 | with accelerator.main_process_first(): 223 | test_tokenized_datasets = eval_datasets.map( 224 | tokenize_function_test, 225 | batched=True, 226 | num_proc=args.preprocessing_num_workers, 227 | remove_columns=eval_column_names, 228 | load_from_cache_file=not args.overwrite_cache, 229 | desc="Running tokenizer on dataset", 230 | ) 231 | 232 | def preprocess_function(examples): 233 | examples["labels"] = examples["input_ids"].copy() 234 | examples["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in examples["labels"]] 235 | return examples 236 | 237 | with accelerator.main_process_first(): 238 | lm_datasets = tokenized_datasets.map( 239 | preprocess_function, 240 | batched=True, 241 | num_proc=args.preprocessing_num_workers, 242 | load_from_cache_file=not args.overwrite_cache, 243 | desc=f"Not grouping text.", 244 | ) 245 | 246 | with accelerator.main_process_first(): 247 | eval_lm_datasets = eval_tokenized_datasets.map( 248 | preprocess_function, 249 | batched=True, 250 | num_proc=args.preprocessing_num_workers, 251 | load_from_cache_file=not args.overwrite_cache, 252 | desc=f"Not grouping text.", 253 | ) 254 | 255 | with accelerator.main_process_first(): 256 | test_lm_datasets = test_tokenized_datasets.map( 257 | preprocess_function, 258 | batched=True, 259 | num_proc=args.preprocessing_num_workers, 260 | load_from_cache_file=not args.overwrite_cache, 261 | desc=f"Not grouping text.", 262 | ) 263 | 264 | train_dataset = lm_datasets["train"] 265 | train_dataloader = DataLoader(train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size) 266 | 267 | eval_dataset = eval_lm_datasets["train"] 268 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 269 | 270 | test_dataset = test_lm_datasets['train'] 271 | test_dataloader = DataLoader(test_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 272 | 273 | 274 | for index in random.sample(range(len(train_dataset)), 1): 275 | index = 0 276 | # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 277 | logger.info(f"Sample {index} of the training set (decoded): {tokenizer.decode(train_dataset[index]['input_ids'], skip_special_tokens=True)}.") 278 | for index in random.sample(range(len(eval_dataset)), 1): 279 | # logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") 280 | logger.info(f"Sample {index} of the validation set (decoded): {tokenizer.decode(eval_dataset[index]['input_ids'], skip_special_tokens=True)}.") 281 | 282 | no_decay = ["bias", "layer_norm.weight"] 283 | optimizer_grouped_parameters = [ 284 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 285 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 286 | ] 287 | 288 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate) 289 | 290 | # Prepare everything with our `accelerator`. 291 | model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader, test_dataloader) 292 | 293 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 294 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) * args.num_grad_steps 295 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 296 | 297 | # Train! 298 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes 299 | 300 | logger.info("***** Running training *****") 301 | logger.info(f" Num examples = {len(train_dataset)}") 302 | logger.info(f" Num Epochs = {args.num_train_epochs}") 303 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 304 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 305 | logger.info(f" Total optimization steps = {args.max_train_steps}") 306 | 307 | # Only show the progress bar once on each machine. 308 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 309 | completed_steps = 0 310 | starting_epoch = 0 311 | 312 | # update the progress_bar if load from checkpoint 313 | progress_bar.update(starting_epoch * num_update_steps_per_epoch) 314 | completed_steps = starting_epoch * num_update_steps_per_epoch 315 | 316 | total_num_task_steps = args.num_train_epochs * args.num_data_samples + 1 317 | 318 | train_losses = [] 319 | eval_losses_all = torch.zeros(total_num_task_steps, len(eval_dataloader)) 320 | test_losses_all = torch.zeros(total_num_task_steps, len(test_dataloader)) 321 | 322 | # Initial Eval 323 | model.eval() 324 | with torch.no_grad(): 325 | for eval_step, batch in enumerate(eval_dataloader): 326 | outputs = model(**batch) 327 | loss = outputs.loss 328 | eval_losses_all[0, eval_step] = loss.detach() 329 | logger.info(f"Mean eval loss: {torch.mean(eval_losses_all[0, :])}") 330 | 331 | for test_step, batch in enumerate(test_dataloader): 332 | outputs = model(**batch) 333 | loss = outputs.loss 334 | test_losses_all[0, test_step] = loss.detach() 335 | logger.info(f"Mean test loss: {torch.mean(test_losses_all[0, :])}") 336 | 337 | for epoch in range(starting_epoch, args.num_train_epochs): 338 | 339 | for step, batch in enumerate(train_dataloader): 340 | 341 | global_train_step = epoch * args.num_data_samples + step + 1 342 | 343 | model.train() 344 | 345 | if args.num_frozen_blocks > 0: 346 | for layer_idx, (name, param) in enumerate(model.named_parameters()): 347 | if layer_idx <= 12 * args.num_frozen_blocks: 348 | param.requires_grad = False 349 | 350 | if args.num_frozen_blocks_rev > 0: 351 | for layer_idx, (name, param) in enumerate(model.named_parameters()): 352 | if layer_idx > 12 * (16 - args.num_frozen_blocks_rev) and layer_idx <= 12 * 16: 353 | param.requires_grad = False 354 | 355 | 356 | for grad_step in range(args.num_grad_steps): 357 | 358 | assert model.training 359 | outputs = model(**batch) 360 | loss = outputs.loss 361 | # keep track of the loss at each epoch 362 | train_losses.append(loss.detach().unsqueeze(0)) 363 | optimizer.zero_grad() 364 | accelerator.backward(loss) 365 | optimizer.step() 366 | 367 | # Checks if the accelerator has performed an optimization step behind the scenes 368 | if accelerator.sync_gradients: 369 | progress_bar.update(1) 370 | completed_steps += 1 371 | 372 | if args.eval_every_step: 373 | model.eval() 374 | with torch.no_grad(): 375 | for eval_step, eval_batch in enumerate(eval_dataloader): 376 | eval_outputs = model(**eval_batch) 377 | eval_loss = eval_outputs.loss 378 | eval_losses_all[global_train_step, eval_step] = eval_loss.detach() 379 | 380 | for test_step, test_batch in enumerate(test_dataloader): 381 | test_outputs = model(**test_batch) 382 | test_loss = test_outputs.loss 383 | test_losses_all[global_train_step, test_step] = test_loss.detach() 384 | 385 | # Logging 386 | output_dir = f"epoch_{epoch}" 387 | if args.output_dir is not None: 388 | output_dir = os.path.join(args.output_dir, output_dir) 389 | 390 | os.makedirs(output_dir, exist_ok=True) 391 | # if epoch == 0 or (epoch+1) % args.save_freq == 0: 392 | # accelerator.save_state(output_dir) 393 | 394 | # save train_losses 395 | train_losses_ckpt = torch.cat(train_losses) 396 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 397 | logger.info(f"Mean train loss: {np.mean(train_losses_ckpt)}") 398 | 399 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 400 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, completed_steps=completed_steps) 401 | 402 | 403 | if args.output_dir is not None: 404 | output_dir = os.path.join(args.output_dir, f'final') 405 | # save model and tokenizer 406 | accelerator.wait_for_everyone() 407 | unwrapped_model = accelerator.unwrap_model(model) 408 | unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save) 409 | if accelerator.is_main_process: 410 | tokenizer.save_pretrained(output_dir) 411 | 412 | # save train_losses 413 | train_losses_ckpt = torch.cat(train_losses) 414 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 415 | logger.info(f"Final mean train loss: {np.mean(train_losses_ckpt)}") 416 | 417 | eval_losses_all_ckpt = eval_losses_all.cpu().numpy() 418 | test_losses_all_ckpt = test_losses_all.cpu().numpy() 419 | 420 | # save results 421 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 422 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, eval_losses_ckpt=eval_losses_all_ckpt, test_losses_ckpt=test_losses_all_ckpt, completed_steps=completed_steps) 423 | 424 | 425 | if __name__ == "__main__": 426 | main() 427 | -------------------------------------------------------------------------------- /llm-experiments/training/train_interleave_randomwindow.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from itertools import chain 7 | 8 | import torch 9 | import numpy as np 10 | from datasets import load_dataset 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import set_seed 18 | from transformers import ( 19 | CONFIG_MAPPING, 20 | MODEL_MAPPING, 21 | AutoConfig, 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | default_data_collator, 25 | GPTNeoXForCausalLM, 26 | ) 27 | from copy import deepcopy 28 | 29 | logger = get_logger(__name__) 30 | 31 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 32 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 33 | 34 | 35 | def parse_args(): 36 | 37 | parser = argparse.ArgumentParser(description="Finetune large language models on causal language modeling tasks") 38 | parser.add_argument("--dataset_name", type=str, default=None, required=True, help="The name of the dataset to use (via the datasets library).") 39 | parser.add_argument("--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") 40 | parser.add_argument("--model_name_or_path", type=str, help="Path to pretrained model or model identifier from huggingface.co/models.", required=False) 41 | parser.add_argument("--revision", type=str, default='main', help="Model Branch") 42 | parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name") 43 | parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name") 44 | parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).") 45 | parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.") 46 | parser.add_argument("--learning_rate", type=float, default=0.0001, help="Initial learning rate (after the potential warmup period) to use.") 47 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 48 | parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") 49 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 50 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 51 | parser.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", choices=MODEL_TYPES) 52 | parser.add_argument("--block_size", type=int, default=None, help="The training dataset will be truncated to blocks of this size (after tokenization) for training.") 53 | parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.") 54 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 55 | parser.add_argument("--save_prefix", type=str, default='', help="Informative string prefix for saving purposes.") 56 | parser.add_argument("--use_pretrained_weights", action=argparse.BooleanOptionalAction, help="Whether to use pretrained weights.") 57 | parser.set_defaults(use_pretrained_weights=True) 58 | 59 | parser.add_argument("--eval_every_step", action=argparse.BooleanOptionalAction, help="Whether to eval every step.") 60 | parser.set_defaults(eval_every_step=True) 61 | 62 | parser.add_argument("--use_validation", action=argparse.BooleanOptionalAction, help="Whether to eval on validation set.") 63 | parser.set_defaults(use_validation=False) 64 | 65 | parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.") 66 | parser.add_argument("--eval_freq", type=int, default=1, help="Number of epochs before every recall experiment.") 67 | parser.add_argument("--save_freq", type=int, default=10, help="Number of epochs before every moodel and optimizer save.") 68 | 69 | parser.add_argument("--num-grad-steps", type=int, default=1, help="Number of gradient updates for each data point.") 70 | parser.add_argument("--num-data-samples", type=int, default=1, help="Number of tasks to interleave.") 71 | parser.add_argument("--num-eval-data-samples", type=int, default=100, help="Number of tasks to interleave.") 72 | parser.add_argument("--random-window", type=int, default=0, help="Range of random window sampling") 73 | 74 | args = parser.parse_args() 75 | 76 | return args 77 | 78 | 79 | def main(): 80 | args = parse_args() 81 | print(args) 82 | 83 | # eval_sample = [0] 84 | eval_sample = range(args.num_data_samples) 85 | 86 | accelerator = Accelerator() 87 | 88 | logging.basicConfig( 89 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 90 | datefmt="%m/%d/%Y %H:%M:%S", 91 | level=logging.INFO, 92 | ) 93 | logger.info(accelerator.state, main_process_only=False) 94 | if accelerator.is_local_main_process: 95 | transformers.utils.logging.set_verbosity_info() 96 | else: 97 | transformers.utils.logging.set_verbosity_error() 98 | 99 | if args.seed is not None: 100 | set_seed(args.seed) 101 | 102 | if accelerator.is_main_process: 103 | if args.output_dir is not None: 104 | os.makedirs(args.output_dir, exist_ok=True) 105 | accelerator.wait_for_everyone() 106 | 107 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 108 | 109 | if 'test' in raw_datasets.keys(): 110 | raw_datasets.pop('test') 111 | print("Length of Training Set", raw_datasets['train']) 112 | 113 | subset_task_index = random.sample(range(len(raw_datasets['train'])), args.num_data_samples) 114 | # subset_task_index = range(args.num_data_samples) 115 | 116 | raw_datasets['train'] = raw_datasets['train'].select(subset_task_index) 117 | if args.use_validation: 118 | subset_task_index_eval = random.sample(range(len(raw_datasets['validation'])), args.num_eval_data_samples) 119 | raw_datasets['validation'] = raw_datasets['validation'].select(subset_task_index_eval) 120 | elif 'validation' in raw_datasets.keys(): 121 | raw_datasets.pop('validation') 122 | 123 | eval_datasets = deepcopy(raw_datasets) 124 | eval_datasets['train'] = eval_datasets['train'].select(eval_sample) 125 | 126 | # Load pretrained model and tokenizer 127 | if args.config_name: 128 | config = AutoConfig.from_pretrained(args.config_name) 129 | elif args.model_name_or_path: 130 | config = AutoConfig.from_pretrained(args.model_name_or_path) 131 | else: 132 | config = CONFIG_MAPPING[args.model_type]() 133 | logger.warning("You are instantiating a new config instance from scratch.") 134 | 135 | if args.tokenizer_name: 136 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 137 | elif args.model_name_or_path: 138 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 139 | if args.model_name_or_path.startswith("gpt2") or args.model_name_or_path.startswith("EleutherAI"): 140 | tokenizer.pad_token = tokenizer.eos_token 141 | else: 142 | raise ValueError() 143 | 144 | if args.model_name_or_path and args.use_pretrained_weights: 145 | if 'pythia' in args.model_name_or_path: 146 | model_author, model_name = args.model_name_or_path.split('/') 147 | model = GPTNeoXForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, revision=args.revision, cache_dir=f"./{model_name}/{args.revision}") 148 | else: 149 | model = AutoModelForCausalLM.from_predtrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config) 150 | else: 151 | logger.info("Training new model from scratch") 152 | if 'pythia' in args.model_name_or_path: 153 | model = GPTNeoXForCausalLM(config) 154 | else: 155 | model = AutoModelForCausalLM.from_config(config) 156 | 157 | model.resize_token_embeddings(len(tokenizer)) 158 | 159 | if not args.use_pretrained_weights: 160 | with torch.no_grad(): 161 | for name, param in model.named_parameters(): 162 | if 'norm' not in name and 'bias' not in name: 163 | # print(f'Layer {name}, Weight Scale {torch.max(param).data}') 164 | if 'query_key_value' in name or 'dense_h_to_4h' in name or 'embed_in' in name or 'embed_out' in name: 165 | param *= math.sqrt(2 / (5 * 2048)) / 0.02 166 | elif 'attention.dense' in name or 'dense_4h_to_h' in name: 167 | param *= 2 / 16 / math.sqrt(2048) / 0.02 168 | 169 | for name, param in model.named_parameters(): 170 | if 'norm' not in name and 'bias' not in name: 171 | print(f'Layer {name}, Weight Scale {torch.max(param).data}') 172 | 173 | column_names = raw_datasets["train"].column_names 174 | eval_column_names = eval_datasets["train"].column_names 175 | 176 | test_text_column_name = 'highlights' 177 | text_column_name = eval_text_column_name = "article" 178 | 179 | if args.block_size is None: 180 | block_size = tokenizer.model_max_length 181 | else: 182 | block_size = args.block_size 183 | if args.block_size > tokenizer.model_max_length: 184 | logger.warning( 185 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" 186 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 187 | ) 188 | block_size = tokenizer.model_max_length 189 | print('Block size:', block_size) 190 | 191 | def tokenize_function(examples): 192 | return tokenizer(examples[text_column_name], padding='max_length', truncation=True, max_length=block_size) 193 | 194 | def tokenize_function_eval(examples): 195 | return tokenizer(examples[eval_text_column_name], truncation=True, max_length=block_size) 196 | 197 | def tokenize_function_test(examples): 198 | return tokenizer(examples[test_text_column_name], truncation=True, max_length=block_size) 199 | 200 | with accelerator.main_process_first(): 201 | tokenized_datasets = raw_datasets.map( 202 | tokenize_function, 203 | batched=True, 204 | num_proc=args.preprocessing_num_workers, 205 | remove_columns=column_names, 206 | load_from_cache_file=not args.overwrite_cache, 207 | desc="Running tokenizer on dataset", 208 | ) 209 | 210 | with accelerator.main_process_first(): 211 | eval_tokenized_datasets = eval_datasets.map( 212 | tokenize_function_eval, 213 | batched=True, 214 | num_proc=args.preprocessing_num_workers, 215 | remove_columns=eval_column_names, 216 | load_from_cache_file=not args.overwrite_cache, 217 | desc="Running tokenizer on dataset", 218 | ) 219 | 220 | with accelerator.main_process_first(): 221 | test_tokenized_datasets = eval_datasets.map( 222 | tokenize_function_test, 223 | batched=True, 224 | num_proc=args.preprocessing_num_workers, 225 | remove_columns=eval_column_names, 226 | load_from_cache_file=not args.overwrite_cache, 227 | desc="Running tokenizer on dataset", 228 | ) 229 | 230 | def preprocess_function(examples): 231 | examples["labels"] = examples["input_ids"].copy() 232 | examples["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in examples["labels"]] 233 | return examples 234 | 235 | with accelerator.main_process_first(): 236 | lm_datasets = tokenized_datasets.map( 237 | preprocess_function, 238 | batched=True, 239 | num_proc=args.preprocessing_num_workers, 240 | load_from_cache_file=not args.overwrite_cache, 241 | desc=f"Not grouping text.", 242 | ) 243 | 244 | with accelerator.main_process_first(): 245 | eval_lm_datasets = eval_tokenized_datasets.map( 246 | preprocess_function, 247 | batched=True, 248 | num_proc=args.preprocessing_num_workers, 249 | load_from_cache_file=not args.overwrite_cache, 250 | desc=f"Not grouping text.", 251 | ) 252 | 253 | with accelerator.main_process_first(): 254 | test_lm_datasets = test_tokenized_datasets.map( 255 | preprocess_function, 256 | batched=True, 257 | num_proc=args.preprocessing_num_workers, 258 | load_from_cache_file=not args.overwrite_cache, 259 | desc=f"Not grouping text.", 260 | ) 261 | 262 | train_dataset = lm_datasets["train"] 263 | train_dataloader = DataLoader(train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size) 264 | 265 | eval_dataset = eval_lm_datasets["train"] 266 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 267 | 268 | test_dataset = test_lm_datasets['train'] 269 | test_dataloader = DataLoader(test_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 270 | 271 | 272 | for index in random.sample(range(len(train_dataset)), 1): 273 | index = 0 274 | # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 275 | logger.info(f"Sample {index} of the training set (decoded): {tokenizer.decode(train_dataset[index]['input_ids'], skip_special_tokens=True)}.") 276 | for index in random.sample(range(len(eval_dataset)), 1): 277 | # logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") 278 | logger.info(f"Sample {index} of the validation set (decoded): {tokenizer.decode(eval_dataset[index]['input_ids'], skip_special_tokens=True)}.") 279 | 280 | no_decay = ["bias", "layer_norm.weight"] 281 | optimizer_grouped_parameters = [ 282 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 283 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 284 | ] 285 | 286 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate) 287 | 288 | # Prepare everything with our `accelerator`. 289 | model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader, test_dataloader) 290 | 291 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 292 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) * args.num_grad_steps 293 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 294 | 295 | # Train! 296 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes 297 | 298 | logger.info("***** Running training *****") 299 | logger.info(f" Num examples = {len(train_dataset)}") 300 | logger.info(f" Num Epochs = {args.num_train_epochs}") 301 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 302 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 303 | logger.info(f" Total optimization steps = {args.max_train_steps}") 304 | 305 | # Only show the progress bar once on each machine. 306 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 307 | completed_steps = 0 308 | starting_epoch = 0 309 | 310 | # update the progress_bar if load from checkpoint 311 | progress_bar.update(starting_epoch * num_update_steps_per_epoch) 312 | completed_steps = starting_epoch * num_update_steps_per_epoch 313 | 314 | total_num_task_steps = args.num_train_epochs * args.num_data_samples + 1 315 | 316 | train_losses = [] 317 | eval_losses_all = torch.zeros(total_num_task_steps, len(eval_dataloader)) 318 | test_losses_all = torch.zeros(total_num_task_steps, len(test_dataloader)) 319 | 320 | # Initial Eval 321 | model.eval() 322 | with torch.no_grad(): 323 | for eval_step, eval_batch in enumerate(eval_dataloader): 324 | 325 | eval_batch['input_ids'] = eval_batch['input_ids'][:, args.block_size // 4 : 3 * args.block_size // 4] 326 | eval_batch['attention_mask'] = eval_batch['attention_mask'][:, args.block_size // 4 : 3 * args.block_size // 4] 327 | eval_batch['labels'] = eval_batch['labels'][:, args.block_size // 4 : 3 * args.block_size // 4] 328 | 329 | outputs = model(**eval_batch) 330 | loss = outputs.loss 331 | eval_losses_all[0, eval_step] = loss.detach() 332 | 333 | logger.info(f"Mean eval loss: {torch.mean(eval_losses_all[0, :])}") 334 | 335 | for test_step, batch in enumerate(test_dataloader): 336 | outputs = model(**batch) 337 | loss = outputs.loss 338 | test_losses_all[0, test_step] = loss.detach() 339 | logger.info(f"Mean test loss: {torch.mean(test_losses_all[0, :])}") 340 | 341 | for epoch in range(starting_epoch, args.num_train_epochs): 342 | 343 | for step, batch in enumerate(train_dataloader): 344 | 345 | global_train_step = epoch * args.num_data_samples + step + 1 346 | 347 | model.train() 348 | 349 | orig_batch = batch.copy() 350 | 351 | for grad_step in range(args.num_grad_steps): 352 | 353 | assert model.training 354 | 355 | start_idx = random.randrange(-args.random_window, args.random_window+1) 356 | batch['input_ids'] = orig_batch['input_ids'][:, args.block_size // 4 + start_idx : 3 * args.block_size // 4 + start_idx] 357 | batch['attention_mask'] = orig_batch['attention_mask'][:, args.block_size // 4 + start_idx : 3 * args.block_size // 4 + start_idx] 358 | batch['labels'] = orig_batch['labels'][:, args.block_size // 4 + start_idx : 3 * args.block_size // 4 + start_idx] 359 | 360 | outputs = model(**batch) 361 | loss = outputs.loss 362 | # keep track of the loss at each epoch 363 | train_losses.append(loss.detach().unsqueeze(0)) 364 | optimizer.zero_grad() 365 | accelerator.backward(loss) 366 | optimizer.step() 367 | 368 | # Checks if the accelerator has performed an optimization step behind the scenes 369 | if accelerator.sync_gradients: 370 | progress_bar.update(1) 371 | completed_steps += 1 372 | 373 | if args.eval_every_step: 374 | model.eval() 375 | with torch.no_grad(): 376 | for eval_step, eval_batch in enumerate(eval_dataloader): 377 | eval_batch['input_ids'] = eval_batch['input_ids'][:, args.block_size // 4 : 3 * args.block_size // 4] 378 | eval_batch['attention_mask'] = eval_batch['attention_mask'][:, args.block_size // 4 : 3 * args.block_size // 4] 379 | eval_batch['labels'] = eval_batch['labels'][:, args.block_size // 4 : 3 * args.block_size // 4] 380 | 381 | eval_outputs = model(**eval_batch) 382 | eval_loss = eval_outputs.loss 383 | eval_losses_all[global_train_step, eval_step] = eval_loss.detach() 384 | 385 | for test_step, test_batch in enumerate(test_dataloader): 386 | test_outputs = model(**test_batch) 387 | test_loss = test_outputs.loss 388 | test_losses_all[global_train_step, test_step] = test_loss.detach() 389 | 390 | # Logging 391 | output_dir = f"epoch_{epoch}" 392 | if args.output_dir is not None: 393 | output_dir = os.path.join(args.output_dir, output_dir) 394 | 395 | os.makedirs(output_dir, exist_ok=True) 396 | # if epoch == 0 or (epoch+1) % args.save_freq == 0: 397 | # accelerator.save_state(output_dir) 398 | 399 | # save train_losses 400 | train_losses_ckpt = torch.cat(train_losses) 401 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 402 | logger.info(f"Mean train loss: {np.mean(train_losses_ckpt)}") 403 | 404 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 405 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, completed_steps=completed_steps) 406 | 407 | 408 | if args.output_dir is not None: 409 | output_dir = os.path.join(args.output_dir, f'final') 410 | # save model and tokenizer 411 | accelerator.wait_for_everyone() 412 | # unwrapped_model = accelerator.unwrap_model(model) 413 | # unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save) 414 | if accelerator.is_main_process: 415 | tokenizer.save_pretrained(output_dir) 416 | 417 | # save train_losses 418 | train_losses_ckpt = torch.cat(train_losses) 419 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 420 | logger.info(f"Final mean train loss: {np.mean(train_losses_ckpt)}") 421 | 422 | eval_losses_all_ckpt = eval_losses_all.cpu().numpy() 423 | test_losses_all_ckpt = test_losses_all.cpu().numpy() 424 | 425 | # save results 426 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 427 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, eval_losses_ckpt=eval_losses_all_ckpt, test_losses_ckpt=test_losses_all_ckpt, completed_steps=completed_steps) 428 | 429 | 430 | if __name__ == "__main__": 431 | main() 432 | -------------------------------------------------------------------------------- /llm-experiments/training/train_interleave_ablation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from itertools import chain 7 | 8 | import torch 9 | import numpy as np 10 | from torch.utils.data import DataLoader 11 | from tqdm.auto import tqdm 12 | 13 | import transformers 14 | from accelerate import Accelerator 15 | from accelerate.logging import get_logger 16 | from accelerate.utils import set_seed 17 | from transformers import ( 18 | CONFIG_MAPPING, 19 | MODEL_MAPPING, 20 | AutoConfig, 21 | AutoModelForCausalLM, 22 | AutoTokenizer, 23 | default_data_collator, 24 | ) 25 | from copy import deepcopy 26 | 27 | logger = get_logger(__name__) 28 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 29 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 30 | 31 | def parse_args(): 32 | 33 | parser = argparse.ArgumentParser(description="Finetune large language models on causal language modeling tasks") 34 | parser.add_argument("--dataset_name", type=str, default=None, help="The name of the dataset to use (via the datasets library).") 35 | parser.add_argument("--dataset_config_name", type=str, default=None, help="The configuration name of the dataset to use (via the datasets library).") 36 | parser.add_argument("--train_file", type=str, default=None, help="A csv or a json file containing the training data.") 37 | parser.add_argument("--model_name_or_path", type=str, help="Path to pretrained model or model identifier from huggingface.co/models.", required=False) 38 | parser.add_argument("--revision", type=str, default='main', help="Model Branch") 39 | parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name") 40 | parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name") 41 | parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).") 42 | parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader.") 43 | parser.add_argument("--learning_rate", type=float, default=0.0001, help="Initial learning rate (after the potential warmup period) to use.") 44 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 45 | parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") 46 | parser.add_argument("--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.") 47 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 48 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 49 | parser.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", choices=MODEL_TYPES) 50 | parser.add_argument("--block_size", type=int, default=None, help="The training dataset will be truncated to blocks of this size (after tokenization) for training.") 51 | parser.add_argument("--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.") 52 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 53 | parser.add_argument("--checkpointing_steps", type=str, default=None, help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.") 54 | parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="If the training should continue from a checkpoint folder.") 55 | parser.add_argument("--save_prefix", type=str, default='', help="Informative string prefix for saving purposes.") 56 | parser.add_argument("--use_pretrained_weights", action=argparse.BooleanOptionalAction, help="Whether to use pretrained weights.") 57 | parser.set_defaults(use_pretrained_weights=False) 58 | 59 | parser.add_argument("--eval_every_step", action=argparse.BooleanOptionalAction, help="Whether to eval every step.") 60 | parser.set_defaults(eval_every_step=True) 61 | 62 | parser.add_argument("--use_validation", action=argparse.BooleanOptionalAction, help="Whether to eval on validation set.") 63 | parser.set_defaults(use_validation=False) 64 | 65 | parser.add_argument("--seen_file", type=str, default=None, help="A csv or a json file containing the seen examples.") 66 | parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size (per device) for the evaluation dataloader.") 67 | parser.add_argument("--eval_freq", type=int, default=1, help="Number of epochs before every recall experiment.") 68 | parser.add_argument("--save_freq", type=int, default=10, help="Number of epochs before every moodel and optimizer save.") 69 | 70 | parser.add_argument("--num-grad-steps", type=int, default=1, help="Number of gradient updates for each data point.") 71 | parser.add_argument("--num-data-samples", type=int, default=1, help="Number of tasks to interleave.") 72 | parser.add_argument("--num-eval-data-samples", type=int, default=100, help="Number of tasks to interleave.") 73 | parser.add_argument("--start-shuffle-data-samples", type=int, default=-1, help="Number of tasks to interleave.") 74 | 75 | parser.add_argument("--num_hidden_layers", type=int, default=16, help="Number of hidden layers.") 76 | parser.add_argument("--num_attention_heads", type=int, default=8, help="Number of attention heads.") 77 | parser.add_argument("--hidden_size", type=int, default=2048, help="Hidden size.") 78 | 79 | args = parser.parse_args() 80 | 81 | if args.start_shuffle_data_samples == -1: 82 | args.start_shuffle_data_samples = args.num_data_samples - 1 83 | 84 | # Sanity checks 85 | if args.dataset_name is None and args.train_file is None: 86 | raise ValueError("Need either a dataset name or a training file.") 87 | else: 88 | if args.train_file is not None: 89 | extension = args.train_file.split(".")[-1] 90 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." 91 | 92 | return args 93 | 94 | 95 | def main(): 96 | args = parse_args() 97 | print(args) 98 | 99 | eval_sample = [0] 100 | # eval_sample = range(args.num_data_samples) 101 | 102 | accelerator = Accelerator() 103 | 104 | logging.basicConfig( 105 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 106 | datefmt="%m/%d/%Y %H:%M:%S", 107 | level=logging.INFO, 108 | ) 109 | logger.info(accelerator.state, main_process_only=False) 110 | if accelerator.is_local_main_process: 111 | transformers.utils.logging.set_verbosity_info() 112 | else: 113 | transformers.utils.logging.set_verbosity_error() 114 | 115 | if args.seed is not None: 116 | set_seed(args.seed) 117 | 118 | if accelerator.is_main_process: 119 | if args.output_dir is not None: 120 | os.makedirs(args.output_dir, exist_ok=True) 121 | accelerator.wait_for_everyone() 122 | 123 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 124 | 125 | raw_datasets.pop('test') 126 | # subset_task_index = random.sample(range(len(raw_datasets['train'])), args.num_data_samples) 127 | subset_task_index = range(args.num_data_samples) 128 | raw_datasets['train'] = raw_datasets['train'].select(subset_task_index) 129 | if args.use_validation: 130 | subset_task_index_eval = random.sample(range(len(raw_datasets['validation'])), args.num_eval_data_samples) 131 | raw_datasets['validation'] = raw_datasets['validation'].select(subset_task_index_eval) 132 | else: 133 | raw_datasets.pop('validation') 134 | 135 | eval_datasets = deepcopy(raw_datasets) 136 | eval_datasets['train'] = eval_datasets['train'].select(eval_sample) 137 | 138 | # Load pretrained model and tokenizer 139 | if args.config_name: 140 | config = AutoConfig.from_pretrained(args.config_name) 141 | elif args.model_name_or_path: 142 | config = AutoConfig.from_pretrained(args.model_name_or_path) 143 | else: 144 | config = CONFIG_MAPPING[args.model_type]() 145 | logger.warning("You are instantiating a new config instance from scratch.") 146 | 147 | if args.tokenizer_name: 148 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 149 | elif args.model_name_or_path: 150 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, model_max_length=2048) 151 | if args.model_name_or_path.startswith("gpt2") or args.model_name_or_path.startswith("EleutherAI"): 152 | tokenizer.pad_token = tokenizer.eos_token 153 | else: 154 | raise ValueError() 155 | 156 | 157 | config.num_hidden_layers = args.num_hidden_layers # Default 16 158 | config.num_attention_heads = args.num_attention_heads # Default 8 159 | config.hidden_size = args.hidden_size # Default 2048 160 | 161 | 162 | if args.model_name_or_path and args.use_pretrained_weights: 163 | if 'pythia' in args.model_name_or_path: 164 | model_author, model_name = args.model_name_or_path.split('/') 165 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, revision=args.revision, cache_dir=f"./{model_name}/{args.revision}") 166 | else: 167 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config) 168 | else: 169 | logger.info("Training new model from scratch") 170 | model = AutoModelForCausalLM.from_config(config) 171 | 172 | model.resize_token_embeddings(len(tokenizer)) 173 | 174 | if not args.use_pretrained_weights: 175 | with torch.no_grad(): 176 | for name, param in model.named_parameters(): 177 | if 'norm' not in name and 'bias' not in name: 178 | # print(f'Layer {name}, Weight Scale {torch.max(param).data}') 179 | if 'query_key_value' in name or 'dense_h_to_4h' in name or 'embed_in' in name or 'embed_out' in name: 180 | param *= math.sqrt(2 / (5 * args.hidden_size)) / 0.02 181 | elif 'attention.dense' in name or 'dense_4h_to_h' in name: 182 | param *= 2 / args.num_hidden_layers / math.sqrt(args.hidden_size) / 0.02 183 | 184 | for name, param in model.named_parameters(): 185 | if 'norm' not in name and 'bias' not in name: 186 | print(f'Layer {name}, Weight Scale {torch.max(param).data}') 187 | 188 | column_names = raw_datasets["train"].column_names 189 | text_column_name = eval_text_column_name = "article" 190 | 191 | eval_column_names = eval_datasets["train"].column_names 192 | 193 | test_text_column_name = 'highlights' 194 | 195 | if args.block_size is None: 196 | block_size = tokenizer.model_max_length 197 | else: 198 | block_size = args.block_size 199 | if args.block_size > tokenizer.model_max_length: 200 | logger.warning( 201 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" 202 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 203 | ) 204 | block_size = tokenizer.model_max_length 205 | print('Block size:', block_size) 206 | 207 | def tokenize_function(examples): 208 | return tokenizer(examples[text_column_name], padding='max_length', truncation=True, max_length=block_size) 209 | 210 | def tokenize_function_eval(examples): 211 | return tokenizer(examples[eval_text_column_name], truncation=True, max_length=block_size) 212 | 213 | def tokenize_function_test(examples): 214 | return tokenizer(examples[test_text_column_name], truncation=True, max_length=block_size) 215 | 216 | with accelerator.main_process_first(): 217 | tokenized_datasets = raw_datasets.map( 218 | tokenize_function, 219 | batched=True, 220 | num_proc=args.preprocessing_num_workers, 221 | remove_columns=column_names, 222 | load_from_cache_file=not args.overwrite_cache, 223 | desc="Running tokenizer on dataset", 224 | ) 225 | 226 | with accelerator.main_process_first(): 227 | eval_tokenized_datasets = eval_datasets.map( 228 | tokenize_function_eval, 229 | batched=True, 230 | num_proc=args.preprocessing_num_workers, 231 | remove_columns=eval_column_names, 232 | load_from_cache_file=not args.overwrite_cache, 233 | desc="Running tokenizer on dataset", 234 | ) 235 | 236 | with accelerator.main_process_first(): 237 | test_tokenized_datasets = eval_datasets.map( 238 | tokenize_function_test, 239 | batched=True, 240 | num_proc=args.preprocessing_num_workers, 241 | remove_columns=eval_column_names, 242 | load_from_cache_file=not args.overwrite_cache, 243 | desc="Running tokenizer on dataset", 244 | ) 245 | 246 | def preprocess_function(examples): 247 | examples["labels"] = examples["input_ids"].copy() 248 | examples["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in examples["labels"]] 249 | return examples 250 | 251 | with accelerator.main_process_first(): 252 | lm_datasets = tokenized_datasets.map( 253 | preprocess_function, 254 | batched=True, 255 | num_proc=args.preprocessing_num_workers, 256 | load_from_cache_file=not args.overwrite_cache, 257 | desc=f"Not grouping text.", 258 | ) 259 | 260 | with accelerator.main_process_first(): 261 | eval_lm_datasets = eval_tokenized_datasets.map( 262 | preprocess_function, 263 | batched=True, 264 | num_proc=args.preprocessing_num_workers, 265 | load_from_cache_file=not args.overwrite_cache, 266 | desc=f"Not grouping text.", 267 | ) 268 | 269 | with accelerator.main_process_first(): 270 | test_lm_datasets = test_tokenized_datasets.map( 271 | preprocess_function, 272 | batched=True, 273 | num_proc=args.preprocessing_num_workers, 274 | load_from_cache_file=not args.overwrite_cache, 275 | desc=f"Not grouping text.", 276 | ) 277 | 278 | train_dataset = lm_datasets["train"] 279 | train_dataloader = DataLoader(train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size) 280 | 281 | eval_dataset = eval_lm_datasets["train"] 282 | eval_dataloader = DataLoader(eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 283 | 284 | test_dataset = test_lm_datasets['train'] 285 | test_dataloader = DataLoader(test_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size) 286 | 287 | 288 | for index in random.sample(range(len(train_dataset)), 1): 289 | logger.info(f"Sample {index} of the training set (decoded): {tokenizer.decode(train_dataset[index]['input_ids'], skip_special_tokens=True)}.") 290 | for index in random.sample(range(len(eval_dataset)), 1): 291 | logger.info(f"Sample {index} of the validation set (decoded): {tokenizer.decode(eval_dataset[index]['input_ids'], skip_special_tokens=True)}.") 292 | 293 | no_decay = ["bias", "layer_norm.weight"] 294 | optimizer_grouped_parameters = [ 295 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay}, 296 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 297 | ] 298 | 299 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate) 300 | 301 | # Scheduler and math around the number of training steps. 302 | overrode_max_train_steps = False 303 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) * args.num_grad_steps 304 | if args.max_train_steps is None: 305 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 306 | overrode_max_train_steps = True 307 | 308 | # Prepare everything with our `accelerator`. 309 | model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader, test_dataloader) 310 | 311 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 312 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) * args.num_grad_steps 313 | if overrode_max_train_steps: 314 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 315 | 316 | # Afterwards we recalculate our number of training epochs 317 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 318 | 319 | # Figure out how many steps we should save the Accelerator states 320 | checkpointing_steps = args.checkpointing_steps 321 | if checkpointing_steps is not None and checkpointing_steps.isdigit(): 322 | checkpointing_steps = int(checkpointing_steps) 323 | 324 | # Train! 325 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes 326 | 327 | logger.info("***** Running training *****") 328 | logger.info(f" Num examples = {len(train_dataset)}") 329 | logger.info(f" Num Epochs = {args.num_train_epochs}") 330 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 331 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 332 | logger.info(f" Total optimization steps = {args.max_train_steps}") 333 | 334 | # Only show the progress bar once on each machine. 335 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 336 | completed_steps = 0 337 | starting_epoch = 0 338 | 339 | # Potentially load in the weights and states from a previous save 340 | if args.resume_from_checkpoint: 341 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 342 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 343 | accelerator.load_state(args.resume_from_checkpoint) 344 | path = os.path.basename(args.resume_from_checkpoint) 345 | else: 346 | # Get the most recent checkpoint 347 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 348 | dirs.sort(key=os.path.getctime) 349 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 350 | # Extract `epoch_{i}` or `step_{i}` 351 | training_difference = os.path.splitext(path)[0] 352 | 353 | if "epoch" in training_difference: 354 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 355 | resume_step = None 356 | else: 357 | resume_step = int(training_difference.replace("step_", "")) 358 | starting_epoch = resume_step // len(train_dataloader) 359 | resume_step -= starting_epoch * len(train_dataloader) 360 | 361 | # update the progress_bar if load from checkpoint 362 | progress_bar.update(starting_epoch * num_update_steps_per_epoch) 363 | completed_steps = starting_epoch * num_update_steps_per_epoch 364 | 365 | train_losses = [] 366 | train_losses_all = [] 367 | 368 | eval_losses = [] 369 | eval_losses_all = [] 370 | 371 | test_losses = [] 372 | test_losses_all = [] 373 | 374 | # Initial Eval 375 | model.eval() 376 | with torch.no_grad(): 377 | for _, batch in enumerate(eval_dataloader): 378 | outputs = model(**batch) 379 | loss = outputs.loss 380 | eval_losses.append(loss.detach().unsqueeze(0)) 381 | eval_losses_all.append(loss.detach().unsqueeze(0)) 382 | eval_losses_ckpt = torch.cat(eval_losses) 383 | eval_losses_ckpt = eval_losses_ckpt.cpu().numpy() 384 | logger.info(f"Mean eval loss: {np.mean(eval_losses_ckpt)}") 385 | 386 | for _, batch in enumerate(test_dataloader): 387 | outputs = model(**batch) 388 | loss = outputs.loss 389 | test_losses.append(loss.detach().unsqueeze(0)) 390 | test_losses_all.append(loss.detach().unsqueeze(0)) 391 | test_losses_ckpt = torch.cat(test_losses) 392 | test_losses_ckpt = test_losses_ckpt.cpu().numpy() 393 | logger.info(f"Mean TEST loss: {np.mean(test_losses_ckpt)}") 394 | 395 | for epoch in range(starting_epoch, args.num_train_epochs): 396 | 397 | for step, batch in enumerate(train_dataloader): 398 | if args.resume_from_checkpoint and epoch == starting_epoch: 399 | if resume_step is not None and step < resume_step: 400 | progress_bar.update(1) 401 | completed_steps += 1 402 | continue 403 | 404 | model.train() 405 | 406 | for _ in range(args.num_grad_steps): 407 | 408 | assert model.training 409 | outputs = model(**batch) 410 | loss = outputs.loss 411 | # keep track of the loss at each epoch 412 | train_losses.append(loss.detach().unsqueeze(0)) 413 | train_losses_all.append(loss.detach().unsqueeze(0)) 414 | optimizer.zero_grad() 415 | accelerator.backward(loss) 416 | optimizer.step() 417 | # optimizer.zero_grad() 418 | 419 | # Checks if the accelerator has performed an optimization step behind the scenes 420 | if accelerator.sync_gradients: 421 | progress_bar.update(1) 422 | completed_steps += 1 423 | 424 | if args.eval_every_step: 425 | # if step % 10 == 0: 426 | model.eval() 427 | with torch.no_grad(): 428 | eval_losses = [] 429 | for eval_step, eval_batch in enumerate(eval_dataloader): 430 | eval_outputs = model(**eval_batch) 431 | eval_loss = eval_outputs.loss 432 | eval_losses.append(eval_loss.detach().unsqueeze(0)) 433 | eval_losses_all.append(eval_loss.detach().unsqueeze(0)) 434 | 435 | test_losses = [] 436 | for test_step, test_batch in enumerate(test_dataloader): 437 | test_outputs = model(**test_batch) 438 | test_loss = test_outputs.loss 439 | test_losses.append(test_loss.detach().unsqueeze(0)) 440 | test_losses_all.append(test_loss.detach().unsqueeze(0)) 441 | 442 | 443 | if isinstance(checkpointing_steps, int): 444 | if completed_steps % checkpointing_steps == 0: 445 | output_dir = f"step_{completed_steps}" 446 | if args.output_dir is not None: 447 | output_dir = os.path.join(args.output_dir, output_dir) 448 | 449 | # save model and tokenizer 450 | accelerator.wait_for_everyone() 451 | unwrapped_model = accelerator.unwrap_model(model) 452 | unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save) 453 | if accelerator.is_main_process: 454 | tokenizer.save_pretrained(output_dir) 455 | 456 | # save train_losses 457 | train_losses_ckpt = torch.cat(train_losses) 458 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 459 | 460 | train_losses_all_ckpt = torch.cat(train_losses_all) 461 | train_losses_all_ckpt = train_losses_all_ckpt.cpu().numpy() 462 | 463 | logger.info(f"Mean train loss: {np.mean(train_losses_ckpt)}") 464 | 465 | save_path = os.path.join(output_dir, 'train_losses.npz') 466 | np.savez(save_path, train_losses=train_losses_all_ckpt, completed_steps=completed_steps) 467 | 468 | # re-initialize losses 469 | train_losses = [] 470 | 471 | if completed_steps >= args.max_train_steps: 472 | break 473 | 474 | if args.checkpointing_steps == "epoch": 475 | output_dir = f"epoch_{epoch}" 476 | if args.output_dir is not None: 477 | output_dir = os.path.join(args.output_dir, output_dir) 478 | 479 | os.makedirs(output_dir, exist_ok=True) 480 | # if epoch == 0 or (epoch+1) % args.save_freq == 0: 481 | # accelerator.save_state(output_dir) 482 | 483 | # save train_losses 484 | train_losses_ckpt = torch.cat(train_losses) 485 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 486 | logger.info(f"Mean train loss: {np.mean(train_losses_ckpt)}") 487 | 488 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 489 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, completed_steps=completed_steps) 490 | 491 | if args.output_dir is not None: 492 | output_dir = os.path.join(args.output_dir, f'final') 493 | # save model and tokenizer 494 | accelerator.wait_for_everyone() 495 | unwrapped_model = accelerator.unwrap_model(model) 496 | unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save) 497 | if accelerator.is_main_process: 498 | tokenizer.save_pretrained(output_dir) 499 | 500 | # save train_losses 501 | train_losses_ckpt = torch.cat(train_losses) 502 | train_losses_ckpt = train_losses_ckpt.cpu().numpy() 503 | logger.info(f"Final mean train loss: {np.mean(train_losses_ckpt)}") 504 | 505 | eval_losses_all_ckpt = torch.cat(eval_losses_all) 506 | eval_losses_all_ckpt = eval_losses_all_ckpt.cpu().numpy() 507 | 508 | test_losses_all_ckpt = torch.cat(test_losses_all) 509 | test_losses_all_ckpt = test_losses_all_ckpt.cpu().numpy() 510 | 511 | # save results 512 | save_path = os.path.join(output_dir, args.save_prefix + '_results.npz') 513 | np.savez(save_path, train_losses_ckpt=train_losses_ckpt, eval_losses_ckpt=eval_losses_all_ckpt, test_losses_ckpt=test_losses_all_ckpt, completed_steps=completed_steps) 514 | 515 | 516 | if __name__ == "__main__": 517 | main() 518 | --------------------------------------------------------------------------------