├── .gitignore ├── INSTALL.md ├── LICENSE ├── README.md ├── dense_ft ├── README.md ├── sparse_trainer.py └── trainer.py ├── image_classifiers ├── README.md ├── datasets.py ├── download_weights.sh ├── engine.py ├── layerwrapper.py ├── main.py ├── models │ ├── convnext.py │ ├── deit.py │ ├── mlp_mixer.py │ ├── swin_transformer.py │ └── vision_transformer.py ├── optim_factory.py ├── prune_utils.py └── utils.py ├── lib ├── ablate.py ├── data.py ├── eval.py ├── layerwrapper.py ├── prune.py ├── prune_opt.py └── sparsegpt.py ├── lora_ft ├── README.md ├── evaluate_ppl.py ├── finetune_lm.py └── script.sh ├── main.py ├── main_opt.py └── scripts ├── ablate_weight_update.sh ├── llama_13b.sh ├── llama_30b.sh ├── llama_65b.sh └── llama_7b.sh /.gitignore: -------------------------------------------------------------------------------- 1 | /*__pycache__/ 2 | .env 3 | *.pyc 4 | *.DS_Store -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | Step 1: Create a new conda environment: 3 | ``` 4 | conda create -n prune_llm python=3.9 5 | conda activate prune_llm 6 | ``` 7 | Step 2: Install relevant packages 8 | ``` 9 | conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge 10 | pip install transformers==4.28.0 datasets==2.11.0 wandb sentencepiece 11 | pip install accelerate==0.18.0 12 | ``` 13 | There are known [issues](https://github.com/huggingface/transformers/issues/22222) with the transformers library on loading the LLaMA tokenizer correctly. Please follow the mentioned suggestions to resolve this issue. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 CMU Locus Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pruning LLMs by Weights and Activations 2 | Official PyTorch implementation of **Wanda** (Pruning by **W**eights **and a**ctivations), as presented in our paper: 3 | 4 | **A Simple and Effective Pruning Approach for Large Language Models**
5 | *Mingjie Sun\*, Zhuang Liu\*, Anna Bair, J. Zico Kolter* (* indicates equal contribution)
6 | Carnegie Mellon University, Meta AI Research and Bosch Center for AI
7 | [Paper](https://arxiv.org/abs/2306.11695) - [Project page](https://eric-mingjie.github.io/wanda/home.html) 8 | 9 | ```bibtex 10 | @article{sun2023wanda, 11 | title={A Simple and Effective Pruning Approach for Large Language Models}, 12 | author={Sun, Mingjie and Liu, Zhuang and Bair, Anna and Kolter, J. Zico}, 13 | year={2023}, 14 | journal={arXiv preprint arXiv:2306.11695} 15 | } 16 | ``` 17 | 18 | --- 19 |

20 | 22 |

23 | 24 | Compared to magnitude pruning which removes weights solely based on their magnitudes, our pruning approach **Wanda** removes weights on a *per-output* basis, by the product of weight magnitudes and input activation norms. 25 | 26 | ## Update 27 | - [x] (9.22.2023) Add [support](https://github.com/locuslab/wanda#pruning-llama-2) for LLaMA-2. 28 | - [x] (9.22.2023) Add [code](https://github.com/locuslab/wanda#ablation-on-obs-weight-update) to reproduce the ablation study on OBS weight update in the paper. 29 | - [x] (10.6.2023) Add new [support](https://github.com/locuslab/wanda#ablation-on-obs-weight-update) for the weight update analysis in the ablation study. Feel free to try it out! 30 | - [x] (10.6.2023) Add [support](https://github.com/locuslab/wanda#zero-shot-evaluation) for zero-shot evaluation. 31 | - [x] (10.20.2023) Add code for pruning OPT models. 32 | - [x] (10.23.2023) Add code for [LoRA fine-tuning](lora_ft). 33 | 34 | ## Setup 35 | Installation instructions can be found in [INSTALL.md](INSTALL.md). 36 | 37 | ## Usage 38 | The [scripts](scripts) directory contains all the bash commands to replicate the main results (Table 2) in our paper. 39 | 40 | Below is an example command for pruning LLaMA-7B with Wanda, to achieve unstructured 50% sparsity. 41 | ```sh 42 | python main.py \ 43 | --model decapoda-research/llama-7b-hf \ 44 | --prune_method wanda \ 45 | --sparsity_ratio 0.5 \ 46 | --sparsity_type unstructured \ 47 | --save out/llama_7b/unstructured/wanda/ 48 | ``` 49 | We provide a quick overview of the arguments: 50 | - `--model`: The identifier for the LLaMA model on the Hugging Face model hub. 51 | - `--cache_dir`: Directory for loading or storing LLM weights. The default is `llm_weights`. 52 | - `--prune_method`: We have implemented three pruning methods, namely [`magnitude`, `wanda`, `sparsegpt`]. 53 | - `--sparsity_ratio`: Denotes the percentage of weights to be pruned. 54 | - `--sparsity_type`: Specifies the type of sparsity [`unstructured`, `2:4`, `4:8`]. 55 | - `--use_variant`: Whether to use the Wanda variant, default is `False`. 56 | - `--save`: Specifies the directory where the result will be stored. 57 | 58 | For structured N:M sparsity, set the argument `--sparsity_type` to "2:4" or "4:8". An illustrative command is provided below: 59 | ```sh 60 | python main.py \ 61 | --model decapoda-research/llama-7b-hf \ 62 | --prune_method wanda \ 63 | --sparsity_ratio 0.5 \ 64 | --sparsity_type 2:4 \ 65 | --save out/llama_7b/2-4/wanda/ 66 | ``` 67 | 68 | ### Pruning LLaMA-2 69 | For [LLaMA-2](https://ai.meta.com/llama/) models, replace `--model` with `meta-llama/Llama-2-7b-hf` (take `7b` as an example): 70 | ```sh 71 | python main.py \ 72 | --model meta-llama/Llama-2-7b-hf \ 73 | --prune_method wanda \ 74 | --sparsity_ratio 0.5 \ 75 | --sparsity_type unstructured \ 76 | --save out/llama2_7b/unstructured/wanda/ 77 | ``` 78 | LLaMA-2 results: (LLaMA-2-34b is not released as of 9.22.2023) 79 | |sparsity| ppl | llama2-7b | llama2-13b | llama2-70b | 80 | |------|------------------|----------|------------|------------| 81 | |-| dense | 5.12 | 4.57 | 3.12 | 82 | |unstructured 50%| magnitude | 14.89 | 6.37 | 4.98 | 83 | |unstructured 50%| sparsegpt | 6.51 | 5.63 | **3.98** | 84 | |unstructured 50%| wanda | **6.42** | **5.56** | **3.98** | 85 | |4:8| magnitude | 16.48 | 6.76 | 5.58 | 86 | |4:8| sparsegpt | 8.12 | 6.60 | 4.59 | 87 | |4:8| wanda | **7.97** | **6.55** | **4.47** | 88 | |2:4| magnitude | 54.59 | 8.33 | 6.33 | 89 | |2:4| sparsegpt | **10.17** | 8.32 | 5.40 | 90 | |2:4| wanda | 11.02 | **8.27** | **5.16** | 91 | 92 | ### Ablation on OBS weight update 93 | To reproduce the analysis on weight update, we provide our implementation for this ablation. All commands can be found in [this script](scripts/ablate_weight_update.sh). 94 | ```sh 95 | for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter 96 | do 97 | CUDA_VISIBLE_DEVICES=0 python main.py \ 98 | --model decapoda-research/llama-7b-hf \ 99 | --sparsity_ratio 0.5 \ 100 | --sparsity_type unstructured \ 101 | --prune_method ${method} \ 102 | --save out/llama_7b_ablation/unstructured/ 103 | done 104 | ``` 105 | Here `ablate_{mag/wanda}_{seq/iter}` means that we use magnitude pruning or wanda to obtain the pruned mask at each layer, then apply weight update procedure with either a sequential style or an iterative style every 128 input channels. For details, please see Section 5 of our [paper](https://arxiv.org/abs/2306.11695). 106 | 107 | ### Zero-Shot Evaluation 108 | For evaluating zero-shot tasks, we modify the [EleutherAI LM Harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/master) framework so that it could evaluate pruned LLM models. We provide the modified repo in [this link](https://drive.google.com/file/d/1zugbLyGZKsH1L19L9biHLfaGGFnEc7XL/view?usp=sharing). Make sure to download, extract and install this custom `lm_eval` package from the source code. 109 | 110 | For reproducibility, we used [commit `df3da98`](https://github.com/EleutherAI/lm-evaluation-harness/tree/df3da98c5405deafd519c2ddca52bb7c3fe36bef) on the main branch. All tasks were evaluated on task version of 0 except for BoolQ, where the task version is 1. 111 | 112 | On a high level, the functionality we provide is adding two arguments `pretrained_model` and `tokenizer` in this [function](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/evaluator.py#L17). We can then call this `simple_evaluate` function API from our [codebase](https://github.com/locuslab/wanda/blob/main/lib/eval.py#L148) to evaluate sparse pruned LLMs. To evaluate zero-shot tasks in addition to the WikiText perplexity, pass the `--eval_zero_shot` argument. 113 | 114 | ### Speedup Evaluation 115 | The pruning speed for each method is evaluated by the cumulated time spent on pruning (for each layer), without the forward passes. 116 | 117 | For inference speedup with structured sparsity, we refer the reader to this [blog post](https://pytorch.org/tutorials/prototype/semi_structured_sparse.html), where structured sparsity is supported by `PyTorch >= 2.1`. You can switch between the CUTLASS or CuSPARSELt kernel [here](https://github.com/pytorch/pytorch/blob/v2.1.0/torch/sparse/semi_structured.py#L55). 118 | 119 | Last, for pruning image classifiers, see directory [image_classifiers](image_classifiers) for details. 120 | 121 | ## Acknowledgement 122 | This repository is build upon the [SparseGPT](https://github.com/IST-DASLab/sparsegpt) repository. 123 | 124 | ## License 125 | This project is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information. 126 | 127 | ## Questions 128 | Feel free to discuss papers/code with us through issues/emails! 129 | 130 | mingjies at cs.cmu.edu 131 | liuzhuangthu at gmail.com -------------------------------------------------------------------------------- /dense_ft/README.md: -------------------------------------------------------------------------------- 1 | ## Dense Fine-tuning 2 | 3 | We provide a sparse Trainer in [sparse_trainer.py](sparse_trainer.py), which can be used as a drop-in-replacement of huggingface Trainer. In `training_step` function, when the gradients are obtained, `mask_grad()` will zero out the gradient corresponding to the pruned weights. -------------------------------------------------------------------------------- /dense_ft/sparse_trainer.py: -------------------------------------------------------------------------------- 1 | from transformers.trainer import Trainer 2 | import torch 3 | import torch.nn as nn 4 | 5 | def find_layers(module, layers=[nn.Linear], name=''): 6 | if type(module) in layers: 7 | return {name: module} 8 | res = {} 9 | for name1, child in module.named_children(): 10 | res.update(find_layers( 11 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 12 | )) 13 | return res 14 | 15 | def fix_grad_nan_inf(model): 16 | layers = model.model.layers 17 | count = 0 18 | total_params = 0 19 | for m in model.parameters(): 20 | if m.requires_grad: 21 | if torch.isnan(m.grad).any() or torch.isinf(m.grad).any(): 22 | m.grad.zero_() 23 | 24 | 25 | def mask_grad(model): 26 | layers = model.model.layers 27 | count = 0 28 | total_params = 0 29 | for i in range(len(layers)): 30 | layer = layers[i] 31 | subset = find_layers(layer) 32 | 33 | sub_count = 0 34 | sub_params = 0 35 | for name in subset: 36 | W = subset[name].weight.data 37 | mask = (W==0) 38 | subset[name].weight.grad[mask]= 0 39 | 40 | def check_sparsity(model): 41 | use_cache = model.config.use_cache 42 | model.config.use_cache = False 43 | 44 | layers = model.model.layers 45 | count = 0 46 | total_params = 0 47 | for i in range(len(layers)): 48 | layer = layers[i] 49 | subset = find_layers(layer) 50 | 51 | sub_count = 0 52 | sub_params = 0 53 | for name in subset: 54 | W = subset[name].weight.data 55 | count += (W==0).sum().item() 56 | total_params += W.numel() 57 | 58 | sub_count += (W==0).sum().item() 59 | sub_params += W.numel() 60 | 61 | # print(f"layer {i} sparsity {float(sub_count)/sub_params:.4f}") 62 | 63 | model.config.use_cache = use_cache 64 | return float(count)/total_params 65 | 66 | class SparseTrainer(Trainer): 67 | def __init__(self, model= None, args= None, data_collator= None, train_dataset= None, eval_dataset= None, 68 | tokenizer= None, model_init= None, compute_metrics= None, callbacks= None, optimizers= (None, None), 69 | preprocess_logits_for_metrics= None 70 | ): 71 | super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, 72 | optimizers, preprocess_logits_for_metrics) 73 | self.counter = 0 74 | 75 | def training_step(self, model, inputs): 76 | """ 77 | Perform a training step on a batch of inputs. 78 | 79 | Subclass and override to inject custom behavior. 80 | 81 | Args: 82 | model (`nn.Module`): 83 | The model to train. 84 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 85 | The inputs and targets of the model. 86 | 87 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 88 | argument `labels`. Check your model's documentation for all accepted arguments. 89 | 90 | Return: 91 | `torch.Tensor`: The tensor with training loss on this batch. 92 | 93 | access optimizer through: self.optimizer.optimizer.param_groups[0] 94 | """ 95 | self.counter += 1 96 | model.train() 97 | inputs = self._prepare_inputs(inputs) 98 | 99 | with self.compute_loss_context_manager(): 100 | loss = self.compute_loss(model, inputs) 101 | 102 | if self.args.n_gpu > 1: 103 | loss = loss.mean() # mean() to average on multi-gpu parallel training 104 | 105 | if self.do_grad_scaling: ## False 106 | self.scaler.scale(loss).backward() 107 | elif self.use_apex: ## False 108 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 109 | scaled_loss.backward() 110 | else: 111 | self.accelerator.backward(loss) 112 | # pass 113 | 114 | mask_grad(model) ### mask the gradients 115 | 116 | return loss.detach() / self.args.gradient_accumulation_steps 117 | 118 | def compute_loss(self, model, inputs, return_outputs=False): 119 | """ 120 | How the loss is computed by Trainer. By default, all models return the loss in the first element. 121 | 122 | Subclass and override for custom behavior. 123 | 124 | ## model type: transformers.models.llama.modeling_llama.LlamaForCausalLM 125 | ## outputs[0]: a single scalar 126 | ## outputs[1]: shape (bs, 2048, 32000) 127 | 128 | ## inputs["input_ids"] shape: (bs, 2048) 129 | ## inputs["attention_mask] shape: (bs, 2048) 130 | """ 131 | if self.label_smoother is not None and "labels" in inputs: 132 | labels = inputs.pop("labels") 133 | else: 134 | labels = None 135 | 136 | outputs = model(**inputs) 137 | # Save past state if it exists 138 | # TODO: this needs to be fixed and made cleaner later. 139 | if self.args.past_index >= 0: 140 | self._past = outputs[self.args.past_index] 141 | 142 | if labels is not None: 143 | if is_peft_available() and isinstance(model, PeftModel): 144 | model_name = unwrap_model(model.base_model)._get_name() 145 | else: 146 | model_name = unwrap_model(model)._get_name() 147 | if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): 148 | loss = self.label_smoother(outputs, labels, shift_labels=True) 149 | else: 150 | loss = self.label_smoother(outputs, labels) 151 | else: 152 | if isinstance(outputs, dict) and "loss" not in outputs: 153 | raise ValueError( 154 | "The model did not return a loss from the inputs, only the following keys: " 155 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." 156 | ) 157 | # We don't use .loss here since the model may return tuples instead of ModelOutput. 158 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 159 | 160 | return (loss, outputs) if return_outputs else loss -------------------------------------------------------------------------------- /image_classifiers/README.md: -------------------------------------------------------------------------------- 1 | # Pruning Image Classifiers 2 | Here we provide the code for pruning ConvNeXt and ViT. This part is built on the [dropout](https://github.com/facebookresearch/dropout) repository. 3 | 4 | ## Environment 5 | We additionally install `timm` for loading pretrained image classifiers. 6 | ``` 7 | pip install timm==0.4.12 8 | ``` 9 | 10 | ## Download Weights 11 | Run the script [download_weights.sh](download_weights.sh) to download pretrained weights for ConvNeXt-B, DeiT-B and ViT-L, which we used in the paper. 12 | 13 | ## Usage 14 | Here is the command for pruning ConvNeXt/ViT models: 15 | ``` 16 | python main.py --model [ARCH] \ 17 | --data_path [PATH to ImageNet] \ 18 | --resume [PATH to the pretrained weights] \ 19 | --prune_metric wanda \ 20 | --prune_granularity row \ 21 | --sparsity 0.5 22 | ``` 23 | where: 24 | - `--model`: network architecture, choices [`convnext_base`, `deit_base_patch16_224`, `vit_large_patch16_224`]. 25 | - `--resume`: model path to downloaded pretrained weights. 26 | - `--prune_metric`: [`magnitude`, `wanda`]. 27 | - `--prune_granularity`: [`layer`, `row`]. -------------------------------------------------------------------------------- /image_classifiers/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | from torchvision import datasets, transforms 10 | 11 | from timm.data.constants import \ 12 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 13 | from timm.data import create_transform 14 | 15 | def build_dataset(is_train, args): 16 | transform = build_transform(is_train, args) 17 | 18 | print("Transform = ") 19 | if isinstance(transform, tuple): 20 | for trans in transform: 21 | print(" - - - - - - - - - - ") 22 | for t in trans.transforms: 23 | print(t) 24 | else: 25 | for t in transform.transforms: 26 | print(t) 27 | print("---------------------------") 28 | 29 | if args.data_set == 'CIFAR': 30 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 31 | nb_classes = 100 32 | elif args.data_set == 'IMNET': 33 | print("reading from datapath", args.data_path) 34 | root = os.path.join(args.data_path, 'train' if is_train else 'val_dirs') 35 | dataset = datasets.ImageFolder(root, transform=transform) 36 | nb_classes = 1000 37 | elif args.data_set == "image_folder": 38 | root = args.data_path if is_train else args.eval_data_path 39 | dataset = datasets.ImageFolder(root, transform=transform) 40 | nb_classes = args.nb_classes 41 | assert len(dataset.class_to_idx) == nb_classes 42 | else: 43 | raise NotImplementedError() 44 | print("Number of the class = %d" % nb_classes) 45 | 46 | return dataset, nb_classes 47 | 48 | 49 | def build_transform(is_train, args): 50 | resize_im = args.input_size > 32 51 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 52 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 53 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 54 | 55 | if is_train: 56 | # this should always dispatch to transforms_imagenet_train 57 | transform = create_transform( 58 | input_size=args.input_size, 59 | is_training=True, 60 | color_jitter=args.color_jitter, 61 | auto_augment=args.aa, 62 | interpolation=args.train_interpolation, 63 | re_prob=args.reprob, 64 | re_mode=args.remode, 65 | re_count=args.recount, 66 | mean=mean, 67 | std=std, 68 | ) 69 | if not resize_im: 70 | transform.transforms[0] = transforms.RandomCrop( 71 | args.input_size, padding=4) 72 | return transform 73 | 74 | t = [] 75 | if resize_im: 76 | # warping (no cropping) when evaluated at 384 or larger 77 | if args.input_size >= 384: 78 | t.append( 79 | transforms.Resize((args.input_size, args.input_size), 80 | interpolation=transforms.InterpolationMode.BICUBIC), 81 | ) 82 | print(f"Warping {args.input_size} size input images...") 83 | else: 84 | if args.crop_pct is None: 85 | args.crop_pct = 224 / 256 86 | size = int(args.input_size / args.crop_pct) 87 | t.append( 88 | # to maintain same ratio w.r.t. 224 images 89 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 90 | ) 91 | t.append(transforms.CenterCrop(args.input_size)) 92 | 93 | t.append(transforms.ToTensor()) 94 | t.append(transforms.Normalize(mean, std)) 95 | return transforms.Compose(t) 96 | -------------------------------------------------------------------------------- /image_classifiers/download_weights.sh: -------------------------------------------------------------------------------- 1 | mkdir -p model_weights/vit/ 2 | mkdir -p model_weights/convnext 3 | mkdir -p model_weights/deit/ 4 | 5 | cd model_weights/vit 6 | wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth 7 | 8 | cd ../convnext/ 9 | wget https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth 10 | 11 | cd ../deit 12 | wget https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth 13 | 14 | cd ../.. -------------------------------------------------------------------------------- /image_classifiers/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import math 9 | from typing import Iterable, Optional 10 | import torch 11 | from timm.data import Mixup 12 | from timm.utils import accuracy, ModelEma 13 | 14 | import utils 15 | import torch.nn as nn 16 | 17 | from prune_utils import check_sparsity 18 | 19 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 20 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 21 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 22 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 23 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, schedules={}, 24 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False): 25 | model.train(True) 26 | metric_logger = utils.MetricLogger(delimiter=" ") 27 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 28 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 29 | header = 'Epoch: [{}]'.format(epoch) 30 | print_freq = 10 31 | 32 | optimizer.zero_grad() 33 | 34 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 35 | # if data_iter_step > 10: 36 | # break 37 | 38 | step = data_iter_step // update_freq 39 | if step >= num_training_steps_per_epoch: 40 | continue 41 | it = start_steps + step # global training iteration 42 | # Update LR & WD for the first acc 43 | if data_iter_step % update_freq == 0: 44 | if lr_schedule_values is not None or wd_schedule_values is not None: 45 | for i, param_group in enumerate(optimizer.param_groups): 46 | if lr_schedule_values is not None: 47 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 48 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 49 | param_group["weight_decay"] = wd_schedule_values[it] 50 | if 'dp' in schedules: 51 | model.module.update_drop_path(schedules['dp'][it]) 52 | if 'do' in schedules: 53 | model.module.update_dropout(schedules['do'][it]) 54 | 55 | samples = samples.to(device, non_blocking=True) 56 | targets = targets.to(device, non_blocking=True) 57 | 58 | if mixup_fn is not None: 59 | samples, targets = mixup_fn(samples, targets) 60 | 61 | if use_amp: 62 | with torch.cuda.amp.autocast(): 63 | output = model(samples) 64 | loss = criterion(output, targets) 65 | else: # full precision 66 | output = model(samples) 67 | loss = criterion(output, targets) 68 | 69 | loss_value = loss.item() 70 | 71 | if not math.isfinite(loss_value): # this could trigger if using AMP 72 | print("Loss is {}, stopping training".format(loss_value)) 73 | assert math.isfinite(loss_value) 74 | 75 | if use_amp: 76 | # this attribute is added by timm on one optimizer (adahessian) 77 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 78 | loss /= update_freq 79 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 80 | parameters=model.parameters(), create_graph=is_second_order, 81 | update_grad=(data_iter_step + 1) % update_freq == 0) 82 | 83 | if (data_iter_step + 1) % update_freq == 0: 84 | optimizer.zero_grad() 85 | if model_ema is not None: 86 | model_ema.update(model) 87 | else: # full precision 88 | loss /= update_freq 89 | loss.backward() 90 | 91 | ############################################################################### 92 | ## code i added 93 | for k, m in enumerate(model.modules()): 94 | if isinstance(m, nn.Linear): 95 | weight_copy = m.weight.data.abs().clone() 96 | mask = weight_copy.gt(0).float().cuda() 97 | m.weight.grad.data.mul_(mask) 98 | ################################################################################ 99 | 100 | if (data_iter_step + 1) % update_freq == 0: 101 | optimizer.step() 102 | optimizer.zero_grad() 103 | if model_ema is not None: 104 | model_ema.update(model) 105 | 106 | torch.cuda.synchronize() 107 | 108 | if mixup_fn is None: 109 | class_acc = (output.max(-1)[-1] == targets).float().mean() 110 | else: 111 | class_acc = None 112 | metric_logger.update(loss=loss_value) 113 | metric_logger.update(class_acc=class_acc) 114 | min_lr = 10. 115 | max_lr = 0. 116 | for group in optimizer.param_groups: 117 | min_lr = min(min_lr, group["lr"]) 118 | max_lr = max(max_lr, group["lr"]) 119 | 120 | metric_logger.update(lr=max_lr) 121 | metric_logger.update(min_lr=min_lr) 122 | weight_decay_value = None 123 | for group in optimizer.param_groups: 124 | if group["weight_decay"] > 0: 125 | weight_decay_value = group["weight_decay"] 126 | metric_logger.update(weight_decay=weight_decay_value) 127 | 128 | if 'dp' in schedules: 129 | metric_logger.update(drop_path=model.module.drop_path) 130 | 131 | if 'do' in schedules: 132 | metric_logger.update(dropout=model.module.drop_rate) 133 | 134 | if use_amp: 135 | metric_logger.update(grad_norm=grad_norm) 136 | 137 | if log_writer is not None: 138 | log_writer.update(loss=loss_value, head="loss") 139 | log_writer.update(class_acc=class_acc, head="loss") 140 | log_writer.update(lr=max_lr, head="opt") 141 | log_writer.update(min_lr=min_lr, head="opt") 142 | log_writer.update(weight_decay=weight_decay_value, head="opt") 143 | if use_amp: 144 | log_writer.update(grad_norm=grad_norm, head="opt") 145 | log_writer.set_step() 146 | 147 | if wandb_logger: 148 | wandb_logger._wandb.log({ 149 | 'Rank-0 Batch Wise/train_loss': loss_value, 150 | 'Rank-0 Batch Wise/train_max_lr': max_lr, 151 | 'Rank-0 Batch Wise/train_min_lr': min_lr 152 | }, commit=False) 153 | if class_acc: 154 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False) 155 | if use_amp: 156 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False) 157 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it}) 158 | 159 | # gather the stats from all processes 160 | metric_logger.synchronize_between_processes() 161 | print("Averaged stats:", metric_logger) 162 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 163 | 164 | @torch.no_grad() 165 | def evaluate(data_loader, model, device, use_amp=False): 166 | criterion = torch.nn.CrossEntropyLoss() 167 | 168 | metric_logger = utils.MetricLogger(delimiter=" ") 169 | header = 'Test:' 170 | 171 | # switch to evaluation mode 172 | model.eval() 173 | for batch in metric_logger.log_every(data_loader, 10, header): 174 | images = batch[0] 175 | target = batch[-1] 176 | 177 | images = images.to(device, non_blocking=True) 178 | target = target.to(device, non_blocking=True) 179 | 180 | # compute output 181 | if use_amp: 182 | with torch.cuda.amp.autocast(): 183 | output = model(images) 184 | loss = criterion(output, target) 185 | else: 186 | output = model(images) 187 | loss = criterion(output, target) 188 | 189 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 190 | 191 | batch_size = images.shape[0] 192 | metric_logger.update(loss=loss.item()) 193 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 194 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 195 | # gather the stats from all processes 196 | metric_logger.synchronize_between_processes() 197 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 198 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 199 | 200 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /image_classifiers/layerwrapper.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | torch.backends.cuda.matmul.allow_tf32 = False 8 | torch.backends.cudnn.allow_tf32 = False 9 | 10 | 11 | class WrappedLayer: 12 | 13 | def __init__(self, layer, layer_id=0, layer_name="none", p_norm=2): 14 | self.layer = layer 15 | self.dev = self.layer.weight.device 16 | self.rows = layer.weight.data.shape[0] 17 | self.columns = layer.weight.data.shape[1] 18 | 19 | self.scaler_row = torch.zeros((self.columns), device=self.dev) 20 | self.nsamples = 0 21 | 22 | self.layer_id = layer_id 23 | self.layer_name = layer_name 24 | self.p_norm = p_norm 25 | 26 | def add_batch(self, inp, out): 27 | assert inp.shape[-1] == self.columns 28 | inp = inp.reshape((-1,self.columns)) 29 | tmp = inp.shape[0] 30 | inp = inp.t() 31 | 32 | if self.p_norm == 2: 33 | self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / tmp 34 | elif self.p_norm == 1: 35 | self.scaler_row += torch.norm(inp, p=1, dim=1) ** 1 / tmp 36 | 37 | if torch.isinf(self.scaler_row).sum() > 0: 38 | print("encountered torch.isinf error") 39 | raise ValueError 40 | 41 | def prune(self, W_mask): 42 | self.layer.weight.data[W_mask] = 0 43 | out_ = self.layer(self.inp1) 44 | 45 | dist = (self.out1 - out_).squeeze(dim=0) 46 | 47 | bias = torch.mean(dist, dim=0) 48 | self.layer.bias = nn.Parameter(bias) 49 | 50 | def free(self): 51 | if DEBUG: 52 | self.inp1 = None 53 | self.out1 = None 54 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /image_classifiers/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import argparse 9 | import datetime 10 | import numpy as np 11 | import time 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | import json 16 | import os 17 | 18 | from pathlib import Path 19 | 20 | from timm.data.mixup import Mixup 21 | from timm.models import create_model 22 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 23 | from timm.utils import ModelEma 24 | from optim_factory import create_optimizer, LayerDecayValueAssigner 25 | 26 | from datasets import build_dataset 27 | from engine import train_one_epoch, evaluate 28 | 29 | import utils 30 | 31 | import models.convnext 32 | import models.vision_transformer 33 | import models.swin_transformer 34 | import models.mlp_mixer 35 | import models.deit 36 | 37 | from prune_utils import prune_convnext, prune_deit, prune_vit, check_sparsity 38 | 39 | def str2bool(v): 40 | """ 41 | Converts string to bool type; enables command line 42 | arguments in the format of '--arg1 true --arg2 false' 43 | """ 44 | if isinstance(v, bool): 45 | return v 46 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 47 | return True 48 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 49 | return False 50 | else: 51 | raise argparse.ArgumentTypeError('Boolean value expected.') 52 | 53 | def get_args_parser(): 54 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False) 55 | parser.add_argument('--batch_size', default=256, type=int, 56 | help='Per GPU batch size') 57 | parser.add_argument('--epochs', default=300, type=int) 58 | parser.add_argument('--update_freq', default=1, type=int, 59 | help='gradient accumulation steps') 60 | 61 | # Model parameters 62 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL', 63 | help='Name of model to train') 64 | parser.add_argument('--input_size', default=224, type=int, 65 | help='image input size') 66 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float, 67 | help="Layer scale initial values") 68 | 69 | ########################## settings specific to this project ########################## 70 | 71 | # dropout and stochastic depth drop rate; set at most one to non-zero 72 | parser.add_argument('--dropout', type=float, default=0, metavar='PCT', 73 | help='Drop path rate (default: 0.0)') 74 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT', 75 | help='Drop path rate (default: 0.0)') 76 | 77 | # early / late dropout and stochastic depth settings 78 | parser.add_argument('--drop_mode', type=str, default='standard', choices=['standard', 'early', 'late'], help='drop mode') 79 | parser.add_argument('--drop_schedule', type=str, default='constant', choices=['constant', 'linear'], 80 | help='drop schedule for early dropout / s.d. only') 81 | parser.add_argument('--cutoff_epoch', type=int, default=0, 82 | help='if drop_mode is early / late, this is the epoch where dropout ends / starts') 83 | 84 | ####################################################################################### 85 | 86 | # EMA related parameters 87 | parser.add_argument('--model_ema', type=str2bool, default=False) 88 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 89 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='') 90 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.') 91 | 92 | # Optimization parameters 93 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 94 | help='Optimizer (default: "adamw"') 95 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 96 | help='Optimizer Epsilon (default: 1e-8)') 97 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 98 | help='Optimizer Betas (default: None, use opt default)') 99 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 100 | help='Clip gradient norm (default: None, no clipping)') 101 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 102 | help='SGD momentum (default: 0.9)') 103 | parser.add_argument('--weight_decay', type=float, default=0.05, 104 | help='weight decay (default: 0.05)') 105 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 106 | weight decay. We use a cosine schedule for WD and using a larger decay by 107 | the end of training improves performance for ViTs.""") 108 | 109 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR', 110 | help='learning rate (default: 4e-3), with total batch size 4096') 111 | parser.add_argument('--layer_decay', type=float, default=1.0) 112 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 113 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') 114 | parser.add_argument('--warmup_epochs', type=int, default=50, metavar='N', 115 | help='epochs to warmup LR, if scheduler supports') 116 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 117 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 118 | 119 | # Augmentation parameters 120 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 121 | help='Color jitter factor (default: 0.4)') 122 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 123 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 124 | parser.add_argument('--smoothing', type=float, default=0.1, 125 | help='Label smoothing (default: 0.1)') 126 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 127 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 128 | 129 | # Evaluation parameters 130 | parser.add_argument('--crop_pct', type=float, default=None) 131 | 132 | # * Random Erase params 133 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 134 | help='Random erase prob (default: 0.25)') 135 | parser.add_argument('--remode', type=str, default='pixel', 136 | help='Random erase mode (default: "pixel")') 137 | parser.add_argument('--recount', type=int, default=1, 138 | help='Random erase count (default: 1)') 139 | parser.add_argument('--resplit', type=str2bool, default=False, 140 | help='Do not random erase first (clean) augmentation split') 141 | 142 | # * Mixup params 143 | parser.add_argument('--mixup', type=float, default=0.8, 144 | help='mixup alpha, mixup enabled if > 0.') 145 | parser.add_argument('--cutmix', type=float, default=1.0, 146 | help='cutmix alpha, cutmix enabled if > 0.') 147 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 148 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 149 | parser.add_argument('--mixup_prob', type=float, default=1.0, 150 | help='Probability of performing mixup or cutmix when either/both is enabled') 151 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 152 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 153 | parser.add_argument('--mixup_mode', type=str, default='batch', 154 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 155 | 156 | # * Finetuning params 157 | parser.add_argument('--finetune', default='', 158 | help='finetune from checkpoint') 159 | parser.add_argument('--head_init_scale', default=1.0, type=float, 160 | help='classifier head initial scale, typically adjusted in fine-tuning') 161 | parser.add_argument('--model_key', default='model|module', type=str, 162 | help='which key to load from saved state dict, usually model or model_ema') 163 | parser.add_argument('--model_prefix', default='', type=str) 164 | 165 | # Dataset parameters 166 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 167 | help='dataset path') 168 | parser.add_argument('--eval_data_path', default=None, type=str, 169 | help='dataset path for evaluation') 170 | parser.add_argument('--nb_classes', default=1000, type=int, 171 | help='number of the classification types') 172 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True) 173 | parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'image_folder'], 174 | type=str, help='ImageNet dataset path') 175 | parser.add_argument('--output_dir', default='', 176 | help='path where to save, empty for no saving') 177 | parser.add_argument('--device', default='cuda', 178 | help='device to use for training / testing') 179 | parser.add_argument('--seed', default=0, type=int) 180 | 181 | parser.add_argument('--resume', default='', 182 | help='resume from checkpoint') 183 | parser.add_argument('--auto_resume', type=str2bool, default=True) 184 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 185 | parser.add_argument('--save_ckpt_freq', default=1, type=int) 186 | parser.add_argument('--save_ckpt_num', default=3, type=int) 187 | 188 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 189 | help='start epoch') 190 | parser.add_argument('--eval', type=str2bool, default=False, 191 | help='Perform evaluation only') 192 | parser.add_argument('--dist_eval', type=str2bool, default=True, 193 | help='Enabling distributed evaluation') 194 | parser.add_argument('--disable_eval', type=str2bool, default=False, 195 | help='Disabling evaluation during training') 196 | parser.add_argument('--num_workers', default=10, type=int) 197 | parser.add_argument('--pin_mem', type=str2bool, default=True, 198 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 199 | 200 | # distributed training parameters 201 | parser.add_argument('--world_size', default=1, type=int, 202 | help='number of distributed processes') 203 | parser.add_argument('--local_rank', default=-1, type=int) 204 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 205 | parser.add_argument('--dist_url', default='env://', 206 | help='url used to set up distributed training') 207 | 208 | parser.add_argument('--use_amp', type=str2bool, default=False, 209 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not") 210 | 211 | # Weights and Biases arguments 212 | parser.add_argument('--enable_wandb', type=str2bool, default=False, 213 | help="enable logging to Weights and Biases") 214 | parser.add_argument('--project', default='convnext', type=str, 215 | help="The name of the W&B project where you're sending the new run.") 216 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False, 217 | help="Save model checkpoints as W&B Artifacts.") 218 | 219 | # arguments for pruning 220 | parser.add_argument("--nsamples", type=int, default=4096) 221 | parser.add_argument("--sparsity", type=float, default=0.) 222 | parser.add_argument("--prune_metric", type=str, choices=["magnitude", "wanda"]) 223 | parser.add_argument("--prune_granularity", type=str) 224 | parser.add_argument("--blocksize", type=int, default=1) 225 | 226 | return parser 227 | 228 | def main(args): 229 | utils.init_distributed_mode(args) 230 | print(args) 231 | device = torch.device(args.device) 232 | 233 | # fix the seed for reproducibility 234 | seed = args.seed + utils.get_rank() 235 | torch.manual_seed(seed) 236 | np.random.seed(seed) 237 | cudnn.benchmark = True 238 | 239 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 240 | if args.disable_eval: 241 | args.dist_eval = False 242 | dataset_val = None 243 | else: 244 | dataset_val, _ = build_dataset(is_train=False, args=args) 245 | 246 | num_tasks = utils.get_world_size() 247 | global_rank = utils.get_rank() 248 | 249 | sampler_train = torch.utils.data.DistributedSampler( 250 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 251 | ) 252 | print("Sampler_train = %s" % str(sampler_train)) 253 | if args.dist_eval: 254 | if len(dataset_val) % num_tasks != 0: 255 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 256 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 257 | 'equal num of samples per-process.') 258 | sampler_val = torch.utils.data.DistributedSampler( 259 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 260 | else: 261 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 262 | 263 | if global_rank == 0 and args.enable_wandb: 264 | wandb_logger = utils.WandbLogger(args) 265 | else: 266 | wandb_logger = None 267 | 268 | 269 | data_loader_train = torch.utils.data.DataLoader( 270 | dataset_train, sampler=sampler_train, 271 | batch_size=args.batch_size, 272 | num_workers=args.num_workers, 273 | pin_memory=args.pin_mem, 274 | drop_last=True, 275 | ) 276 | 277 | if dataset_val is not None: 278 | data_loader_val = torch.utils.data.DataLoader( 279 | dataset_val, sampler=sampler_val, 280 | batch_size=int(1.5 * args.batch_size), 281 | num_workers=args.num_workers, 282 | pin_memory=args.pin_mem, 283 | drop_last=False 284 | ) 285 | else: 286 | data_loader_val = None 287 | 288 | model = utils.build_model(args, pretrained=False) 289 | model.cuda() 290 | if args.distributed: 291 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 292 | model_without_ddp = model.module 293 | 294 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 295 | print('number of params:', n_parameters) 296 | 297 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 298 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 299 | 300 | # At most one of dropout and stochastic depth should be enabled. 301 | assert(args.dropout == 0 or args.drop_path == 0) 302 | # ConvNeXt does not support dropout. 303 | assert(args.dropout == 0 if args.model.startswith("convnext") else True) 304 | 305 | 306 | if "convnext" in args.model: 307 | checkpoint = torch.load(args.resume, map_location='cpu') 308 | model.load_state_dict(checkpoint["model"]) 309 | elif "vit" in args.model: 310 | checkpoint = torch.load(args.resume, map_location='cpu') 311 | model.load_state_dict(checkpoint) 312 | elif "deit" in args.model: 313 | checkpoint = torch.load(args.resume, map_location='cpu') 314 | model.load_state_dict(checkpoint["model"]) 315 | 316 | ################################################################################ 317 | np.random.seed(0) 318 | calibration_ids = np.random.choice(len(dataset_train), args.nsamples) 319 | calib_data = [] 320 | for i in calibration_ids: 321 | calib_data.append(dataset_train[i][0].unsqueeze(dim=0)) 322 | calib_data = torch.cat(calib_data, dim=0).to(device) 323 | 324 | tick = time.time() 325 | if args.sparsity != 0: 326 | with torch.no_grad(): 327 | if "convnext" in args.model: 328 | prune_convnext(args, model, calib_data, device) 329 | elif "vit" in args.model: 330 | prune_vit(args, model, calib_data, device) 331 | elif "deit" in args.model: 332 | prune_vit(args, model, calib_data, device) 333 | ################################################################################ 334 | 335 | actual_sparsity = check_sparsity(model) 336 | print(f"actual sparsity {actual_sparsity}") 337 | 338 | print(f"Eval only mode") 339 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 340 | val_acc = test_stats["acc1"] 341 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%") 342 | return 343 | 344 | if __name__ == '__main__': 345 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()]) 346 | args = parser.parse_args() 347 | if args.output_dir: 348 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 349 | main(args) -------------------------------------------------------------------------------- /image_classifiers/models/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from timm.models.layers import trunc_normal_, DropPath 12 | from timm.models.registry import register_model 13 | 14 | class Block(nn.Module): 15 | r""" ConvNeXt Block. There are two equivalent implementations: 16 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 17 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 18 | We use (2) as we find it slightly faster in PyTorch 19 | 20 | Args: 21 | dim (int): Number of input channels. 22 | drop_path (float): Stochastic depth rate. Default: 0.0 23 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 24 | """ 25 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, drop_rate=0.): 26 | super().__init__() 27 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 28 | self.norm = LayerNorm(dim, eps=1e-6) 29 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 30 | self.act = nn.GELU() 31 | self.pwconv2 = nn.Linear(4 * dim, dim) 32 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 33 | requires_grad=True) if layer_scale_init_value > 0 else None 34 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 35 | self.dropout = nn.Dropout(drop_rate) 36 | 37 | def forward(self, x): 38 | input = x 39 | x = self.dwconv(x) 40 | x = self.dropout(x) 41 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 42 | x = self.norm(x) 43 | x = self.pwconv1(x) 44 | x = self.act(x) 45 | x = self.dropout(x) 46 | x = self.pwconv2(x) 47 | x = self.dropout(x) 48 | if self.gamma is not None: 49 | x = self.gamma * x 50 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 51 | x = input + self.drop_path(x) 52 | return x 53 | 54 | 55 | class ConvNeXt(nn.Module): 56 | r""" ConvNeXt 57 | A PyTorch impl of : `A ConvNet for the 2020s` - 58 | https://arxiv.org/pdf/2201.03545.pdf 59 | 60 | Args: 61 | in_chans (int): Number of input image channels. Default: 3 62 | num_classes (int): Number of classes for classification head. Default: 1000 63 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 64 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 65 | drop_path_rate (float): Stochastic depth rate. Default: 0. 66 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 67 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 68 | drop_rate (float): Dropout rate 69 | """ 70 | def __init__(self, in_chans=3, num_classes=1000, 71 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 72 | layer_scale_init_value=1e-6, head_init_scale=1., drop_rate=0. 73 | ): 74 | super().__init__() 75 | 76 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 77 | stem = nn.Sequential( 78 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 79 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 80 | ) 81 | self.downsample_layers.append(stem) 82 | for i in range(3): 83 | downsample_layer = nn.Sequential( 84 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 85 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 86 | ) 87 | self.downsample_layers.append(downsample_layer) 88 | 89 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 90 | self.depths = depths 91 | self.drop_path = drop_path_rate 92 | self.drop_rate = drop_rate 93 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 94 | cur = 0 95 | for i in range(4): 96 | stage = nn.Sequential( 97 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 98 | layer_scale_init_value=layer_scale_init_value, drop_rate=drop_rate) for j in range(depths[i])] 99 | ) 100 | self.stages.append(stage) 101 | cur += depths[i] 102 | 103 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 104 | self.head = nn.Linear(dims[-1], num_classes) 105 | 106 | self.apply(self._init_weights) 107 | self.head.weight.data.mul_(head_init_scale) 108 | self.head.bias.data.mul_(head_init_scale) 109 | 110 | def _init_weights(self, m): 111 | if isinstance(m, (nn.Conv2d, nn.Linear)): 112 | trunc_normal_(m.weight, std=.02) 113 | nn.init.constant_(m.bias, 0) 114 | 115 | def forward_features(self, x): 116 | for i in range(4): 117 | x = self.downsample_layers[i](x) 118 | x = self.stages[i](x) 119 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 120 | 121 | def forward(self, x): 122 | x = self.forward_features(x) 123 | x = self.head(x) 124 | return x 125 | 126 | def update_drop_path(self, drop_path_rate): 127 | self.drop_path = drop_path_rate 128 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 129 | cur = 0 130 | for i in range(4): 131 | for j in range(self.depths[i]): 132 | self.stages[i][j].drop_path.drop_prob = dp_rates[cur + j] 133 | cur += self.depths[i] 134 | 135 | class LayerNorm(nn.Module): 136 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 137 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 138 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 139 | with shape (batch_size, channels, height, width). 140 | """ 141 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 142 | super().__init__() 143 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 144 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 145 | self.eps = eps 146 | self.data_format = data_format 147 | if self.data_format not in ["channels_last", "channels_first"]: 148 | raise NotImplementedError 149 | self.normalized_shape = (normalized_shape, ) 150 | 151 | def forward(self, x): 152 | if self.data_format == "channels_last": 153 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 154 | elif self.data_format == "channels_first": 155 | u = x.mean(1, keepdim=True) 156 | s = (x - u).pow(2).mean(1, keepdim=True) 157 | x = (x - u) / torch.sqrt(s + self.eps) 158 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 159 | return x 160 | 161 | model_urls = { 162 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 163 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 164 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 165 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 166 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 167 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 168 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 169 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 170 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 171 | } 172 | 173 | @register_model 174 | def convnext_atto(pretrained=False, **kwargs): 175 | model = ConvNeXt(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) 176 | return model 177 | 178 | @register_model 179 | def convnext_mini(pretrained=False, **kwargs): 180 | model = ConvNeXt(depths=[2, 2, 4, 2], dims=[48, 96, 192, 384], **kwargs) 181 | return model 182 | 183 | @register_model 184 | def convnext_femto(pretrained=False, **kwargs): 185 | model = ConvNeXt(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) 186 | return model 187 | 188 | @register_model 189 | def convnext_pico(pretrained=False, **kwargs): 190 | # timm pico variant 191 | model = ConvNeXt(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) 192 | return model 193 | 194 | @register_model 195 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs): 196 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 197 | return model 198 | 199 | @register_model 200 | def convnext_small(pretrained=False,in_22k=False, **kwargs): 201 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 202 | return model 203 | 204 | @register_model 205 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 206 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 207 | return model 208 | 209 | @register_model 210 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 211 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 212 | return model 213 | 214 | @register_model 215 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 216 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 217 | return model 218 | -------------------------------------------------------------------------------- /image_classifiers/models/deit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | 7 | from timm.models.vision_transformer import VisionTransformer, _cfg 8 | from timm.models.registry import register_model 9 | from timm.models.layers import trunc_normal_ 10 | 11 | 12 | __all__ = [ 13 | 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', 14 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224', 15 | 'deit_base_distilled_patch16_224', 'deit_base_patch16_384', 16 | 'deit_base_distilled_patch16_384', 17 | ] 18 | 19 | 20 | class DistilledVisionTransformer(VisionTransformer): 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 24 | num_patches = self.patch_embed.num_patches 25 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 26 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 27 | 28 | trunc_normal_(self.dist_token, std=.02) 29 | trunc_normal_(self.pos_embed, std=.02) 30 | self.head_dist.apply(self._init_weights) 31 | 32 | def forward_features(self, x): 33 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 34 | # with slight modifications to add the dist_token 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | dist_token = self.dist_token.expand(B, -1, -1) 40 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 41 | 42 | x = x + self.pos_embed 43 | x = self.pos_drop(x) 44 | 45 | for blk in self.blocks: 46 | x = blk(x) 47 | 48 | x = self.norm(x) 49 | return x[:, 0], x[:, 1] 50 | 51 | def forward(self, x): 52 | x, x_dist = self.forward_features(x) 53 | x = self.head(x) 54 | x_dist = self.head_dist(x_dist) 55 | if self.training: 56 | return x, x_dist 57 | else: 58 | # during inference, return the average of both classifier predictions 59 | return (x + x_dist) / 2 60 | 61 | 62 | @register_model 63 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 64 | model = VisionTransformer( 65 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 66 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 67 | model.default_cfg = _cfg() 68 | if pretrained: 69 | checkpoint = torch.hub.load_state_dict_from_url( 70 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 71 | map_location="cpu", check_hash=True 72 | ) 73 | model.load_state_dict(checkpoint["model"]) 74 | return model 75 | 76 | 77 | @register_model 78 | def deit_small_patch16_224(pretrained=False, **kwargs): 79 | model = VisionTransformer( 80 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 81 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 82 | model.default_cfg = _cfg() 83 | if pretrained: 84 | checkpoint = torch.hub.load_state_dict_from_url( 85 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 86 | map_location="cpu", check_hash=True 87 | ) 88 | model.load_state_dict(checkpoint["model"]) 89 | return model 90 | 91 | 92 | @register_model 93 | def deit_base_patch16_224(pretrained=False, **kwargs): 94 | model = VisionTransformer( 95 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 96 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 97 | model.default_cfg = _cfg() 98 | if pretrained: 99 | checkpoint = torch.hub.load_state_dict_from_url( 100 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 101 | map_location="cpu", check_hash=True 102 | ) 103 | model.load_state_dict(checkpoint["model"]) 104 | return model 105 | 106 | 107 | @register_model 108 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 109 | model = DistilledVisionTransformer( 110 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 111 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 112 | model.default_cfg = _cfg() 113 | if pretrained: 114 | checkpoint = torch.hub.load_state_dict_from_url( 115 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth", 116 | map_location="cpu", check_hash=True 117 | ) 118 | model.load_state_dict(checkpoint["model"]) 119 | return model 120 | 121 | 122 | @register_model 123 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs): 124 | model = DistilledVisionTransformer( 125 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 126 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 127 | model.default_cfg = _cfg() 128 | if pretrained: 129 | checkpoint = torch.hub.load_state_dict_from_url( 130 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth", 131 | map_location="cpu", check_hash=True 132 | ) 133 | model.load_state_dict(checkpoint["model"]) 134 | return model 135 | 136 | 137 | @register_model 138 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs): 139 | model = DistilledVisionTransformer( 140 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 141 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 142 | model.default_cfg = _cfg() 143 | if pretrained: 144 | checkpoint = torch.hub.load_state_dict_from_url( 145 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth", 146 | map_location="cpu", check_hash=True 147 | ) 148 | model.load_state_dict(checkpoint["model"]) 149 | return model 150 | 151 | 152 | @register_model 153 | def deit_base_patch16_384(pretrained=False, **kwargs): 154 | model = VisionTransformer( 155 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 156 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 157 | model.default_cfg = _cfg() 158 | if pretrained: 159 | checkpoint = torch.hub.load_state_dict_from_url( 160 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth", 161 | map_location="cpu", check_hash=True 162 | ) 163 | model.load_state_dict(checkpoint["model"]) 164 | return model 165 | 166 | 167 | @register_model 168 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs): 169 | model = DistilledVisionTransformer( 170 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 171 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 172 | model.default_cfg = _cfg() 173 | if pretrained: 174 | checkpoint = torch.hub.load_state_dict_from_url( 175 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth", 176 | map_location="cpu", check_hash=True 177 | ) 178 | model.load_state_dict(checkpoint["model"]) 179 | return model -------------------------------------------------------------------------------- /image_classifiers/models/mlp_mixer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import math 9 | from copy import deepcopy 10 | from functools import partial 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 16 | from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply 17 | from timm.models.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple 18 | from timm.models.registry import register_model 19 | 20 | 21 | def _cfg(url='', **kwargs): 22 | return { 23 | 'url': url, 24 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 25 | 'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True, 26 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 27 | 'first_conv': 'stem.proj', 'classifier': 'head', 28 | **kwargs 29 | } 30 | 31 | 32 | default_cfgs = dict( 33 | mixer_s32_224=_cfg(), 34 | mixer_s16_224=_cfg(), 35 | mixer_b32_224=_cfg(), 36 | mixer_b16_224=_cfg( 37 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth', 38 | ), 39 | mixer_b16_224_in21k=_cfg( 40 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth', 41 | num_classes=21843 42 | ), 43 | mixer_l32_224=_cfg(), 44 | mixer_l16_224=_cfg( 45 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth', 46 | ), 47 | mixer_l16_224_in21k=_cfg( 48 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth', 49 | num_classes=21843 50 | ), 51 | 52 | # Mixer ImageNet-21K-P pretraining 53 | mixer_b16_224_miil_in21k=_cfg( 54 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth', 55 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, 56 | ), 57 | mixer_b16_224_miil=_cfg( 58 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth', 59 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', 60 | ), 61 | 62 | gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 63 | gmixer_24_224=_cfg( 64 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth', 65 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 66 | 67 | resmlp_12_224=_cfg( 68 | url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth', 69 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 70 | resmlp_24_224=_cfg( 71 | url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth', 72 | #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth', 73 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 74 | resmlp_36_224=_cfg( 75 | url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth', 76 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 77 | resmlp_big_24_224=_cfg( 78 | url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth', 79 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 80 | 81 | resmlp_12_distilled_224=_cfg( 82 | url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth', 83 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 84 | resmlp_24_distilled_224=_cfg( 85 | url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth', 86 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 87 | resmlp_36_distilled_224=_cfg( 88 | url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth', 89 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 90 | resmlp_big_24_distilled_224=_cfg( 91 | url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth', 92 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 93 | 94 | resmlp_big_24_224_in22ft1k=_cfg( 95 | url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth', 96 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 97 | 98 | gmlp_ti16_224=_cfg(), 99 | gmlp_s16_224=_cfg( 100 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth', 101 | ), 102 | gmlp_b16_224=_cfg(), 103 | ) 104 | 105 | 106 | class MixerBlock(nn.Module): 107 | """ Residual Block w/ token mixing and channel MLPs 108 | Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 109 | """ 110 | def __init__( 111 | self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp, 112 | norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): 113 | super().__init__() 114 | tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)] 115 | self.norm1 = norm_layer(dim) 116 | self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) 117 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 118 | self.norm2 = norm_layer(dim) 119 | self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) 120 | 121 | def forward(self, x): 122 | x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) 123 | x = x + self.drop_path(self.mlp_channels(self.norm2(x))) 124 | return x 125 | 126 | 127 | class MlpMixer(nn.Module): 128 | 129 | def __init__( 130 | self, 131 | num_classes=1000, 132 | img_size=224, 133 | in_chans=3, 134 | patch_size=16, 135 | num_blocks=8, 136 | embed_dim=512, 137 | mlp_ratio=(0.5, 4.0), 138 | block_layer=MixerBlock, 139 | mlp_layer=Mlp, 140 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 141 | act_layer=nn.GELU, 142 | drop_rate=0., 143 | drop_path_rate=0., 144 | nlhb=False, 145 | stem_norm=False, 146 | ): 147 | super().__init__() 148 | self.num_classes = num_classes 149 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 150 | 151 | self.stem = PatchEmbed( 152 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, 153 | embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None) 154 | # FIXME drop_path (stochastic depth scaling rule or all the same?) 155 | self.drop_path = drop_path_rate 156 | self.drop_rate = drop_rate 157 | self.num_blocks = num_blocks 158 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] 159 | self.blocks = nn.Sequential(*[ 160 | block_layer( 161 | embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, 162 | act_layer=act_layer, drop=drop_rate, drop_path=dpr[i]) 163 | for i in range(num_blocks)]) 164 | self.norm = norm_layer(embed_dim) 165 | self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 166 | 167 | self.init_weights(nlhb=nlhb) 168 | 169 | def init_weights(self, nlhb=False): 170 | head_bias = -math.log(self.num_classes) if nlhb else 0. 171 | named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first 172 | 173 | def get_classifier(self): 174 | return self.head 175 | 176 | def reset_classifier(self, num_classes, global_pool=''): 177 | self.num_classes = num_classes 178 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 179 | 180 | def forward_features(self, x): 181 | x = self.stem(x) 182 | x = self.blocks(x) 183 | x = self.norm(x) 184 | x = x.mean(dim=1) 185 | return x 186 | 187 | def forward(self, x): 188 | x = self.forward_features(x) 189 | x = self.head(x) 190 | return x 191 | 192 | def update_drop_path(self, drop_path_rate): 193 | self.drop_path = drop_path_rate 194 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_blocks)] 195 | cur = 0 196 | for block in self.blocks: 197 | block.drop_path.drop_prob = dpr[cur] 198 | cur += 1 199 | assert cur == self.num_blocks 200 | 201 | def update_dropout(self, drop_rate): 202 | self.drop_rate = drop_rate 203 | for module in self.modules(): 204 | if isinstance(module, nn.Dropout): 205 | module.p = drop_rate 206 | 207 | 208 | def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): 209 | """ Mixer weight initialization (trying to match Flax defaults) 210 | """ 211 | if isinstance(module, nn.Linear): 212 | if name.startswith('head'): 213 | nn.init.zeros_(module.weight) 214 | nn.init.constant_(module.bias, head_bias) 215 | else: 216 | if flax: 217 | # Flax defaults 218 | lecun_normal_(module.weight) 219 | if module.bias is not None: 220 | nn.init.zeros_(module.bias) 221 | else: 222 | # like MLP init in vit (my original init) 223 | nn.init.xavier_uniform_(module.weight) 224 | if module.bias is not None: 225 | if 'mlp' in name: 226 | nn.init.normal_(module.bias, std=1e-6) 227 | else: 228 | nn.init.zeros_(module.bias) 229 | elif isinstance(module, nn.Conv2d): 230 | lecun_normal_(module.weight) 231 | if module.bias is not None: 232 | nn.init.zeros_(module.bias) 233 | elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): 234 | nn.init.ones_(module.weight) 235 | nn.init.zeros_(module.bias) 236 | elif hasattr(module, 'init_weights'): 237 | # NOTE if a parent module contains init_weights method, it can override the init of the 238 | # child modules as this will be called in depth-first order. 239 | module.init_weights() 240 | 241 | 242 | def checkpoint_filter_fn(state_dict, model): 243 | """ Remap checkpoints if needed """ 244 | if 'patch_embed.proj.weight' in state_dict: 245 | # Remap FB ResMlp models -> timm 246 | out_dict = {} 247 | for k, v in state_dict.items(): 248 | k = k.replace('patch_embed.', 'stem.') 249 | k = k.replace('attn.', 'linear_tokens.') 250 | k = k.replace('mlp.', 'mlp_channels.') 251 | k = k.replace('gamma_', 'ls') 252 | if k.endswith('.alpha') or k.endswith('.beta'): 253 | v = v.reshape(1, 1, -1) 254 | out_dict[k] = v 255 | return out_dict 256 | return state_dict 257 | 258 | 259 | def _create_mixer(variant, pretrained=False, **kwargs): 260 | if kwargs.get('features_only', None): 261 | raise RuntimeError('features_only not implemented for MLP-Mixer models.') 262 | 263 | model = build_model_with_cfg( 264 | MlpMixer, variant, pretrained, 265 | default_cfg=default_cfgs[variant], 266 | pretrained_filter_fn=checkpoint_filter_fn, 267 | **kwargs) 268 | return model 269 | 270 | @register_model 271 | def mixer_t32(pretrained=False, **kwargs): 272 | """ Mixer-S/32 224x224 273 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 274 | """ 275 | model = MlpMixer(patch_size=32, num_blocks=8, embed_dim=256, **kwargs) 276 | return model 277 | 278 | 279 | @register_model 280 | def mixer_s32(pretrained=False, **kwargs): 281 | """ Mixer-S/32 224x224 282 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 283 | """ 284 | model = MlpMixer(patch_size=32, num_blocks=8, embed_dim=512, **kwargs) 285 | return model 286 | 287 | 288 | @register_model 289 | def mixer_s16(pretrained=False, **kwargs): 290 | """ Mixer-S/16 224x224 291 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 292 | """ 293 | model = MlpMixer(patch_size=16, num_blocks=8, embed_dim=512, **kwargs) 294 | return model 295 | 296 | 297 | @register_model 298 | def mixer_b32(pretrained=False, **kwargs): 299 | """ Mixer-B/32 224x224 300 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 301 | """ 302 | model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs) 303 | model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args) 304 | return model 305 | 306 | 307 | @register_model 308 | def mixer_b16(pretrained=False, **kwargs): 309 | """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. 310 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 311 | """ 312 | model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) 313 | model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args) 314 | return model 315 | 316 | -------------------------------------------------------------------------------- /image_classifiers/models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | from functools import partial 11 | 12 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from timm.models.helpers import load_pretrained 14 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 15 | from timm.models.resnet import resnet26d, resnet50d 16 | from timm.models.registry import register_model 17 | 18 | 19 | def _cfg(url='', **kwargs): 20 | return { 21 | 'url': url, 22 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 23 | 'crop_pct': .9, 'interpolation': 'bicubic', 24 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 25 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 26 | **kwargs 27 | } 28 | 29 | 30 | default_cfgs = { 31 | # patch models 32 | 'vit_small_patch16_224': _cfg( 33 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', 34 | ), 35 | 'vit_base_patch16_224': _cfg( 36 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', 37 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 38 | ), 39 | 'vit_base_patch16_384': _cfg( 40 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', 41 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 42 | 'vit_base_patch32_384': _cfg( 43 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', 44 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 45 | 'vit_large_patch16_224': _cfg( 46 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', 47 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 48 | 'vit_large_patch16_384': _cfg( 49 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', 50 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 51 | 'vit_large_patch32_384': _cfg( 52 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 53 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 54 | 'vit_huge_patch16_224': _cfg(), 55 | 'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)), 56 | # hybrid models 57 | 'vit_small_resnet26d_224': _cfg(), 58 | 'vit_small_resnet50d_s3_224': _cfg(), 59 | 'vit_base_resnet26d_224': _cfg(), 60 | 'vit_base_resnet50d_224': _cfg(), 61 | } 62 | 63 | 64 | class Mlp(nn.Module): 65 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 66 | super().__init__() 67 | out_features = out_features or in_features 68 | hidden_features = hidden_features or in_features 69 | self.fc1 = nn.Linear(in_features, hidden_features) 70 | self.act = act_layer() 71 | self.fc2 = nn.Linear(hidden_features, out_features) 72 | self.drop = nn.Dropout(drop) 73 | 74 | def forward(self, x): 75 | x = self.fc1(x) 76 | x = self.act(x) 77 | x = self.drop(x) 78 | x = self.fc2(x) 79 | x = self.drop(x) 80 | return x 81 | 82 | 83 | class Attention(nn.Module): 84 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 85 | super().__init__() 86 | self.num_heads = num_heads 87 | head_dim = dim // num_heads 88 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 89 | self.scale = qk_scale or head_dim ** -0.5 90 | 91 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 92 | self.attn_drop = nn.Dropout(attn_drop) 93 | self.proj = nn.Linear(dim, dim) 94 | self.proj_drop = nn.Dropout(proj_drop) 95 | 96 | def forward(self, x): 97 | B, N, C = x.shape 98 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 100 | 101 | attn = (q @ k.transpose(-2, -1)) * self.scale 102 | attn = attn.softmax(dim=-1) 103 | attn = self.attn_drop(attn) 104 | 105 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 106 | x = self.proj(x) 107 | x = self.proj_drop(x) 108 | return x 109 | 110 | 111 | class Block(nn.Module): 112 | 113 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 114 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 115 | super().__init__() 116 | self.norm1 = norm_layer(dim) 117 | self.attn = Attention( 118 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 119 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 120 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 121 | self.norm2 = norm_layer(dim) 122 | mlp_hidden_dim = int(dim * mlp_ratio) 123 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 124 | 125 | def forward(self, x): 126 | x = x + self.drop_path(self.attn(self.norm1(x))) 127 | x = x + self.drop_path(self.mlp(self.norm2(x))) 128 | return x 129 | 130 | 131 | class PatchEmbed(nn.Module): 132 | """ Image to Patch Embedding 133 | """ 134 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 135 | super().__init__() 136 | img_size = to_2tuple(img_size) 137 | patch_size = to_2tuple(patch_size) 138 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 139 | self.img_size = img_size 140 | self.patch_size = patch_size 141 | self.num_patches = num_patches 142 | 143 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 144 | 145 | def forward(self, x): 146 | B, C, H, W = x.shape 147 | # FIXME look at relaxing size constraints 148 | assert H == self.img_size[0] and W == self.img_size[1], \ 149 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 150 | x = self.proj(x).flatten(2).transpose(1, 2) 151 | return x 152 | 153 | 154 | class HybridEmbed(nn.Module): 155 | """ CNN Feature Map Embedding 156 | Extract feature map from CNN, flatten, project to embedding dim. 157 | """ 158 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 159 | super().__init__() 160 | assert isinstance(backbone, nn.Module) 161 | img_size = to_2tuple(img_size) 162 | self.img_size = img_size 163 | self.backbone = backbone 164 | if feature_size is None: 165 | with torch.no_grad(): 166 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 167 | # map for all networks, the feature metadata has reliable channel and stride info, but using 168 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 169 | training = backbone.training 170 | if training: 171 | backbone.eval() 172 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 173 | feature_size = o.shape[-2:] 174 | feature_dim = o.shape[1] 175 | backbone.train(training) 176 | else: 177 | feature_size = to_2tuple(feature_size) 178 | feature_dim = self.backbone.feature_info.channels()[-1] 179 | self.num_patches = feature_size[0] * feature_size[1] 180 | self.proj = nn.Linear(feature_dim, embed_dim) 181 | 182 | def forward(self, x): 183 | x = self.backbone(x)[-1] 184 | x = x.flatten(2).transpose(1, 2) 185 | x = self.proj(x) 186 | return x 187 | 188 | 189 | class VisionTransformer(nn.Module): 190 | """ Vision Transformer with support for patch or hybrid CNN input stage 191 | """ 192 | 193 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 194 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 195 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): 196 | super().__init__() 197 | self.num_classes = num_classes 198 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 199 | # I add these two lines 200 | self.drop_rate=drop_rate 201 | attn_drop_rate=drop_rate 202 | if hybrid_backbone is not None: 203 | self.patch_embed = HybridEmbed( 204 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 205 | else: 206 | self.patch_embed = PatchEmbed( 207 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 208 | num_patches = self.patch_embed.num_patches 209 | 210 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 211 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 212 | self.pos_drop = nn.Dropout(p=drop_rate) 213 | self.depth = depth 214 | 215 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 216 | self.blocks = nn.ModuleList([ 217 | Block( 218 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 219 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 220 | for i in range(depth)]) 221 | self.norm = norm_layer(embed_dim) 222 | 223 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 224 | #self.repr = nn.Linear(embed_dim, representation_size) 225 | #self.repr_act = nn.Tanh() 226 | 227 | # Classifier head 228 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 229 | 230 | trunc_normal_(self.pos_embed, std=.02) 231 | trunc_normal_(self.cls_token, std=.02) 232 | self.apply(self._init_weights) 233 | 234 | def _init_weights(self, m): 235 | if isinstance(m, nn.Linear): 236 | trunc_normal_(m.weight, std=.02) 237 | if isinstance(m, nn.Linear) and m.bias is not None: 238 | nn.init.constant_(m.bias, 0) 239 | elif isinstance(m, nn.LayerNorm): 240 | nn.init.constant_(m.bias, 0) 241 | nn.init.constant_(m.weight, 1.0) 242 | 243 | @torch.jit.ignore 244 | def no_weight_decay(self): 245 | return {'pos_embed', 'cls_token'} 246 | 247 | def get_classifier(self): 248 | return self.head 249 | 250 | def reset_classifier(self, num_classes, global_pool=''): 251 | self.num_classes = num_classes 252 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 253 | 254 | def forward_features(self, x): 255 | B = x.shape[0] 256 | x = self.patch_embed(x) 257 | 258 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 259 | x = torch.cat((cls_tokens, x), dim=1) 260 | x = x + self.pos_embed 261 | x = self.pos_drop(x) 262 | 263 | for blk in self.blocks: 264 | x = blk(x) 265 | 266 | x = self.norm(x) 267 | return x[:, 0] 268 | 269 | def forward(self, x): 270 | x = self.forward_features(x) 271 | x = self.head(x) 272 | return x 273 | 274 | def update_drop_path(self, drop_path_rate): 275 | self.drop_path = drop_path_rate 276 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, self.depth)] 277 | for i in range(self.depth): 278 | self.blocks[i].drop_path.drop_prob = dp_rates[i] 279 | 280 | def update_dropout(self, drop_rate): 281 | self.drop_rate = drop_rate 282 | for module in self.modules(): 283 | if isinstance(module, nn.Dropout): 284 | module.p = drop_rate 285 | 286 | 287 | def _conv_filter(state_dict, patch_size=16): 288 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 289 | out_dict = {} 290 | for k, v in state_dict.items(): 291 | if 'patch_embed.proj.weight' in k: 292 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 293 | out_dict[k] = v 294 | return out_dict 295 | 296 | @register_model 297 | def vit_tiny(pretrained=False, **kwargs): 298 | model = VisionTransformer( 299 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 300 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 301 | return model 302 | 303 | @register_model 304 | def vit_small(pretrained=False, **kwargs): 305 | model = VisionTransformer( 306 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 307 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 308 | return model 309 | 310 | @register_model 311 | def vit_base(pretrained=False, **kwargs): 312 | model = VisionTransformer( 313 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 314 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 315 | return model 316 | 317 | @register_model 318 | def vit_large(pretrained=False, **kwargs): 319 | model = VisionTransformer( 320 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 321 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 322 | return model -------------------------------------------------------------------------------- /image_classifiers/optim_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import optim as optim 10 | 11 | from timm.optim.adafactor import Adafactor 12 | from timm.optim.adahessian import Adahessian 13 | from timm.optim.adamp import AdamP 14 | from timm.optim.lookahead import Lookahead 15 | from timm.optim.nadam import Nadam 16 | # from timm.optim.novograd import NovoGrad 17 | from timm.optim.nvnovograd import NvNovoGrad 18 | from timm.optim.radam import RAdam 19 | from timm.optim.rmsprop_tf import RMSpropTF 20 | from timm.optim.sgdp import SGDP 21 | 22 | import json 23 | 24 | try: 25 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 26 | has_apex = True 27 | except ImportError: 28 | has_apex = False 29 | 30 | 31 | def get_num_layer_for_convnext(var_name): 32 | """ 33 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three 34 | consecutive blocks, including possible neighboring downsample layers; 35 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py 36 | """ 37 | num_max_layer = 12 38 | if var_name.startswith("downsample_layers"): 39 | stage_id = int(var_name.split('.')[1]) 40 | if stage_id == 0: 41 | layer_id = 0 42 | elif stage_id == 1 or stage_id == 2: 43 | layer_id = stage_id + 1 44 | elif stage_id == 3: 45 | layer_id = 12 46 | return layer_id 47 | 48 | elif var_name.startswith("stages"): 49 | stage_id = int(var_name.split('.')[1]) 50 | block_id = int(var_name.split('.')[2]) 51 | if stage_id == 0 or stage_id == 1: 52 | layer_id = stage_id + 1 53 | elif stage_id == 2: 54 | layer_id = 3 + block_id // 3 55 | elif stage_id == 3: 56 | layer_id = 12 57 | return layer_id 58 | else: 59 | return num_max_layer + 1 60 | 61 | class LayerDecayValueAssigner(object): 62 | def __init__(self, values): 63 | self.values = values 64 | 65 | def get_scale(self, layer_id): 66 | return self.values[layer_id] 67 | 68 | def get_layer_id(self, var_name): 69 | return get_num_layer_for_convnext(var_name) 70 | 71 | 72 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 73 | parameter_group_names = {} 74 | parameter_group_vars = {} 75 | 76 | for name, param in model.named_parameters(): 77 | if not param.requires_grad: 78 | continue # frozen weights 79 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 80 | group_name = "no_decay" 81 | this_weight_decay = 0. 82 | else: 83 | group_name = "decay" 84 | this_weight_decay = weight_decay 85 | if get_num_layer is not None: 86 | layer_id = get_num_layer(name) 87 | group_name = "layer_%d_%s" % (layer_id, group_name) 88 | else: 89 | layer_id = None 90 | 91 | if group_name not in parameter_group_names: 92 | if get_layer_scale is not None: 93 | scale = get_layer_scale(layer_id) 94 | else: 95 | scale = 1. 96 | 97 | parameter_group_names[group_name] = { 98 | "weight_decay": this_weight_decay, 99 | "params": [], 100 | "lr_scale": scale 101 | } 102 | parameter_group_vars[group_name] = { 103 | "weight_decay": this_weight_decay, 104 | "params": [], 105 | "lr_scale": scale 106 | } 107 | 108 | parameter_group_vars[group_name]["params"].append(param) 109 | parameter_group_names[group_name]["params"].append(name) 110 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 111 | return list(parameter_group_vars.values()) 112 | 113 | 114 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 115 | opt_lower = args.opt.lower() 116 | weight_decay = args.weight_decay 117 | # if weight_decay and filter_bias_and_bn: 118 | if filter_bias_and_bn: 119 | skip = {} 120 | if skip_list is not None: 121 | skip = skip_list 122 | elif hasattr(model, 'no_weight_decay'): 123 | skip = model.no_weight_decay() 124 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 125 | weight_decay = 0. 126 | else: 127 | parameters = model.parameters() 128 | 129 | if 'fused' in opt_lower: 130 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 131 | 132 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 133 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 134 | opt_args['eps'] = args.opt_eps 135 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 136 | opt_args['betas'] = args.opt_betas 137 | 138 | opt_split = opt_lower.split('_') 139 | opt_lower = opt_split[-1] 140 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 141 | opt_args.pop('eps', None) 142 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 143 | elif opt_lower == 'momentum': 144 | opt_args.pop('eps', None) 145 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 146 | elif opt_lower == 'adam': 147 | optimizer = optim.Adam(parameters, **opt_args) 148 | elif opt_lower == 'adamw': 149 | optimizer = optim.AdamW(parameters, **opt_args) 150 | elif opt_lower == 'nadam': 151 | optimizer = Nadam(parameters, **opt_args) 152 | elif opt_lower == 'radam': 153 | optimizer = RAdam(parameters, **opt_args) 154 | elif opt_lower == 'adamp': 155 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 156 | elif opt_lower == 'sgdp': 157 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 158 | elif opt_lower == 'adadelta': 159 | optimizer = optim.Adadelta(parameters, **opt_args) 160 | elif opt_lower == 'adafactor': 161 | if not args.lr: 162 | opt_args['lr'] = None 163 | optimizer = Adafactor(parameters, **opt_args) 164 | elif opt_lower == 'adahessian': 165 | optimizer = Adahessian(parameters, **opt_args) 166 | elif opt_lower == 'rmsprop': 167 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 168 | elif opt_lower == 'rmsproptf': 169 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 170 | elif opt_lower == 'novograd': 171 | optimizer = NovoGrad(parameters, **opt_args) 172 | elif opt_lower == 'nvnovograd': 173 | optimizer = NvNovoGrad(parameters, **opt_args) 174 | elif opt_lower == 'fusedsgd': 175 | opt_args.pop('eps', None) 176 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 177 | elif opt_lower == 'fusedmomentum': 178 | opt_args.pop('eps', None) 179 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 180 | elif opt_lower == 'fusedadam': 181 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 182 | elif opt_lower == 'fusedadamw': 183 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 184 | elif opt_lower == 'fusedlamb': 185 | optimizer = FusedLAMB(parameters, **opt_args) 186 | elif opt_lower == 'fusednovograd': 187 | opt_args.setdefault('betas', (0.95, 0.98)) 188 | optimizer = FusedNovoGrad(parameters, **opt_args) 189 | else: 190 | assert False and "Invalid optimizer" 191 | 192 | if len(opt_split) > 1: 193 | if opt_split[0] == 'lookahead': 194 | optimizer = Lookahead(optimizer) 195 | 196 | return optimizer 197 | -------------------------------------------------------------------------------- /image_classifiers/prune_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layerwrapper import WrappedLayer 4 | 5 | def find_layers(module, layers=[nn.Linear], name=''): 6 | if type(module) in layers: 7 | return {name: module} 8 | res = {} 9 | for name1, child in module.named_children(): 10 | res.update(find_layers( 11 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 12 | )) 13 | return res 14 | 15 | def check_sparsity(model): 16 | subset = find_layers(model, layers=[nn.Linear]) 17 | zero_cnt = 0 18 | fc_params = 0 19 | for name in subset: 20 | W = subset[name].weight.data 21 | if W.shape[0] == 1000: 22 | continue 23 | zero_cnt += (W==0).sum().item() 24 | fc_params += W.numel() 25 | return float(zero_cnt) / fc_params 26 | 27 | def compute_mask(W_metric, prune_granularity, sparsity): 28 | if prune_granularity == "layer": 29 | thres = torch.sort(W_metric.flatten().cuda())[0][int(W_metric.numel() * sparsity)].cpu() 30 | W_mask = (W_metric <= thres) 31 | return W_mask 32 | elif prune_granularity == "row": 33 | W_mask = (torch.zeros_like(W_metric)==1) 34 | sort_res = torch.sort(W_metric, dim=-1, stable=True) 35 | 36 | indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity)] 37 | W_mask.scatter_(1, indices, True) 38 | return W_mask 39 | 40 | def prune_deit(args, model, calib_data, device): 41 | inps = calib_data 42 | bs = inps.shape[0] 43 | require_forward = (args.prune_metric in ["wanda"]) 44 | 45 | metric_stats = [] 46 | for blk in model.blocks: 47 | subset = find_layers(blk) 48 | res_per_layer = {} 49 | for name in subset: 50 | res_per_layer[name] = torch.abs(subset[name].weight.data) 51 | metric_stats.append(res_per_layer) 52 | 53 | thresh = None 54 | ##################################### 55 | inps = model.patch_embed(inps) 56 | 57 | cls_tokens = model.cls_token.expand(bs, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 58 | dist_token = model.dist_token.expand(bs, -1, -1) 59 | inps = torch.cat((cls_tokens, dist_token, inps), dim=1) 60 | 61 | inps = inps + model.pos_embed 62 | inps = model.pos_drop(inps) 63 | 64 | for block_id, blk in enumerate(model.blocks): 65 | subset = find_layers(blk) 66 | 67 | if require_forward: 68 | wrapped_layers = {} 69 | for name in subset: 70 | wrapped_layers[name] = WrappedLayer(subset[name]) 71 | 72 | def add_batch(name): 73 | def tmp(_, inp, out): 74 | wrapped_layers[name].add_batch(inp[0].data, out.data) 75 | return tmp 76 | 77 | handles = [] 78 | for name in wrapped_layers: 79 | handles.append(subset[name].register_forward_hook(add_batch(name))) 80 | 81 | if bs > 256: 82 | tmp_res = [] 83 | for i1 in range(0, bs, 256): 84 | j1 = min(i1+256, bs) 85 | tmp_res.append(blk(inps[i1:j1])) 86 | inps = torch.cat(tmp_res, dim=0) 87 | else: 88 | inps = blk(inps) 89 | 90 | for h in handles: 91 | h.remove() 92 | 93 | ################# pruning ################### 94 | for name in subset: 95 | if args.prune_metric == "wanda": 96 | metric_stats[block_id][name] *= torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1))) 97 | 98 | W_mask = compute_mask(metric_stats[block_id][name], args.prune_granularity, args.sparsity) 99 | 100 | subset[name].weight.data[W_mask] = 0 101 | 102 | def prune_vit(args, model, calib_data, device): 103 | inps = calib_data 104 | bs = inps.shape[0] 105 | require_forward = (args.prune_metric in ["wanda"]) 106 | 107 | metric_stats = [] 108 | for blk in model.blocks: 109 | subset = find_layers(blk) 110 | res_per_layer = {} 111 | for name in subset: 112 | res_per_layer[name] = torch.abs(subset[name].weight.data) 113 | metric_stats.append(res_per_layer) 114 | 115 | thresh = None 116 | ##################################### 117 | inps = model.patch_embed(inps) 118 | 119 | cls_tokens = model.cls_token.expand(bs, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 120 | inps = torch.cat((cls_tokens, inps), dim=1) 121 | inps = inps + model.pos_embed 122 | inps = model.pos_drop(inps) 123 | 124 | for block_id, blk in enumerate(model.blocks): 125 | print(f"block {block_id}") 126 | subset = find_layers(blk) 127 | 128 | if require_forward: 129 | wrapped_layers = {} 130 | for name in subset: 131 | wrapped_layers[name] = WrappedLayer(subset[name]) 132 | 133 | def add_batch(name): 134 | def tmp(_, inp, out): 135 | wrapped_layers[name].add_batch(inp[0].data, out.data) 136 | return tmp 137 | 138 | handles = [] 139 | for name in wrapped_layers: 140 | handles.append(subset[name].register_forward_hook(add_batch(name))) 141 | 142 | if bs > 256: 143 | tmp_res = [] 144 | for i1 in range(0, bs, 256): 145 | j1 = min(i1+256, bs) 146 | tmp_res.append(blk(inps[i1:j1])) 147 | inps = torch.cat(tmp_res, dim=0) 148 | else: 149 | inps = blk(inps) 150 | 151 | for h in handles: 152 | h.remove() 153 | 154 | ################# pruning ################### 155 | for name in subset: 156 | if args.prune_metric == "wanda": 157 | metric_stats[block_id][name] *= torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1))) 158 | 159 | W_mask = compute_mask(metric_stats[block_id][name], args.prune_granularity, args.sparsity) 160 | 161 | subset[name].weight.data[W_mask] = 0 162 | ############################################## 163 | 164 | def prune_convnext(args, model, calib_data, device): 165 | inps = calib_data 166 | bs = inps.shape[0] 167 | require_forward = (args.prune_metric in ["wanda"]) 168 | 169 | ############################################################## 170 | metric_stats = [] 171 | for block_id in range(4): 172 | subset = find_layers(model.stages[block_id]) 173 | res_per_layer = {} 174 | for name in subset: 175 | res_per_layer[name] = torch.abs(subset[name].weight.data) 176 | metric_stats.append(res_per_layer) 177 | ############################################################## 178 | 179 | thresh = None 180 | for block_id in range(4): 181 | print(f"block {block_id}") 182 | subset = find_layers(model.stages[block_id]) 183 | 184 | if require_forward: 185 | layer = model.downsample_layers[block_id] 186 | if bs > 1024: 187 | tmp_res = [] 188 | for i1 in range(0, bs, 512): 189 | j1 = min(i1+512, bs) 190 | tmp_res.append(layer(inps[i1:j1])) 191 | inps = torch.cat(tmp_res, dim=0) 192 | else: 193 | inps = layer(inps) 194 | 195 | wrapped_layers = {} 196 | for name in subset: 197 | wrapped_layers[name] = WrappedLayer(subset[name]) 198 | 199 | def add_batch(name): 200 | def tmp(_, inp, out): 201 | wrapped_layers[name].add_batch(inp[0].data, out.data) 202 | return tmp 203 | 204 | handles = [] 205 | for name in wrapped_layers: 206 | handles.append(subset[name].register_forward_hook(add_batch(name))) 207 | layer = model.stages[block_id] 208 | if bs > 1024: 209 | tmp_res = [] 210 | for i1 in range(0, bs, 512): 211 | j1 = min(i1+512, bs) 212 | tmp_res.append(layer(inps[i1:j1])) 213 | inps = torch.cat(tmp_res, dim=0) 214 | else: 215 | inps = layer(inps) 216 | for h in handles: 217 | h.remove() 218 | 219 | ################# pruning ################### 220 | for name in subset: 221 | if args.prune_metric == "wanda": 222 | metric_stats[block_id][name] *= torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1))) 223 | 224 | W_mask = compute_mask(metric_stats[block_id][name], args.prune_granularity, args.sparsity) 225 | 226 | subset[name].weight.data[W_mask] = 0 227 | ############################################## -------------------------------------------------------------------------------- /image_classifiers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import math 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | import numpy as np 14 | from timm.utils import get_state_dict 15 | 16 | from pathlib import Path 17 | from timm.models import create_model 18 | import torch 19 | import torch.distributed as dist 20 | from torch._six import inf 21 | 22 | class SmoothedValue(object): 23 | """Track a series of values and provide access to smoothed values over a 24 | window or the global series average. 25 | """ 26 | 27 | def __init__(self, window_size=20, fmt=None): 28 | if fmt is None: 29 | fmt = "{median:.4f} ({global_avg:.4f})" 30 | self.deque = deque(maxlen=window_size) 31 | self.total = 0.0 32 | self.count = 0 33 | self.fmt = fmt 34 | 35 | def update(self, value, n=1): 36 | self.deque.append(value) 37 | self.count += n 38 | self.total += value * n 39 | 40 | def synchronize_between_processes(self): 41 | """ 42 | Warning: does not synchronize the deque! 43 | """ 44 | if not is_dist_avail_and_initialized(): 45 | return 46 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 47 | dist.barrier() 48 | dist.all_reduce(t) 49 | t = t.tolist() 50 | self.count = int(t[0]) 51 | self.total = t[1] 52 | 53 | @property 54 | def median(self): 55 | d = torch.tensor(list(self.deque)) 56 | return d.median().item() 57 | 58 | @property 59 | def avg(self): 60 | d = torch.tensor(list(self.deque), dtype=torch.float32) 61 | return d.mean().item() 62 | 63 | @property 64 | def global_avg(self): 65 | return self.total / self.count 66 | 67 | @property 68 | def max(self): 69 | return max(self.deque) 70 | 71 | @property 72 | def value(self): 73 | return self.deque[-1] 74 | 75 | def __str__(self): 76 | return self.fmt.format( 77 | median=self.median, 78 | avg=self.avg, 79 | global_avg=self.global_avg, 80 | max=self.max, 81 | value=self.value) 82 | 83 | 84 | class MetricLogger(object): 85 | def __init__(self, delimiter="\t"): 86 | self.meters = defaultdict(SmoothedValue) 87 | self.delimiter = delimiter 88 | 89 | def update(self, **kwargs): 90 | for k, v in kwargs.items(): 91 | if v is None: 92 | continue 93 | if isinstance(v, torch.Tensor): 94 | v = v.item() 95 | assert isinstance(v, (float, int)) 96 | self.meters[k].update(v) 97 | 98 | def __getattr__(self, attr): 99 | if attr in self.meters: 100 | return self.meters[attr] 101 | if attr in self.__dict__: 102 | return self.__dict__[attr] 103 | raise AttributeError("'{}' object has no attribute '{}'".format( 104 | type(self).__name__, attr)) 105 | 106 | def __str__(self): 107 | loss_str = [] 108 | for name, meter in self.meters.items(): 109 | loss_str.append( 110 | "{}: {}".format(name, str(meter)) 111 | ) 112 | return self.delimiter.join(loss_str) 113 | 114 | def synchronize_between_processes(self): 115 | for meter in self.meters.values(): 116 | meter.synchronize_between_processes() 117 | 118 | def add_meter(self, name, meter): 119 | self.meters[name] = meter 120 | 121 | def log_every(self, iterable, print_freq, header=None): 122 | i = 0 123 | if not header: 124 | header = '' 125 | start_time = time.time() 126 | end = time.time() 127 | iter_time = SmoothedValue(fmt='{avg:.4f}') 128 | data_time = SmoothedValue(fmt='{avg:.4f}') 129 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 130 | log_msg = [ 131 | header, 132 | '[{0' + space_fmt + '}/{1}]', 133 | 'eta: {eta}', 134 | '{meters}', 135 | 'time: {time}', 136 | 'data: {data}' 137 | ] 138 | if torch.cuda.is_available(): 139 | log_msg.append('max mem: {memory:.0f}') 140 | log_msg = self.delimiter.join(log_msg) 141 | MB = 1024.0 * 1024.0 142 | for obj in iterable: 143 | data_time.update(time.time() - end) 144 | yield obj 145 | iter_time.update(time.time() - end) 146 | if i % print_freq == 0 or i == len(iterable) - 1: 147 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 148 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 149 | if torch.cuda.is_available(): 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time), 154 | memory=torch.cuda.max_memory_allocated() / MB)) 155 | else: 156 | print(log_msg.format( 157 | i, len(iterable), eta=eta_string, 158 | meters=str(self), 159 | time=str(iter_time), data=str(data_time))) 160 | i += 1 161 | end = time.time() 162 | total_time = time.time() - start_time 163 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 164 | print('{} Total time: {} ({:.4f} s / it)'.format( 165 | header, total_time_str, total_time / len(iterable))) 166 | 167 | 168 | class TensorboardLogger(object): 169 | def __init__(self, log_dir): 170 | self.writer = SummaryWriter(logdir=log_dir) 171 | self.step = 0 172 | 173 | def set_step(self, step=None): 174 | if step is not None: 175 | self.step = step 176 | else: 177 | self.step += 1 178 | 179 | def update(self, head='scalar', step=None, **kwargs): 180 | for k, v in kwargs.items(): 181 | if v is None: 182 | continue 183 | if isinstance(v, torch.Tensor): 184 | v = v.item() 185 | assert isinstance(v, (float, int)) 186 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 187 | 188 | def flush(self): 189 | self.writer.flush() 190 | 191 | 192 | class WandbLogger(object): 193 | def __init__(self, args): 194 | self.args = args 195 | 196 | try: 197 | import wandb 198 | self._wandb = wandb 199 | except ImportError: 200 | raise ImportError( 201 | "To use the Weights and Biases Logger please install wandb." 202 | "Run `pip install wandb` to install it." 203 | ) 204 | 205 | # Initialize a W&B run 206 | if self._wandb.run is None: 207 | self._wandb.init( 208 | project=args.project, 209 | config=args 210 | ) 211 | 212 | def log_epoch_metrics(self, metrics, commit=True): 213 | """ 214 | Log train/test metrics onto W&B. 215 | """ 216 | # Log number of model parameters as W&B summary 217 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None) 218 | metrics.pop('n_parameters', None) 219 | 220 | # Log current epoch 221 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False) 222 | metrics.pop('epoch') 223 | 224 | for k, v in metrics.items(): 225 | if 'train' in k: 226 | self._wandb.log({f'Global Train/{k}': v}, commit=False) 227 | elif 'test' in k: 228 | self._wandb.log({f'Global Test/{k}': v}, commit=False) 229 | 230 | self._wandb.log({}) 231 | 232 | def log_checkpoints(self): 233 | output_dir = self.args.output_dir 234 | model_artifact = self._wandb.Artifact( 235 | self._wandb.run.id + "_model", type="model" 236 | ) 237 | 238 | model_artifact.add_dir(output_dir) 239 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"]) 240 | 241 | def set_steps(self): 242 | # Set global training step 243 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step') 244 | # Set epoch-wise step 245 | self._wandb.define_metric('Global Train/*', step_metric='epoch') 246 | self._wandb.define_metric('Global Test/*', step_metric='epoch') 247 | 248 | 249 | def setup_for_distributed(is_master): 250 | """ 251 | This function disables printing when not in master process 252 | """ 253 | import builtins as __builtin__ 254 | builtin_print = __builtin__.print 255 | 256 | def print(*args, **kwargs): 257 | force = kwargs.pop('force', False) 258 | if is_master or force: 259 | builtin_print(*args, **kwargs) 260 | 261 | __builtin__.print = print 262 | 263 | 264 | def is_dist_avail_and_initialized(): 265 | if not dist.is_available(): 266 | return False 267 | if not dist.is_initialized(): 268 | return False 269 | return True 270 | 271 | 272 | def get_world_size(): 273 | if not is_dist_avail_and_initialized(): 274 | return 1 275 | return dist.get_world_size() 276 | 277 | 278 | def get_rank(): 279 | if not is_dist_avail_and_initialized(): 280 | return 0 281 | return dist.get_rank() 282 | 283 | 284 | def is_main_process(): 285 | return get_rank() == 0 286 | 287 | 288 | def save_on_master(*args, **kwargs): 289 | if is_main_process(): 290 | torch.save(*args, **kwargs) 291 | 292 | 293 | def init_distributed_mode(args): 294 | 295 | if args.dist_on_itp: 296 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 297 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 298 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 299 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 300 | os.environ['LOCAL_RANK'] = str(args.gpu) 301 | os.environ['RANK'] = str(args.rank) 302 | os.environ['WORLD_SIZE'] = str(args.world_size) 303 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 304 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 305 | args.rank = int(os.environ["RANK"]) 306 | args.world_size = int(os.environ['WORLD_SIZE']) 307 | args.gpu = int(os.environ['LOCAL_RANK']) 308 | elif 'SLURM_PROCID' in os.environ: 309 | args.rank = int(os.environ['SLURM_PROCID']) 310 | args.gpu = args.rank % torch.cuda.device_count() 311 | 312 | os.environ['RANK'] = str(args.rank) 313 | os.environ['LOCAL_RANK'] = str(args.gpu) 314 | os.environ['WORLD_SIZE'] = str(args.world_size) 315 | else: 316 | print('Not using distributed mode') 317 | args.distributed = False 318 | return 319 | 320 | args.distributed = True 321 | 322 | torch.cuda.set_device(args.gpu) 323 | args.dist_backend = 'nccl' 324 | print('| distributed init (rank {}): {}, gpu {}'.format( 325 | args.rank, args.dist_url, args.gpu), flush=True) 326 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 327 | world_size=args.world_size, rank=args.rank) 328 | torch.distributed.barrier() 329 | setup_for_distributed(args.rank == 0) 330 | 331 | 332 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 333 | missing_keys = [] 334 | unexpected_keys = [] 335 | error_msgs = [] 336 | # copy state_dict so _load_from_state_dict can modify it 337 | metadata = getattr(state_dict, '_metadata', None) 338 | state_dict = state_dict.copy() 339 | if metadata is not None: 340 | state_dict._metadata = metadata 341 | 342 | def load(module, prefix=''): 343 | local_metadata = {} if metadata is None else metadata.get( 344 | prefix[:-1], {}) 345 | module._load_from_state_dict( 346 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 347 | for name, child in module._modules.items(): 348 | if child is not None: 349 | load(child, prefix + name + '.') 350 | 351 | load(model, prefix=prefix) 352 | 353 | warn_missing_keys = [] 354 | ignore_missing_keys = [] 355 | for key in missing_keys: 356 | keep_flag = True 357 | for ignore_key in ignore_missing.split('|'): 358 | if ignore_key in key: 359 | keep_flag = False 360 | break 361 | if keep_flag: 362 | warn_missing_keys.append(key) 363 | else: 364 | ignore_missing_keys.append(key) 365 | 366 | missing_keys = warn_missing_keys 367 | 368 | if len(missing_keys) > 0: 369 | print("Weights of {} not initialized from pretrained model: {}".format( 370 | model.__class__.__name__, missing_keys)) 371 | if len(unexpected_keys) > 0: 372 | print("Weights from pretrained model not used in {}: {}".format( 373 | model.__class__.__name__, unexpected_keys)) 374 | if len(ignore_missing_keys) > 0: 375 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 376 | model.__class__.__name__, ignore_missing_keys)) 377 | if len(error_msgs) > 0: 378 | print('\n'.join(error_msgs)) 379 | 380 | 381 | class NativeScalerWithGradNormCount: 382 | state_dict_key = "amp_scaler" 383 | 384 | def __init__(self): 385 | self._scaler = torch.cuda.amp.GradScaler() 386 | 387 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 388 | self._scaler.scale(loss).backward(create_graph=create_graph) 389 | 390 | ######################################################## 391 | ## Code I added 392 | for param in parameters: 393 | weight_copy = param.data.abs().clone() 394 | mask = weight_copy.gt(0).float().cuda() 395 | sparsity = mask.sum() / mask.numel() 396 | if sparsity > 0.3: 397 | # non-trivial sparsity 398 | param.grad.data.mul_(mask) 399 | ######################################################## 400 | 401 | if update_grad: 402 | if clip_grad is not None: 403 | assert parameters is not None 404 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 405 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 406 | else: 407 | self._scaler.unscale_(optimizer) 408 | norm = get_grad_norm_(parameters) 409 | self._scaler.step(optimizer) 410 | self._scaler.update() 411 | else: 412 | norm = None 413 | return norm 414 | 415 | def state_dict(self): 416 | return self._scaler.state_dict() 417 | 418 | def load_state_dict(self, state_dict): 419 | self._scaler.load_state_dict(state_dict) 420 | 421 | 422 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 423 | if isinstance(parameters, torch.Tensor): 424 | parameters = [parameters] 425 | parameters = [p for p in parameters if p.grad is not None] 426 | norm_type = float(norm_type) 427 | if len(parameters) == 0: 428 | return torch.tensor(0.) 429 | device = parameters[0].grad.device 430 | if norm_type == inf: 431 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 432 | else: 433 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 434 | return total_norm 435 | 436 | 437 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 438 | start_warmup_value=0, warmup_steps=-1): 439 | warmup_schedule = np.array([]) 440 | warmup_iters = warmup_epochs * niter_per_ep 441 | if warmup_steps > 0: 442 | warmup_iters = warmup_steps 443 | print("Set warmup steps = %d" % warmup_iters) 444 | if warmup_epochs > 0: 445 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 446 | 447 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 448 | schedule = np.array( 449 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 450 | 451 | schedule = np.concatenate((warmup_schedule, schedule)) 452 | 453 | assert len(schedule) == epochs * niter_per_ep 454 | return schedule 455 | 456 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 457 | output_dir = Path(args.output_dir) 458 | epoch_name = str(epoch) 459 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 460 | for checkpoint_path in checkpoint_paths: 461 | to_save = { 462 | 'model': model_without_ddp.state_dict(), 463 | 'optimizer': optimizer.state_dict(), 464 | 'epoch': epoch, 465 | 'scaler': loss_scaler.state_dict(), 466 | 'args': args, 467 | } 468 | 469 | if model_ema is not None: 470 | to_save['model_ema'] = get_state_dict(model_ema) 471 | 472 | save_on_master(to_save, checkpoint_path) 473 | 474 | if is_main_process() and isinstance(epoch, int): 475 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq 476 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del) 477 | if os.path.exists(old_ckpt): 478 | os.remove(old_ckpt) 479 | 480 | 481 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 482 | output_dir = Path(args.output_dir) 483 | if args.auto_resume and len(args.resume) == 0: 484 | import glob 485 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 486 | latest_ckpt = -1 487 | for ckpt in all_checkpoints: 488 | t = ckpt.split('-')[-1].split('.')[0] 489 | if t.isdigit(): 490 | latest_ckpt = max(int(t), latest_ckpt) 491 | if latest_ckpt >= 0: 492 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 493 | print("Auto resume checkpoint: %s" % args.resume) 494 | 495 | if args.resume: 496 | if args.resume.startswith('https'): 497 | checkpoint = torch.hub.load_state_dict_from_url( 498 | args.resume, map_location='cpu', check_hash=True) 499 | else: 500 | checkpoint = torch.load(args.resume, map_location='cpu') 501 | model_without_ddp.load_state_dict(checkpoint['model']) 502 | print("Resume checkpoint %s" % args.resume) 503 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 504 | optimizer.load_state_dict(checkpoint['optimizer']) 505 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema' 506 | args.start_epoch = checkpoint['epoch'] + 1 507 | else: 508 | assert args.eval, 'Does not support resuming with checkpoint-best' 509 | if hasattr(args, 'model_ema') and args.model_ema: 510 | if 'model_ema' in checkpoint.keys(): 511 | model_ema.ema.load_state_dict(checkpoint['model_ema']) 512 | else: 513 | model_ema.ema.load_state_dict(checkpoint['model']) 514 | if 'scaler' in checkpoint: 515 | loss_scaler.load_state_dict(checkpoint['scaler']) 516 | print("With optim & sched!") 517 | 518 | def reg_scheduler(base_value, final_value, epochs, niter_per_ep, early_epochs=0, early_value=None, 519 | mode='linear', early_mode='regular'): 520 | early_schedule = np.array([]) 521 | early_iters = early_epochs * niter_per_ep 522 | if early_value is None: 523 | early_value = final_value 524 | if early_epochs > 0: 525 | print(f"Set early value to {early_mode} {early_value}") 526 | if early_mode == 'regular': 527 | early_schedule = np.array([early_value] * early_iters) 528 | elif early_mode == 'linear': 529 | early_schedule = np.linspace(early_value, base_value, early_iters) 530 | elif early_mode == 'cosine': 531 | early_schedule = np.array( 532 | [base_value + 0.5 * (early_value - base_value) * (1 + math.cos(math.pi * i / early_iters)) for i in np.arange(early_iters)]) 533 | regular_epochs = epochs - early_epochs 534 | iters = np.arange(regular_epochs * niter_per_ep) 535 | schedule = np.linspace(base_value, final_value, len(iters)) 536 | schedule = np.concatenate((early_schedule, schedule)) 537 | 538 | assert len(schedule) == epochs * niter_per_ep 539 | return schedule 540 | 541 | def build_model(args, pretrained=False): 542 | if args.model.startswith("convnext"): 543 | model = create_model( 544 | args.model, 545 | pretrained=pretrained, 546 | num_classes=args.nb_classes, 547 | layer_scale_init_value=args.layer_scale_init_value, 548 | head_init_scale=args.head_init_scale, 549 | drop_path_rate=args.drop_path, 550 | drop_rate=args.dropout, 551 | ) 552 | else: 553 | model = create_model( 554 | args.model, 555 | pretrained=pretrained, 556 | num_classes=args.nb_classes, 557 | drop_path_rate=args.drop_path, 558 | drop_rate =args.dropout 559 | ) 560 | return model -------------------------------------------------------------------------------- /lib/ablate.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | torch.backends.cuda.matmul.allow_tf32 = False 9 | torch.backends.cudnn.allow_tf32 = False 10 | 11 | class AblateGPT: 12 | 13 | def __init__(self, layer): 14 | self.layer = layer 15 | self.dev = self.layer.weight.device 16 | W = layer.weight.data.clone() 17 | if isinstance(self.layer, nn.Conv2d): 18 | W = W.flatten(1) 19 | if isinstance(self.layer, transformers.Conv1D): 20 | W = W.t() 21 | self.rows = W.shape[0] 22 | self.columns = W.shape[1] 23 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 24 | self.nsamples = 0 25 | 26 | self.scaler_row = torch.zeros((self.columns), device=self.dev) 27 | 28 | def add_batch(self, inp, out): 29 | if len(inp.shape) == 2: 30 | inp = inp.unsqueeze(0) 31 | tmp = inp.shape[0] 32 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 33 | if len(inp.shape) == 3: 34 | inp = inp.reshape((-1, inp.shape[-1])) 35 | inp = inp.t() 36 | self.H *= self.nsamples / (self.nsamples + tmp) 37 | 38 | self.scaler_row *= self.nsamples / (self.nsamples+tmp) 39 | 40 | self.nsamples += tmp 41 | inp = math.sqrt(2 / self.nsamples) * inp.float() 42 | self.H += inp.matmul(inp.t()) 43 | self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples 44 | 45 | def get_wanda_mask(self, sparsity, prunen, prunem): 46 | W_metric = torch.abs(self.layer.weight.data) * torch.sqrt(self.scaler_row.reshape((1,-1))) 47 | W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False 48 | if prunen != 0: 49 | for ii in range(W_metric.shape[1]): 50 | if ii % prunem == 0: 51 | tmp = W_metric[:,ii:(ii+prunem)].float() 52 | W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True) 53 | else: 54 | sort_res = torch.sort(W_metric, dim=-1, stable=True) 55 | indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity)] 56 | W_mask.scatter_(1, indices, True) 57 | 58 | return W_mask 59 | 60 | def get_mag_mask(self, sparsity, prunen, prunem): 61 | W = self.layer.weight.data 62 | W_metric = torch.abs(W) 63 | if prunen != 0: 64 | W_mask = (torch.zeros_like(W)==1) 65 | for ii in range(W_metric.shape[1]): 66 | if ii % prunem == 0: 67 | tmp = W_metric[:,ii:(ii+prunem)].float() 68 | W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True) 69 | else: 70 | thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*sparsity)].cpu() 71 | W_mask = (W_metric<=thresh) 72 | 73 | return W_mask 74 | 75 | def fasterprune( 76 | self, args, sparsity, mask=None, prune_n=0, prune_m=0, blocksize=128, percdamp=.01 77 | ): 78 | W = self.layer.weight.data.clone() 79 | if isinstance(self.layer, nn.Conv2d): 80 | W = W.flatten(1) 81 | if isinstance(self.layer, transformers.Conv1D): 82 | W = W.t() 83 | W = W.float() 84 | 85 | tick = time.time() 86 | 87 | H = self.H 88 | del self.H 89 | dead = torch.diag(H) == 0 90 | H[dead, dead] = 1 91 | W[:, dead] = 0 92 | 93 | Losses = torch.zeros(self.rows, device=self.dev) 94 | 95 | damp = percdamp * torch.mean(torch.diag(H)) 96 | diag = torch.arange(self.columns, device=self.dev) 97 | H[diag, diag] += damp 98 | H = torch.linalg.cholesky(H) 99 | H = torch.cholesky_inverse(H) 100 | H = torch.linalg.cholesky(H, upper=True) 101 | Hinv = H 102 | 103 | for i1 in range(0, self.columns, blocksize): 104 | i2 = min(i1 + blocksize, self.columns) 105 | count = i2 - i1 106 | 107 | W1 = W[:, i1:i2].clone() 108 | Q1 = torch.zeros_like(W1) 109 | Err1 = torch.zeros_like(W1) 110 | Losses1 = torch.zeros_like(W1) 111 | Hinv1 = Hinv[i1:i2, i1:i2] 112 | 113 | if prune_n == 0 or mask is not None: 114 | if mask is not None: 115 | mask1 = mask[:, i1:i2] 116 | else: 117 | # tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 118 | if "wanda" in args.prune_method: 119 | tmp = torch.abs(W1) * torch.sqrt(self.scaler_row[i1:i2].reshape((1,-1))) 120 | elif "mag" in args.prune_method: 121 | tmp = torch.abs(W1) 122 | thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] 123 | mask1 = tmp <= thresh 124 | else: 125 | mask1 = torch.zeros_like(W1) == 1 126 | 127 | for i in range(count): 128 | w = W1[:, i] 129 | d = Hinv1[i, i] 130 | 131 | if prune_n != 0 and i % prune_m == 0 and mask is None: 132 | # tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2 133 | if "wanda" in args.prune_method: 134 | tmp = torch.abs(W1[:, i:(i+prune_m)]) * torch.sqrt(self.scaler_row[(i+i1):(i+i1+prune_m)].reshape((1,-1))) 135 | elif "mag" in args.prune_method: 136 | tmp = torch.abs(W1[:, i:(i+prune_m)]) 137 | mask1.scatter_(1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True) 138 | 139 | q = w.clone() 140 | q[mask1[:, i]] = 0 141 | 142 | Q1[:, i] = q 143 | Losses1[:, i] = (w - q) ** 2 / d ** 2 144 | 145 | err1 = (w - q) / d 146 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 147 | Err1[:, i] = err1 148 | 149 | W[:, i1:i2] = Q1 150 | Losses += torch.sum(Losses1, 1) / 2 151 | 152 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 153 | 154 | torch.cuda.synchronize() 155 | if isinstance(self.layer, transformers.Conv1D): 156 | W = W.t() 157 | self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 158 | 159 | def free(self): 160 | self.H = None 161 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /lib/data.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py 2 | 3 | import numpy as np 4 | import random 5 | import torch 6 | from datasets import load_dataset 7 | 8 | # Set seed for reproducibility 9 | def set_seed(seed): 10 | np.random.seed(seed) 11 | torch.random.manual_seed(seed) 12 | 13 | # Wrapper for tokenized input IDs 14 | class TokenizerWrapper: 15 | def __init__(self, input_ids): 16 | self.input_ids = input_ids 17 | 18 | # Load and process wikitext2 dataset 19 | def get_wikitext2(nsamples, seed, seqlen, tokenizer): 20 | # Load train and test datasets 21 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 22 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 23 | 24 | # Encode datasets 25 | trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt') 26 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 27 | 28 | # Generate samples from training set 29 | random.seed(seed) 30 | trainloader = [] 31 | for _ in range(nsamples): 32 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 33 | j = i + seqlen 34 | inp = trainenc.input_ids[:, i:j] 35 | tar = inp.clone() 36 | tar[:, :-1] = -100 37 | trainloader.append((inp, tar)) 38 | return trainloader, testenc 39 | 40 | # Load and process c4 dataset 41 | def get_c4(nsamples, seed, seqlen, tokenizer): 42 | # Load train and validation datasets 43 | traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') 44 | valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') 45 | 46 | # Generate samples from training set 47 | random.seed(seed) 48 | trainloader = [] 49 | for _ in range(nsamples): 50 | while True: 51 | i = random.randint(0, len(traindata) - 1) 52 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 53 | if trainenc.input_ids.shape[1] > seqlen: 54 | break 55 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 56 | j = i + seqlen 57 | inp = trainenc.input_ids[:, i:j] 58 | tar = inp.clone() 59 | tar[:, :-1] = -100 60 | trainloader.append((inp, tar)) 61 | 62 | # Prepare validation dataset 63 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 64 | valenc = valenc.input_ids[:, :(256 * seqlen)] 65 | valenc = TokenizerWrapper(valenc) 66 | return trainloader, valenc 67 | 68 | # Function to select the appropriate loader based on dataset name 69 | def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None): 70 | if 'wikitext2' in name: 71 | return get_wikitext2(nsamples, seed, seqlen, tokenizer) 72 | if "c4" in name: 73 | return get_c4(nsamples, seed, seqlen, tokenizer) -------------------------------------------------------------------------------- /lib/eval.py: -------------------------------------------------------------------------------- 1 | # Import necessary modules 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | 6 | # Import get_loaders function from data module within the same directory 7 | from .data import get_loaders 8 | 9 | from collections import defaultdict 10 | import fnmatch 11 | 12 | 13 | # Function to evaluate perplexity (ppl) on a specified model and tokenizer 14 | def eval_ppl(args, model, tokenizer, device=torch.device("cuda:0")): 15 | # Set dataset 16 | dataset = "wikitext2" 17 | 18 | # Print status 19 | print(f"evaluating on {dataset}") 20 | 21 | # Get the test loader 22 | _, testloader = get_loaders( 23 | dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer 24 | ) 25 | 26 | # Evaluate ppl in no grad context to avoid updating the model 27 | with torch.no_grad(): 28 | ppl_test = eval_ppl_wikitext(model, testloader, 1, device) 29 | return ppl_test 30 | 31 | # Function to evaluate perplexity (ppl) specifically on the wikitext dataset 32 | def eval_ppl_wikitext_train(model, trainloader, bs=1, device=None): 33 | # Get input IDs 34 | # testenc = testenc.input_ids 35 | 36 | # Calculate number of samples 37 | # nsamples = testenc.numel() // model.seqlen 38 | nsamples = len(trainloader) 39 | 40 | # List to store negative log likelihoods 41 | nlls = [] 42 | print(f"nsamples {nsamples}") 43 | 44 | # Loop through each batch 45 | for i in range(0,nsamples,bs): 46 | if i % 50 == 0: 47 | print(f"sample {i}") 48 | 49 | # Calculate end index 50 | j = min(i+bs, nsamples) 51 | 52 | # Prepare inputs and move to device 53 | # inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device) 54 | inputs = trainloader[i][0].to(device) 55 | inputs = inputs.reshape(j-i, model.seqlen) 56 | 57 | # Forward pass through the model 58 | lm_logits = model(inputs).logits 59 | 60 | # Shift logits and labels for next token prediction 61 | shift_logits = lm_logits[:, :-1, :].contiguous() 62 | shift_labels = inputs[:, 1:] 63 | 64 | # Compute loss 65 | loss_fct = nn.CrossEntropyLoss() 66 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) 67 | 68 | # Calculate negative log likelihood 69 | neg_log_likelihood = loss.float() * model.seqlen * (j-i) 70 | 71 | # Append to list of negative log likelihoods 72 | nlls.append(neg_log_likelihood) 73 | 74 | # Compute perplexity 75 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 76 | 77 | # Empty CUDA cache to save memory 78 | torch.cuda.empty_cache() 79 | 80 | return ppl.item() 81 | 82 | # Function to evaluate perplexity (ppl) specifically on the wikitext dataset 83 | def eval_ppl_wikitext(model, testenc, bs=1, device=None): 84 | # Get input IDs 85 | testenc = testenc.input_ids 86 | 87 | # Calculate number of samples 88 | nsamples = testenc.numel() // model.seqlen 89 | 90 | # List to store negative log likelihoods 91 | nlls = [] 92 | print(f"nsamples {nsamples}") 93 | 94 | # Loop through each batch 95 | for i in range(0,nsamples,bs): 96 | if i % 50 == 0: 97 | print(f"sample {i}") 98 | 99 | # Calculate end index 100 | j = min(i+bs, nsamples) 101 | 102 | # Prepare inputs and move to device 103 | inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device) 104 | inputs = inputs.reshape(j-i, model.seqlen) 105 | 106 | # Forward pass through the model 107 | lm_logits = model(inputs).logits 108 | 109 | # Shift logits and labels for next token prediction 110 | shift_logits = lm_logits[:, :-1, :].contiguous() 111 | shift_labels = inputs[:, 1:] 112 | 113 | # Compute loss 114 | loss_fct = nn.CrossEntropyLoss() 115 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) 116 | 117 | # Calculate negative log likelihood 118 | neg_log_likelihood = loss.float() * model.seqlen * (j-i) 119 | 120 | # Append to list of negative log likelihoods 121 | nlls.append(neg_log_likelihood) 122 | 123 | # Compute perplexity 124 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 125 | 126 | # Empty CUDA cache to save memory 127 | torch.cuda.empty_cache() 128 | 129 | return ppl.item() 130 | 131 | 132 | def eval_zero_shot(model_name, model, tokenizer, task_list=["boolq","rte","hellaswag","winogrande","arc_challenge","arc_easy","openbookqa"], 133 | num_fewshot=0, use_accelerate=False, add_special_tokens=False): 134 | from lm_eval import tasks, evaluator 135 | def pattern_match(patterns, source_list): 136 | task_names = set() 137 | for pattern in patterns: 138 | for matching in fnmatch.filter(source_list, pattern): 139 | task_names.add(matching) 140 | return list(task_names) 141 | task_names = pattern_match(task_list, tasks.ALL_TASKS) 142 | model_args = f"pretrained={model_name},cache_dir=./llm_weights" 143 | limit = None 144 | if "70b" in model_name or "65b" in model_name: 145 | limit = 2000 146 | if use_accelerate: 147 | model_args = f"pretrained={model_name},cache_dir=./llm_weights,use_accelerate=True" 148 | results = evaluator.simple_evaluate( 149 | model="hf-causal-experimental", 150 | model_args=model_args, 151 | tasks=task_names, 152 | num_fewshot=num_fewshot, 153 | batch_size=None, 154 | device=None, 155 | no_cache=True, 156 | limit=limit, 157 | description_dict={}, 158 | decontamination_ngrams_path=None, 159 | check_integrity=False, 160 | pretrained_model=model, 161 | tokenizer=tokenizer, 162 | add_special_tokens=add_special_tokens 163 | ) 164 | 165 | return results -------------------------------------------------------------------------------- /lib/layerwrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Define WrappedGPT class 5 | class WrappedGPT: 6 | """ 7 | This class wraps a GPT layer for specific operations. 8 | """ 9 | 10 | def __init__(self, layer, layer_id=0, layer_name="none"): 11 | self.layer = layer 12 | self.dev = self.layer.weight.device 13 | self.rows = layer.weight.data.shape[0] 14 | self.columns = layer.weight.data.shape[1] 15 | 16 | self.scaler_row = torch.zeros((self.columns), device=self.dev) 17 | self.nsamples = 0 18 | 19 | self.layer_id = layer_id 20 | self.layer_name = layer_name 21 | 22 | def add_batch(self, inp, out): 23 | if len(inp.shape) == 2: 24 | inp = inp.unsqueeze(0) 25 | tmp = inp.shape[0] 26 | if isinstance(self.layer, nn.Linear): 27 | if len(inp.shape) == 3: 28 | inp = inp.reshape((-1, inp.shape[-1])) 29 | inp = inp.t() 30 | 31 | self.scaler_row *= self.nsamples / (self.nsamples+tmp) 32 | self.nsamples += tmp 33 | 34 | inp = inp.type(torch.float32) 35 | self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples -------------------------------------------------------------------------------- /lib/prune.py: -------------------------------------------------------------------------------- 1 | import time 2 | import heapq 3 | import torch 4 | import torch.nn as nn 5 | from .sparsegpt import SparseGPT 6 | from .layerwrapper import WrappedGPT 7 | from .data import get_loaders 8 | 9 | from .ablate import AblateGPT 10 | 11 | def find_layers(module, layers=[nn.Linear], name=''): 12 | """ 13 | Recursively find the layers of a certain type in a module. 14 | 15 | Args: 16 | module (nn.Module): PyTorch module. 17 | layers (list): List of layer types to find. 18 | name (str): Name of the module. 19 | 20 | Returns: 21 | dict: Dictionary of layers of the given type(s) within the module. 22 | """ 23 | if type(module) in layers: 24 | return {name: module} 25 | res = {} 26 | for name1, child in module.named_children(): 27 | res.update(find_layers( 28 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 29 | )) 30 | return res 31 | 32 | def check_sparsity(model): 33 | use_cache = model.config.use_cache 34 | model.config.use_cache = False 35 | 36 | layers = model.model.layers 37 | count = 0 38 | total_params = 0 39 | for i in range(len(layers)): 40 | layer = layers[i] 41 | subset = find_layers(layer) 42 | 43 | sub_count = 0 44 | sub_params = 0 45 | for name in subset: 46 | W = subset[name].weight.data 47 | count += (W==0).sum().item() 48 | total_params += W.numel() 49 | 50 | sub_count += (W==0).sum().item() 51 | sub_params += W.numel() 52 | 53 | print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}") 54 | 55 | model.config.use_cache = use_cache 56 | return float(count)/total_params 57 | 58 | def prepare_calibration_input(model, dataloader, device): 59 | use_cache = model.config.use_cache 60 | model.config.use_cache = False 61 | layers = model.model.layers 62 | 63 | # dev = model.hf_device_map["model.embed_tokens"] 64 | if "model.embed_tokens" in model.hf_device_map: 65 | device = model.hf_device_map["model.embed_tokens"] 66 | 67 | dtype = next(iter(model.parameters())).dtype 68 | inps = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device) 69 | inps.requires_grad = False 70 | cache = {'i': 0, 'attention_mask': None, "position_ids": None} 71 | 72 | class Catcher(nn.Module): 73 | def __init__(self, module): 74 | super().__init__() 75 | self.module = module 76 | def forward(self, inp, **kwargs): 77 | inps[cache['i']] = inp 78 | cache['i'] += 1 79 | cache['attention_mask'] = kwargs['attention_mask'] 80 | cache['position_ids'] = kwargs['position_ids'] 81 | raise ValueError 82 | layers[0] = Catcher(layers[0]) 83 | for batch in dataloader: 84 | try: 85 | model(batch[0].to(device)) 86 | except ValueError: 87 | pass 88 | layers[0] = layers[0].module 89 | 90 | outs = torch.zeros_like(inps) 91 | attention_mask = cache['attention_mask'] 92 | position_ids = cache['position_ids'] 93 | model.config.use_cache = use_cache 94 | 95 | return inps, outs, attention_mask, position_ids 96 | 97 | def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before): 98 | thres_cumsum = sum_before * alpha 99 | sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1)) 100 | thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1) 101 | W_mask = (W_metric <= thres) 102 | cur_sparsity = (W_mask==True).sum() / W_mask.numel() 103 | return W_mask, cur_sparsity 104 | 105 | def prune_magnitude(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0): 106 | layers = model.model.layers 107 | 108 | for i in range(len(layers)): 109 | layer = layers[i] 110 | subset = find_layers(layer) 111 | 112 | for name in subset: 113 | W = subset[name].weight.data 114 | W_metric = torch.abs(W) 115 | if prune_n != 0: 116 | W_mask = (torch.zeros_like(W)==1) 117 | for ii in range(W_metric.shape[1]): 118 | if ii % prune_m == 0: 119 | tmp = W_metric[:,ii:(ii+prune_m)].float() 120 | W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True) 121 | else: 122 | thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*args.sparsity_ratio)].cpu() 123 | W_mask = (W_metric<=thresh) 124 | 125 | W[W_mask] = 0 126 | 127 | def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0): 128 | use_cache = model.config.use_cache 129 | model.config.use_cache = False 130 | 131 | print("loading calibdation data") 132 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer) 133 | print("dataset loading complete") 134 | with torch.no_grad(): 135 | inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device) 136 | 137 | layers = model.model.layers 138 | for i in range(len(layers)): 139 | layer = layers[i] 140 | subset = find_layers(layer) 141 | 142 | if f"model.layers.{i}" in model.hf_device_map: ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs; 143 | dev = model.hf_device_map[f"model.layers.{i}"] 144 | inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev) 145 | 146 | wrapped_layers = {} 147 | for name in subset: 148 | wrapped_layers[name] = WrappedGPT(subset[name]) 149 | 150 | def add_batch(name): 151 | def tmp(_, inp, out): 152 | wrapped_layers[name].add_batch(inp[0].data, out.data) 153 | return tmp 154 | 155 | handles = [] 156 | for name in wrapped_layers: 157 | handles.append(subset[name].register_forward_hook(add_batch(name))) 158 | for j in range(args.nsamples): 159 | with torch.no_grad(): 160 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 161 | for h in handles: 162 | h.remove() 163 | 164 | for name in subset: 165 | print(f"pruning layer {i} name {name}") 166 | W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1))) 167 | 168 | W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False 169 | if prune_n != 0: 170 | # structured n:m sparsity 171 | for ii in range(W_metric.shape[1]): 172 | if ii % prune_m == 0: 173 | tmp = W_metric[:,ii:(ii+prune_m)].float() 174 | W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True) 175 | else: 176 | sort_res = torch.sort(W_metric, dim=-1, stable=True) 177 | 178 | if args.use_variant: 179 | # wanda variant 180 | tmp_metric = torch.cumsum(sort_res[0], dim=1) 181 | sum_before = W_metric.sum(dim=1) 182 | 183 | alpha = 0.4 184 | alpha_hist = [0., 0.8] 185 | W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before) 186 | while (torch.abs(cur_sparsity - args.sparsity_ratio)>0.001) and (alpha_hist[1]-alpha_hist[0]>=0.001): 187 | if cur_sparsity > args.sparsity_ratio: 188 | alpha_new = (alpha + alpha_hist[0]) / 2.0 189 | alpha_hist[1] = alpha 190 | else: 191 | alpha_new = (alpha + alpha_hist[1]) / 2.0 192 | alpha_hist[0] = alpha 193 | 194 | alpha = alpha_new 195 | W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before) 196 | print(f"alpha found {alpha} sparsity {cur_sparsity:.6f}") 197 | else: 198 | # unstructured pruning 199 | indices = sort_res[1][:,:int(W_metric.shape[1]*args.sparsity_ratio)] 200 | W_mask.scatter_(1, indices, True) 201 | 202 | subset[name].weight.data[W_mask] = 0 ## set weights to zero 203 | 204 | for j in range(args.nsamples): 205 | with torch.no_grad(): 206 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 207 | inps, outs = outs, inps 208 | 209 | model.config.use_cache = use_cache 210 | torch.cuda.empty_cache() 211 | 212 | 213 | @torch.no_grad() 214 | def prune_sparsegpt(args, model, tokenizer, dev, prune_n=0, prune_m=0): 215 | ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa 216 | print('Starting ...') 217 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer) 218 | 219 | use_cache = model.config.use_cache 220 | model.config.use_cache = False 221 | layers = model.model.layers 222 | 223 | if "model.embed_tokens" in model.hf_device_map: 224 | dev = model.hf_device_map["model.embed_tokens"] 225 | 226 | dtype = next(iter(model.parameters())).dtype 227 | inps = torch.zeros( 228 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 229 | ) 230 | cache = {'i': 0, 'attention_mask': None, "position_ids": None} 231 | 232 | class Catcher(nn.Module): 233 | def __init__(self, module): 234 | super().__init__() 235 | self.module = module 236 | def forward(self, inp, **kwargs): 237 | inps[cache['i']] = inp 238 | cache['i'] += 1 239 | cache['attention_mask'] = kwargs['attention_mask'] 240 | cache['position_ids'] = kwargs['position_ids'] 241 | raise ValueError 242 | layers[0] = Catcher(layers[0]) 243 | for batch in dataloader: 244 | try: 245 | model(batch[0].to(dev)) 246 | except ValueError: 247 | pass 248 | layers[0] = layers[0].module 249 | torch.cuda.empty_cache() 250 | 251 | outs = torch.zeros_like(inps) 252 | attention_mask = cache['attention_mask'] 253 | position_ids = cache['position_ids'] 254 | 255 | print('Ready.') 256 | 257 | for i in range(len(layers)): 258 | layer = layers[i] 259 | if f"model.layers.{i}" in model.hf_device_map: 260 | dev = model.hf_device_map[f"model.layers.{i}"] 261 | print(f"layer {i} device {dev}") 262 | inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev) 263 | 264 | subset = find_layers(layer) 265 | 266 | gpts = {} 267 | for name in subset: 268 | gpts[name] = SparseGPT(subset[name]) 269 | 270 | def add_batch(name): 271 | def tmp(_, inp, out): 272 | gpts[name].add_batch(inp[0].data, out.data) 273 | return tmp 274 | 275 | handles = [] 276 | for name in gpts: 277 | handles.append(subset[name].register_forward_hook(add_batch(name))) 278 | 279 | for j in range(args.nsamples): 280 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 281 | for h in handles: 282 | h.remove() 283 | 284 | for name in gpts: 285 | print(i, name) 286 | print('Pruning ...') 287 | 288 | gpts[name].fasterprune(args.sparsity_ratio, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128) 289 | gpts[name].free() 290 | 291 | for j in range(args.nsamples): 292 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 293 | 294 | layers[i] = layer 295 | torch.cuda.empty_cache() 296 | 297 | inps, outs = outs, inps 298 | 299 | model.config.use_cache = use_cache 300 | torch.cuda.empty_cache() 301 | 302 | 303 | 304 | @torch.no_grad() 305 | def prune_ablate(args, model, tokenizer, dev, prune_n=0, prune_m=0): 306 | ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa 307 | print('Starting ...') 308 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer) 309 | 310 | use_cache = model.config.use_cache 311 | model.config.use_cache = False 312 | layers = model.model.layers 313 | 314 | if "model.embed_tokens" in model.hf_device_map: 315 | dev = model.hf_device_map["model.embed_tokens"] 316 | 317 | dtype = next(iter(model.parameters())).dtype 318 | inps = torch.zeros( 319 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 320 | ) 321 | cache = {'i': 0, 'attention_mask': None, "position_ids": None} 322 | 323 | class Catcher(nn.Module): 324 | def __init__(self, module): 325 | super().__init__() 326 | self.module = module 327 | def forward(self, inp, **kwargs): 328 | inps[cache['i']] = inp 329 | cache['i'] += 1 330 | cache['attention_mask'] = kwargs['attention_mask'] 331 | cache['position_ids'] = kwargs['position_ids'] 332 | raise ValueError 333 | layers[0] = Catcher(layers[0]) 334 | for batch in dataloader: 335 | try: 336 | model(batch[0].to(dev)) 337 | except ValueError: 338 | pass 339 | layers[0] = layers[0].module 340 | torch.cuda.empty_cache() 341 | 342 | outs = torch.zeros_like(inps) 343 | attention_mask = cache['attention_mask'] 344 | position_ids = cache['position_ids'] 345 | 346 | print('Ready.') 347 | 348 | for i in range(len(layers)): 349 | layer = layers[i] 350 | if f"model.layers.{i}" in model.hf_device_map: 351 | dev = model.hf_device_map[f"model.layers.{i}"] 352 | print(f"layer {i} device {dev}") 353 | inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev) 354 | 355 | subset = find_layers(layer) 356 | 357 | gpts = {} 358 | for name in subset: 359 | gpts[name] = AblateGPT(subset[name]) 360 | 361 | def add_batch(name): 362 | def tmp(_, inp, out): 363 | gpts[name].add_batch(inp[0].data, out.data) 364 | return tmp 365 | 366 | handles = [] 367 | for name in gpts: 368 | handles.append(subset[name].register_forward_hook(add_batch(name))) 369 | 370 | for j in range(args.nsamples): 371 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 372 | for h in handles: 373 | h.remove() 374 | 375 | for name in gpts: 376 | print(i, name) 377 | print('Pruning ...') 378 | 379 | if args.prune_method == "ablate_wanda_seq": 380 | prune_mask = gpts[name].get_wanda_mask(args.sparsity_ratio, prune_n, prune_m) 381 | elif args.prune_method == "ablate_mag_seq": 382 | prune_mask = gpts[name].get_mag_mask(args.sparsity_ratio, prune_n, prune_m) 383 | elif "iter" in args.prune_method: 384 | prune_mask = None 385 | 386 | gpts[name].fasterprune(args, args.sparsity_ratio, mask=prune_mask, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128) 387 | gpts[name].free() 388 | 389 | for j in range(args.nsamples): 390 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 391 | 392 | layers[i] = layer 393 | torch.cuda.empty_cache() 394 | 395 | inps, outs = outs, inps 396 | 397 | model.config.use_cache = use_cache 398 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /lib/prune_opt.py: -------------------------------------------------------------------------------- 1 | import time 2 | import heapq 3 | import torch 4 | import torch.nn as nn 5 | from .sparsegpt import SparseGPT 6 | from .layerwrapper import WrappedGPT 7 | from .data import get_loaders 8 | 9 | from .ablate import AblateGPT 10 | 11 | def find_layers(module, layers=[nn.Linear], name=''): 12 | """ 13 | Recursively find the layers of a certain type in a module. 14 | 15 | Args: 16 | module (nn.Module): PyTorch module. 17 | layers (list): List of layer types to find. 18 | name (str): Name of the module. 19 | 20 | Returns: 21 | dict: Dictionary of layers of the given type(s) within the module. 22 | """ 23 | if type(module) in layers: 24 | return {name: module} 25 | res = {} 26 | for name1, child in module.named_children(): 27 | res.update(find_layers( 28 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 29 | )) 30 | return res 31 | 32 | def check_sparsity(model): 33 | use_cache = model.config.use_cache 34 | model.config.use_cache = False 35 | 36 | layers = model.model.decoder.layers 37 | count = 0 38 | total_params = 0 39 | for i in range(len(layers)): 40 | layer = layers[i] 41 | subset = find_layers(layer) 42 | 43 | sub_count = 0 44 | sub_params = 0 45 | for name in subset: 46 | W = subset[name].weight.data 47 | count += (W==0).sum().item() 48 | total_params += W.numel() 49 | 50 | sub_count += (W==0).sum().item() 51 | sub_params += W.numel() 52 | 53 | print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}") 54 | 55 | model.config.use_cache = use_cache 56 | return float(count)/total_params 57 | 58 | def prepare_calibration_input(model, dataloader, device): 59 | use_cache = model.config.use_cache 60 | model.config.use_cache = False 61 | layers = model.model.decoder.layers 62 | 63 | if "model.embed_tokens" in model.hf_device_map: 64 | device = model.hf_device_map["model.embed_tokens"] 65 | 66 | dtype = next(iter(model.parameters())).dtype 67 | inps = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device) 68 | inps.requires_grad = False 69 | cache = {'i': 0, 'attention_mask': None, "position_ids": None} 70 | 71 | class Catcher(nn.Module): 72 | def __init__(self, module): 73 | super().__init__() 74 | self.module = module 75 | def forward(self, inp, **kwargs): 76 | inps[cache['i']] = inp 77 | cache['i'] += 1 78 | cache['attention_mask'] = kwargs['attention_mask'] 79 | raise ValueError 80 | layers[0] = Catcher(layers[0]) 81 | for batch in dataloader: 82 | try: 83 | model(batch[0].to(device)) 84 | except ValueError: 85 | pass 86 | layers[0] = layers[0].module 87 | 88 | outs = torch.zeros_like(inps) 89 | attention_mask = cache['attention_mask'] 90 | model.config.use_cache = use_cache 91 | 92 | return inps, outs, attention_mask 93 | 94 | def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before): 95 | thres_cumsum = sum_before * alpha 96 | sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1)) 97 | thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1) 98 | W_mask = (W_metric <= thres) 99 | cur_sparsity = (W_mask==True).sum() / W_mask.numel() 100 | return W_mask, cur_sparsity 101 | 102 | def prune_magnitude(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0): 103 | layers = model.model.decoder.layers 104 | 105 | for i in range(len(layers)): 106 | layer = layers[i] 107 | subset = find_layers(layer) 108 | 109 | for name in subset: 110 | W = subset[name].weight.data 111 | W_metric = torch.abs(W) 112 | if prune_n != 0: 113 | W_mask = (torch.zeros_like(W)==1) 114 | for ii in range(W_metric.shape[1]): 115 | if ii % prune_m == 0: 116 | tmp = W_metric[:,ii:(ii+prune_m)].float() 117 | W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True) 118 | else: 119 | thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*args.sparsity_ratio)].cpu() 120 | W_mask = (W_metric<=thresh) 121 | 122 | W[W_mask] = 0 123 | 124 | def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0): 125 | use_cache = model.config.use_cache 126 | model.config.use_cache = False 127 | 128 | print("loading calibdation data") 129 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer) 130 | print("dataset loading complete") 131 | with torch.no_grad(): 132 | inps, outs, attention_mask = prepare_calibration_input(model, dataloader, device) 133 | 134 | layers = model.model.decoder.layers 135 | for i in range(len(layers)): 136 | layer = layers[i] 137 | subset = find_layers(layer) 138 | 139 | if f"model.layers.{i}" in model.hf_device_map: ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs; 140 | dev = model.hf_device_map[f"model.layers.{i}"] 141 | inps, outs, attention_mask = inps.to(dev), outs.to(dev), attention_mask.to(dev) 142 | 143 | wrapped_layers = {} 144 | for name in subset: 145 | wrapped_layers[name] = WrappedGPT(subset[name]) 146 | 147 | def add_batch(name): 148 | def tmp(_, inp, out): 149 | wrapped_layers[name].add_batch(inp[0].data, out.data) 150 | return tmp 151 | 152 | handles = [] 153 | for name in wrapped_layers: 154 | handles.append(subset[name].register_forward_hook(add_batch(name))) 155 | for j in range(args.nsamples): 156 | with torch.no_grad(): 157 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 158 | for h in handles: 159 | h.remove() 160 | 161 | for name in subset: 162 | print(f"pruning layer {i} name {name}") 163 | W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1))) 164 | 165 | W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False 166 | if prune_n != 0: 167 | # structured n:m sparsity 168 | for ii in range(W_metric.shape[1]): 169 | if ii % prune_m == 0: 170 | tmp = W_metric[:,ii:(ii+prune_m)].float() 171 | W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True) 172 | else: 173 | sort_res = torch.sort(W_metric, dim=-1, stable=True) 174 | 175 | # unstructured pruning 176 | indices = sort_res[1][:,:int(W_metric.shape[1]*args.sparsity_ratio)] 177 | W_mask.scatter_(1, indices, True) 178 | 179 | subset[name].weight.data[W_mask] = 0 ## set weights to zero 180 | 181 | for j in range(args.nsamples): 182 | with torch.no_grad(): 183 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 184 | inps, outs = outs, inps 185 | 186 | model.config.use_cache = use_cache 187 | torch.cuda.empty_cache() 188 | 189 | @torch.no_grad() 190 | def prune_sparsegpt(args, model, tokenizer, dev, prune_n=0, prune_m=0): 191 | ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa 192 | print('Starting ...') 193 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer) 194 | 195 | use_cache = model.config.use_cache 196 | model.config.use_cache = False 197 | layers = model.model.decoder.layers 198 | 199 | if "model.embed_tokens" in model.hf_device_map: 200 | dev = model.hf_device_map["model.embed_tokens"] 201 | 202 | dtype = next(iter(model.parameters())).dtype 203 | inps = torch.zeros( 204 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 205 | ) 206 | cache = {'i': 0, 'attention_mask': None, "position_ids": None} 207 | 208 | class Catcher(nn.Module): 209 | def __init__(self, module): 210 | super().__init__() 211 | self.module = module 212 | def forward(self, inp, **kwargs): 213 | inps[cache['i']] = inp 214 | cache['i'] += 1 215 | cache['attention_mask'] = kwargs['attention_mask'] 216 | # cache['position_ids'] = kwargs['position_ids'] 217 | raise ValueError 218 | layers[0] = Catcher(layers[0]) 219 | for batch in dataloader: 220 | try: 221 | model(batch[0].to(dev)) 222 | except ValueError: 223 | pass 224 | layers[0] = layers[0].module 225 | torch.cuda.empty_cache() 226 | 227 | outs = torch.zeros_like(inps) 228 | attention_mask = cache['attention_mask'] 229 | 230 | print('Ready.') 231 | 232 | for i in range(len(layers)): 233 | layer = layers[i] 234 | if f"model.layers.{i}" in model.hf_device_map: 235 | dev = model.hf_device_map[f"model.layers.{i}"] 236 | print(f"layer {i} device {dev}") 237 | inps, outs, attention_mask = inps.to(dev), outs.to(dev), attention_mask.to(dev) 238 | 239 | subset = find_layers(layer) 240 | 241 | gpts = {} 242 | for name in subset: 243 | gpts[name] = SparseGPT(subset[name]) 244 | 245 | def add_batch(name): 246 | def tmp(_, inp, out): 247 | gpts[name].add_batch(inp[0].data, out.data) 248 | return tmp 249 | 250 | handles = [] 251 | for name in gpts: 252 | handles.append(subset[name].register_forward_hook(add_batch(name))) 253 | 254 | for j in range(args.nsamples): 255 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 256 | for h in handles: 257 | h.remove() 258 | 259 | for name in gpts: 260 | print(i, name) 261 | print('Pruning ...') 262 | 263 | gpts[name].fasterprune(args.sparsity_ratio, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128) 264 | gpts[name].free() 265 | 266 | for j in range(args.nsamples): 267 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 268 | 269 | layers[i] = layer 270 | torch.cuda.empty_cache() 271 | 272 | inps, outs = outs, inps 273 | 274 | model.config.use_cache = use_cache 275 | torch.cuda.empty_cache() 276 | 277 | @torch.no_grad() 278 | def prune_ablate(args, model, tokenizer, dev, prune_n=0, prune_m=0): 279 | ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa 280 | print('Starting ...') 281 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer) 282 | 283 | use_cache = model.config.use_cache 284 | model.config.use_cache = False 285 | layers = model.model.decoder.layers 286 | 287 | if "model.embed_tokens" in model.hf_device_map: 288 | dev = model.hf_device_map["model.embed_tokens"] 289 | 290 | dtype = next(iter(model.parameters())).dtype 291 | inps = torch.zeros( 292 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 293 | ) 294 | cache = {'i': 0, 'attention_mask': None, "position_ids": None} 295 | 296 | class Catcher(nn.Module): 297 | def __init__(self, module): 298 | super().__init__() 299 | self.module = module 300 | def forward(self, inp, **kwargs): 301 | inps[cache['i']] = inp 302 | cache['i'] += 1 303 | cache['attention_mask'] = kwargs['attention_mask'] 304 | # cache['position_ids'] = kwargs['position_ids'] 305 | raise ValueError 306 | layers[0] = Catcher(layers[0]) 307 | for batch in dataloader: 308 | try: 309 | model(batch[0].to(dev)) 310 | except ValueError: 311 | pass 312 | layers[0] = layers[0].module 313 | torch.cuda.empty_cache() 314 | 315 | outs = torch.zeros_like(inps) 316 | attention_mask = cache['attention_mask'] 317 | # position_ids = cache['position_ids'] 318 | 319 | print('Ready.') 320 | 321 | for i in range(len(layers)): 322 | layer = layers[i] 323 | if f"model.layers.{i}" in model.hf_device_map: 324 | dev = model.hf_device_map[f"model.layers.{i}"] 325 | print(f"layer {i} device {dev}") 326 | inps, outs, attention_mask = inps.to(dev), outs.to(dev), attention_mask.to(dev) 327 | 328 | subset = find_layers(layer) 329 | 330 | gpts = {} 331 | for name in subset: 332 | gpts[name] = AblateGPT(subset[name]) 333 | 334 | def add_batch(name): 335 | def tmp(_, inp, out): 336 | gpts[name].add_batch(inp[0].data, out.data) 337 | return tmp 338 | 339 | handles = [] 340 | for name in gpts: 341 | handles.append(subset[name].register_forward_hook(add_batch(name))) 342 | 343 | for j in range(args.nsamples): 344 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 345 | for h in handles: 346 | h.remove() 347 | 348 | for name in gpts: 349 | print(i, name) 350 | print('Pruning ...') 351 | 352 | if args.prune_method == "ablate_wanda_seq": 353 | prune_mask = gpts[name].get_wanda_mask(args.sparsity_ratio, prune_n, prune_m) 354 | elif args.prune_method == "ablate_mag_seq": 355 | prune_mask = gpts[name].get_mag_mask(args.sparsity_ratio, prune_n, prune_m) 356 | elif "iter" in args.prune_method: 357 | prune_mask = None 358 | 359 | gpts[name].fasterprune(args, args.sparsity_ratio, mask=prune_mask, 360 | prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128) 361 | gpts[name].free() 362 | 363 | for j in range(args.nsamples): 364 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] 365 | 366 | layers[i] = layer 367 | torch.cuda.empty_cache() 368 | 369 | inps, outs = outs, inps 370 | 371 | model.config.use_cache = use_cache 372 | torch.cuda.empty_cache() 373 | -------------------------------------------------------------------------------- /lib/sparsegpt.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | torch.backends.cuda.matmul.allow_tf32 = False 9 | torch.backends.cudnn.allow_tf32 = False 10 | 11 | ## SparseGPT: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa 12 | class SparseGPT: 13 | 14 | def __init__(self, layer): 15 | self.layer = layer 16 | self.dev = self.layer.weight.device 17 | W = layer.weight.data.clone() 18 | if isinstance(self.layer, nn.Conv2d): 19 | W = W.flatten(1) 20 | if isinstance(self.layer, transformers.Conv1D): 21 | W = W.t() 22 | self.rows = W.shape[0] 23 | self.columns = W.shape[1] 24 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 25 | self.nsamples = 0 26 | 27 | def add_batch(self, inp, out): 28 | if len(inp.shape) == 2: 29 | inp = inp.unsqueeze(0) 30 | tmp = inp.shape[0] 31 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 32 | if len(inp.shape) == 3: 33 | inp = inp.reshape((-1, inp.shape[-1])) 34 | inp = inp.t() 35 | self.H *= self.nsamples / (self.nsamples + tmp) 36 | self.nsamples += tmp 37 | inp = math.sqrt(2 / self.nsamples) * inp.float() 38 | self.H += inp.matmul(inp.t()) 39 | 40 | def fasterprune( 41 | self, sparsity, prune_n=0, prune_m=0, blocksize=128, percdamp=.01 42 | ): 43 | W = self.layer.weight.data.clone() 44 | if isinstance(self.layer, nn.Conv2d): 45 | W = W.flatten(1) 46 | if isinstance(self.layer, transformers.Conv1D): 47 | W = W.t() 48 | W = W.float() 49 | 50 | tick = time.time() 51 | 52 | H = self.H 53 | del self.H 54 | dead = torch.diag(H) == 0 55 | H[dead, dead] = 1 56 | W[:, dead] = 0 57 | 58 | Losses = torch.zeros(self.rows, device=self.dev) 59 | 60 | damp = percdamp * torch.mean(torch.diag(H)) 61 | diag = torch.arange(self.columns, device=self.dev) 62 | H[diag, diag] += damp 63 | H = torch.linalg.cholesky(H) 64 | H = torch.cholesky_inverse(H) 65 | H = torch.linalg.cholesky(H, upper=True) 66 | Hinv = H 67 | 68 | mask = None 69 | 70 | for i1 in range(0, self.columns, blocksize): 71 | i2 = min(i1 + blocksize, self.columns) 72 | count = i2 - i1 73 | 74 | W1 = W[:, i1:i2].clone() 75 | Q1 = torch.zeros_like(W1) 76 | Err1 = torch.zeros_like(W1) 77 | Losses1 = torch.zeros_like(W1) 78 | Hinv1 = Hinv[i1:i2, i1:i2] 79 | 80 | if prune_n == 0: 81 | if mask is not None: 82 | mask1 = mask[:, i1:i2] 83 | else: 84 | tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 85 | thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] 86 | mask1 = tmp <= thresh 87 | else: 88 | mask1 = torch.zeros_like(W1) == 1 89 | 90 | for i in range(count): 91 | w = W1[:, i] 92 | d = Hinv1[i, i] 93 | 94 | if prune_n != 0 and i % prune_m == 0: 95 | tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2 96 | mask1.scatter_(1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True) 97 | 98 | q = w.clone() 99 | q[mask1[:, i]] = 0 100 | 101 | Q1[:, i] = q 102 | Losses1[:, i] = (w - q) ** 2 / d ** 2 103 | 104 | err1 = (w - q) / d 105 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 106 | Err1[:, i] = err1 107 | 108 | W[:, i1:i2] = Q1 109 | Losses += torch.sum(Losses1, 1) / 2 110 | 111 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 112 | 113 | torch.cuda.synchronize() 114 | if isinstance(self.layer, transformers.Conv1D): 115 | W = W.t() 116 | self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 117 | 118 | def free(self): 119 | self.H = None 120 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /lora_ft/README.md: -------------------------------------------------------------------------------- 1 | ## LoRA Fine-tuning of pruned LLMs 2 | Here we provide the script for the lora fine-tuning experiments in the paper. The commands for reproducing our experiments are in [script.sh](script.sh). 3 | 4 | This codebase is based on [run_clm.py](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling). Here we adapt this code with LoRA fine-tuning on the C4 training dataset. Some custom changes we make in the code include: 5 | - [loc 1](https://github.com/locuslab/wanda/blob/main/lora_ft/finetune_lm.py#L374): set up LLaMA-7B for LoRA fine-tuning; 6 | - [loc 2](https://github.com/locuslab/wanda/blob/main/lora_ft/finetune_lm.py#L521): set up training arguments for Trainer. 7 | - [loc 3](https://github.com/locuslab/wanda/blob/main/lora_ft/finetune_lm.py#L364): load the tokenizer from vicuna, which are the same as the original LLaMA tokenizer but also fix the issues of some special tokens. 8 | - [loc 4](https://github.com/locuslab/wanda/blob/main/lora_ft/finetune_lm.py#L319): load the c4 training dataset. 9 | 10 | To train a LoRA adapter, run the command: 11 | ```sh 12 | CUDA_VISIBLE_DEVICES=0 python finetune_lm.py \ 13 | --model_name_or_path [PATH to load sparse pruned LLaMA-7B] \ 14 | --config_name "decapoda-research/llama-7b-hf" \ 15 | --dataset_name c4 \ 16 | --num_train_epochs 1 \ 17 | --block_size 1024 \ 18 | --per_device_train_batch_size 1 \ 19 | --per_device_eval_batch_size 8 \ 20 | --do_train \ 21 | --do_eval \ 22 | --max_train_samples 30000 \ 23 | --max_eval_samples 128 \ 24 | --learning_rate 1e-4 \ 25 | --overwrite_output_dir \ 26 | --output_dir [PATH to save the LoRA weights] 27 | ``` 28 | We provide a quick overview of the arguments: 29 | - `--model_name_or_path`: The path/directory where pruned LLaMA-7B are saved with `model.save_pretrained(PATH)`. 30 | - `--block_size`: context size, if you have 80GB gpu, you can set it to 2048; 31 | - `--max_train_samples`: the number of training sequences, 30000 would lead to roughly 12 hours of training on 1 GPU; 32 | - `--learning_rate`: the learning rate for LoRA fine-tuning; 33 | 34 | We provide the code to evaluate LoRA adapter on WikiText validation dataset in [evaluate_ppl.py](evaluate_ppl.py). For zero shot evaluation, additionally pass the `--eval_zero_shot` argument. -------------------------------------------------------------------------------- /lora_ft/evaluate_ppl.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from transformers import AutoModelForCausalLM 3 | from datasets import load_dataset 4 | import torch 5 | import torch.nn as nn 6 | from peft import PeftModel, PeftConfig 7 | from tqdm import tqdm 8 | import sys 9 | import json 10 | import time 11 | import os 12 | 13 | import fnmatch 14 | 15 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 16 | if type(module) in layers: 17 | return {name: module} 18 | res = {} 19 | for name1, child in module.named_children(): 20 | res.update(find_layers( 21 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 22 | )) 23 | return res 24 | 25 | def check_sparsity(model): 26 | use_cache = model.config.use_cache 27 | model.config.use_cache = False 28 | 29 | try: 30 | layers = model.model.layers 31 | except: 32 | layers = model.model.model.layers 33 | count = 0 34 | total_params = 0 35 | for i in range(len(layers)): 36 | layer = layers[i] 37 | subset = find_layers(layer) 38 | 39 | for name in subset: 40 | W = subset[name].weight.data 41 | cur_zeros = (W==0).sum().item() 42 | cur_total = W.numel() 43 | 44 | count += cur_zeros 45 | total_params += cur_total 46 | 47 | print(f"layer {i} name {name} {W.shape} sparsity {float(cur_zeros)/cur_total}") 48 | 49 | print(f"total number of params {total_params}") 50 | model.config.use_cache = use_cache 51 | return float(count)/total_params 52 | 53 | def evaluate_ppl(dataset_name, model, tokenizer, ctx_length): 54 | # max_length = model.seqlen 55 | model_seqlen = ctx_length 56 | max_length = ctx_length 57 | stride = ctx_length 58 | 59 | if dataset_name == "wikitext": 60 | test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") 61 | encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt") 62 | seq_len = encodings.input_ids.size(1) 63 | elif dataset_name == "ptb": 64 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') 65 | encodings = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 66 | seq_len = encodings.input_ids.size(1) 67 | elif dataset_name == "c4": 68 | valdata = load_dataset( 69 | 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' 70 | ) 71 | encodings = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 72 | # encodings = encodings.input_ids[:, :(256 * model.seqlen)] 73 | seq_len = 256 * model_seqlen 74 | 75 | nlls = [] 76 | prev_end_loc = 0 77 | for begin_loc in tqdm(range(0, seq_len, stride)): 78 | end_loc = min(begin_loc + max_length, seq_len) 79 | trg_len = end_loc - prev_end_loc # may be different from stride on last loop 80 | input_ids = encodings.input_ids[:, begin_loc:end_loc].cuda() 81 | target_ids = input_ids.clone() 82 | target_ids[:, :-trg_len] = -100 83 | 84 | with torch.no_grad(): 85 | outputs = model(input_ids, labels=target_ids) 86 | 87 | neg_log_likelihood = outputs.loss 88 | 89 | nlls.append(neg_log_likelihood) 90 | 91 | prev_end_loc = end_loc 92 | if end_loc == seq_len: 93 | break 94 | 95 | ppl = torch.exp(torch.stack(nlls).mean()) 96 | return ppl.item() 97 | 98 | def eval_llm(model, tokenizer, task_list=["boolq","piqa","hellaswag","winogrande","arc_challenge","arc_easy","openbookqa"], num_fewshot=0): 99 | from lm_eval import tasks, evaluator 100 | def pattern_match(patterns, source_list): 101 | task_names = set() 102 | for pattern in patterns: 103 | for matching in fnmatch.filter(source_list, pattern): 104 | task_names.add(matching) 105 | return list(task_names) 106 | task_names = pattern_match(task_list, tasks.ALL_TASKS) 107 | results = evaluator.simple_evaluate( 108 | model="hf-causal-experimental", 109 | model_args="pretrained=decapoda-research/llama-7b-hf", 110 | tasks=task_names, 111 | num_fewshot=num_fewshot, 112 | batch_size=None, 113 | # device='cuda:0', 114 | device=None, 115 | no_cache=True, 116 | limit=None, 117 | description_dict={}, 118 | decontamination_ngrams_path=None, 119 | check_integrity=False, 120 | pretrained_model=model, 121 | tokenizer=tokenizer 122 | ) 123 | 124 | return results 125 | 126 | def main(args): 127 | model = AutoModelForCausalLM.from_pretrained( 128 | args.model, 129 | torch_dtype=torch.float16, cache_dir=args.cache_dir, low_cpu_mem_usage=True, device_map="auto") 130 | tokenizer = AutoTokenizer.from_pretrained( 131 | "lmsys/vicuna-13b-delta-v0", 132 | cache_dir=args.cache_dir, 133 | padding_side="right", 134 | use_fast=True, 135 | ) 136 | 137 | model = PeftModel.from_pretrained(model,args.lora_weights,torch_dtype=torch.float16) 138 | 139 | model.eval() 140 | 141 | ppl = evaluate_ppl("wikitext", model, tokenizer, args.ctx_length) 142 | print(f"perplexity on wikitext {ppl}") 143 | 144 | if args.eval_zero_shot: 145 | task_list_dict = {0: ["boolq", "rte","hellaswag","winogrande", "arc_easy","arc_challenge", "openbookqa"]} 146 | accelerate=False 147 | for num_shot in [0]: 148 | task_list = task_list_dict[num_shot] 149 | results = eval_llm(model, tokenizer, task_list, num_shot) 150 | 151 | if __name__ == "__main__": 152 | import argparse 153 | 154 | parser = argparse.ArgumentParser() 155 | 156 | parser.add_argument( 157 | '--model', type=str 158 | ) 159 | parser.add_argument( 160 | '--cache_dir', type=str, default="llm_weights" 161 | ) 162 | parser.add_argument( 163 | '--lora_weights', type=str, default=None 164 | ) 165 | parser.add_argument( 166 | '--ctx_length', type=int, default=2048 167 | ) 168 | parser.add_argument("--eval_zero_shot", action="store_true") 169 | 170 | args = parser.parse_args() 171 | main(args) -------------------------------------------------------------------------------- /lora_ft/script.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python finetune_lm.py \ 2 | --model_name_or_path [PATH to load sparse pruned LLaMA-7B] \ 3 | --config_name "decapoda-research/llama-7b-hf" \ 4 | --dataset_name c4 \ 5 | --num_train_epochs 1 \ 6 | --block_size 1024 \ 7 | --per_device_train_batch_size 1 \ 8 | --per_device_eval_batch_size 8 \ 9 | --do_train \ 10 | --do_eval \ 11 | --max_train_samples 30000 \ 12 | --max_eval_samples 128 \ 13 | --learning_rate 1e-4 \ 14 | --overwrite_output_dir \ 15 | --output_dir [PATH to save the LoRA weights] 16 | 17 | CUDA_VISIBLE_DEVICES=0 python evaluate_ppl.py \ 18 | --model [PATH to load sparse pruned LLaMA-7B] \ 19 | --lora_weights [PATH to load the LoRA weights] -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | from importlib.metadata import version 7 | 8 | from lib.prune import prune_wanda, prune_magnitude, prune_sparsegpt, prune_ablate, check_sparsity, find_layers 9 | from lib.eval import eval_ppl, eval_zero_shot 10 | 11 | print('torch', version('torch')) 12 | print('transformers', version('transformers')) 13 | print('accelerate', version('accelerate')) 14 | print('# of gpus: ', torch.cuda.device_count()) 15 | 16 | def get_llm(model_name, cache_dir="llm_weights"): 17 | model = AutoModelForCausalLM.from_pretrained( 18 | model_name, 19 | torch_dtype=torch.float16, 20 | cache_dir=cache_dir, 21 | low_cpu_mem_usage=True, 22 | device_map="auto" 23 | ) 24 | 25 | model.seqlen = model.config.max_position_embeddings 26 | return model 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--model', type=str, help='LLaMA model') 31 | parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') 32 | parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.') 33 | parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level') 34 | parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"]) 35 | parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt", 36 | "ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"]) 37 | parser.add_argument("--cache_dir", default="llm_weights", type=str ) 38 | parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix") 39 | parser.add_argument('--save', type=str, default=None, help='Path to save results.') 40 | parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.') 41 | 42 | parser.add_argument("--eval_zero_shot", action="store_true") 43 | args = parser.parse_args() 44 | 45 | # Setting seeds for reproducibility 46 | np.random.seed(args.seed) 47 | torch.random.manual_seed(args.seed) 48 | 49 | # Handling n:m sparsity 50 | prune_n, prune_m = 0, 0 51 | if args.sparsity_type != "unstructured": 52 | assert args.sparsity_ratio == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity" 53 | prune_n, prune_m = map(int, args.sparsity_type.split(":")) 54 | 55 | model_name = args.model.split("/")[-1] 56 | print(f"loading llm model {args.model}") 57 | model = get_llm(args.model, args.cache_dir) 58 | model.eval() 59 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) 60 | 61 | device = torch.device("cuda:0") 62 | if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here. 63 | device = model.hf_device_map["lm_head"] 64 | print("use device ", device) 65 | 66 | if args.sparsity_ratio != 0: 67 | print("pruning starts") 68 | if args.prune_method == "wanda": 69 | prune_wanda(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m) 70 | elif args.prune_method == "magnitude": 71 | prune_magnitude(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m) 72 | elif args.prune_method == "sparsegpt": 73 | prune_sparsegpt(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m) 74 | elif "ablate" in args.prune_method: 75 | prune_ablate(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m) 76 | 77 | ################################################################ 78 | print("*"*30) 79 | sparsity_ratio = check_sparsity(model) 80 | print(f"sparsity sanity check {sparsity_ratio:.4f}") 81 | print("*"*30) 82 | ################################################################ 83 | ppl_test = eval_ppl(args, model, tokenizer, device) 84 | print(f"wikitext perplexity {ppl_test}") 85 | 86 | if not os.path.exists(args.save): 87 | os.makedirs(args.save) 88 | save_filepath = os.path.join(args.save, f"log_{args.prune_method}.txt") 89 | with open(save_filepath, "w") as f: 90 | print("method\tactual_sparsity\tppl_test", file=f, flush=True) 91 | print(f"{args.prune_method}\t{sparsity_ratio:.4f}\t{ppl_test:.4f}", file=f, flush=True) 92 | 93 | if args.eval_zero_shot: 94 | accelerate=False 95 | if "30b" in args.model or "65b" in args.model or "70b" in args.model: 96 | accelerate=True 97 | 98 | task_list = ["boolq", "rte","hellaswag","winogrande", "arc_easy","arc_challenge", "openbookqa"] 99 | num_shot = 0 100 | results = eval_zero_shot(args.model, model, tokenizer, task_list, num_shot, accelerate) 101 | print("********************************") 102 | print("zero_shot evaluation results") 103 | print(results) 104 | 105 | if args.save_model: 106 | model.save_pretrained(args.save_model) 107 | tokenizer.save_pretrained(args.save_model) 108 | 109 | if __name__ == '__main__': 110 | main() -------------------------------------------------------------------------------- /main_opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | from importlib.metadata import version 7 | 8 | from lib.prune_opt import prune_wanda, prune_magnitude, prune_sparsegpt, prune_ablate, check_sparsity, find_layers 9 | from lib.eval import eval_ppl, eval_zero_shot 10 | 11 | print('torch', version('torch')) 12 | print('transformers', version('transformers')) 13 | print('accelerate', version('accelerate')) 14 | print('# of gpus: ', torch.cuda.device_count()) 15 | 16 | def get_llm(model_name, cache_dir="llm_weights"): 17 | model = AutoModelForCausalLM.from_pretrained( 18 | model_name, 19 | torch_dtype=torch.float16, 20 | cache_dir=cache_dir, 21 | low_cpu_mem_usage=True, 22 | device_map="auto" 23 | ) 24 | 25 | model.seqlen = model.config.max_position_embeddings 26 | return model 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--model', type=str, help='LLaMA model') 31 | parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') 32 | parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.') 33 | parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level') 34 | parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"]) 35 | parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt", 36 | "ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"]) 37 | parser.add_argument("--cache_dir", default="llm_weights", type=str ) 38 | parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix") 39 | parser.add_argument('--save', type=str, default=None, help='Path to save results.') 40 | parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.') 41 | 42 | parser.add_argument("--eval_zero_shot", action="store_true") 43 | args = parser.parse_args() 44 | 45 | # Setting seeds for reproducibility 46 | np.random.seed(args.seed) 47 | torch.random.manual_seed(args.seed) 48 | 49 | # Handling n:m sparsity 50 | prune_n, prune_m = 0, 0 51 | if args.sparsity_type != "unstructured": 52 | assert args.sparsity_ratio == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity" 53 | prune_n, prune_m = map(int, args.sparsity_type.split(":")) 54 | 55 | model_name = args.model.split("/")[-1] 56 | print(f"loading llm model {args.model}") 57 | model = get_llm(args.model, args.cache_dir) 58 | model.eval() 59 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) 60 | 61 | device = torch.device("cuda:0") 62 | if "30b" in args.model or "66b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here. 63 | device = model.hf_device_map["lm_head"] 64 | print("use device ", device) 65 | 66 | if args.sparsity_ratio != 0: 67 | print("pruning starts") 68 | if args.prune_method == "wanda": 69 | prune_wanda(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m) 70 | elif args.prune_method == "magnitude": 71 | prune_magnitude(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m) 72 | elif args.prune_method == "sparsegpt": 73 | prune_sparsegpt(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m) 74 | elif "ablate" in args.prune_method: 75 | prune_ablate(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m) 76 | 77 | ################################################################ 78 | print("*"*30) 79 | sparsity_ratio = check_sparsity(model) 80 | print(f"sparsity sanity check {sparsity_ratio:.4f}") 81 | print("*"*30) 82 | ################################################################ 83 | ppl_test = eval_ppl(args, model, tokenizer, device) 84 | print(f"wikitext perplexity {ppl_test}") 85 | 86 | if not os.path.exists(args.save): 87 | os.makedirs(args.save) 88 | save_filepath = os.path.join(args.save, f"log_{args.prune_method}.txt") 89 | with open(save_filepath, "w") as f: 90 | print("method\tactual_sparsity\tppl_test", file=f, flush=True) 91 | print(f"{args.prune_method}\t{sparsity_ratio:.4f}\t{ppl_test:.4f}", file=f, flush=True) 92 | 93 | if args.eval_zero_shot: 94 | accelerate=False 95 | if "30b" in args.model or "66b" in args.model: 96 | accelerate=True 97 | 98 | task_list = ["boolq", "rte","hellaswag","winogrande", "arc_easy","arc_challenge", "openbookqa"] 99 | num_shot = 0 100 | results = eval_zero_shot(args.model, model, tokenizer, task_list, num_shot, accelerate) 101 | print("********************************") 102 | print("zero_shot evaluation results") 103 | print(results) 104 | 105 | if args.save_model: 106 | model.save_pretrained(args.save_model) 107 | tokenizer.save_pretrained(args.save_model) 108 | 109 | if __name__ == '__main__': 110 | main() -------------------------------------------------------------------------------- /scripts/ablate_weight_update.sh: -------------------------------------------------------------------------------- 1 | for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter 2 | do 3 | CUDA_VISIBLE_DEVICES=0 python main.py \ 4 | --model decapoda-research/llama-7b-hf \ 5 | --nsamples 128 \ 6 | --sparsity_ratio 0.5 \ 7 | --sparsity_type unstructured \ 8 | --prune_method ${method} \ 9 | --save out/llama_7b_ablation/unstructured/ 10 | done 11 | 12 | for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter 13 | do 14 | CUDA_VISIBLE_DEVICES=0 python main.py \ 15 | --model decapoda-research/llama-7b-hf \ 16 | --nsamples 128 \ 17 | --sparsity_ratio 0.5 \ 18 | --sparsity_type 4:8 \ 19 | --prune_method ${method} \ 20 | --save out/llama_7b_ablation/4:8/ 21 | done 22 | 23 | for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter 24 | do 25 | CUDA_VISIBLE_DEVICES=0 python main.py \ 26 | --model decapoda-research/llama-7b-hf \ 27 | --nsamples 128 \ 28 | --sparsity_ratio 0.5 \ 29 | --sparsity_type 2:4 \ 30 | --prune_method ${method} \ 31 | --save out/llama_7b_ablation/2:4/ 32 | done 33 | -------------------------------------------------------------------------------- /scripts/llama_13b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set common variables 4 | model="decapoda-research/llama-13b-hf" 5 | sparsity_ratio=0.5 6 | cuda_device=0 7 | 8 | # Set CUDA device visibility 9 | export CUDA_VISIBLE_DEVICES=$cuda_device 10 | 11 | # Define function to run python command 12 | run_python_command () { 13 | python main.py \ 14 | --model $model \ 15 | --prune_method $1 \ 16 | --sparsity_ratio $sparsity_ratio \ 17 | --sparsity_type $2 \ 18 | --save $3 19 | } 20 | 21 | # llama-13b with wanda pruning method 22 | echo "Running with wanda pruning method" 23 | run_python_command "wanda" "unstructured" "out/llama_13b/unstructured/wanda/" 24 | run_python_command "wanda" "2:4" "out/llama_13b/2-4/wanda/" 25 | run_python_command "wanda" "4:8" "out/llama_13b/4-8/wanda/" 26 | echo "Finished wanda pruning method" 27 | 28 | # llama-13b with sparsegpt pruning method 29 | echo "Running with sparsegpt pruning method" 30 | run_python_command "sparsegpt" "unstructured" "out/llama_13b/unstructured/sparsegpt/" 31 | run_python_command "sparsegpt" "2:4" "out/llama_13b/2-4/sparsegpt/" 32 | run_python_command "sparsegpt" "4:8" "out/llama_13b/4-8/sparsegpt/" 33 | echo "Finished sparsegpt pruning method" 34 | 35 | # llama-13b with magnitude pruning method 36 | echo "Running with magnitude pruning method" 37 | run_python_command "magnitude" "unstructured" "out/llama_13b/unstructured/magnitude/" 38 | run_python_command "magnitude" "2:4" "out/llama_13b/2-4/magnitude/" 39 | run_python_command "magnitude" "4:8" "out/llama_13b/4-8/magnitude/" 40 | echo "Finished magnitude pruning method" -------------------------------------------------------------------------------- /scripts/llama_30b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set common variables 4 | model="decapoda-research/llama-30b-hf" 5 | sparsity_ratio=0.5 6 | cuda_devices="0,1" 7 | 8 | # Set CUDA device visibility 9 | export CUDA_VISIBLE_DEVICES=$cuda_devices 10 | 11 | # Define function to run python command 12 | run_python_command () { 13 | python main.py \ 14 | --model $model \ 15 | --prune_method $1 \ 16 | --sparsity_ratio $sparsity_ratio \ 17 | --sparsity_type $2 \ 18 | --save $3 19 | } 20 | 21 | # llama-30b with wanda pruning method 22 | echo "Running with wanda pruning method" 23 | run_python_command "wanda" "unstructured" "out/llama_30b/unstructured/wanda/" 24 | run_python_command "wanda" "2:4" "out/llama_30b/2-4/wanda/" 25 | run_python_command "wanda" "4:8" "out/llama_30b/4-8/wanda/" 26 | echo "Finished wanda pruning method" 27 | 28 | # llama-30b with sparsegpt pruning method 29 | echo "Running with sparsegpt pruning method" 30 | run_python_command "sparsegpt" "unstructured" "out/llama_30b/unstructured/sparsegpt/" 31 | run_python_command "sparsegpt" "2:4" "out/llama_30b/2-4/sparsegpt/" 32 | run_python_command "sparsegpt" "4:8" "out/llama_30b/4-8/sparsegpt/" 33 | echo "Finished sparsegpt pruning method" 34 | 35 | # llama-30b with magnitude pruning method 36 | echo "Running with magnitude pruning method" 37 | run_python_command "magnitude" "unstructured" "out/llama_30b/unstructured/magnitude/" 38 | run_python_command "magnitude" "2:4" "out/llama_30b/2-4/magnitude/" 39 | run_python_command "magnitude" "4:8" "out/llama_30b/4-8/magnitude/" 40 | echo "Finished magnitude pruning method" -------------------------------------------------------------------------------- /scripts/llama_65b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set common variables 4 | model="decapoda-research/llama-65b-hf" 5 | sparsity_ratio=0.5 6 | 7 | # Define function to run python command 8 | run_python_command () { 9 | CUDA_VISIBLE_DEVICES=$1 python main.py \ 10 | --model $model \ 11 | --prune_method $2 \ 12 | --sparsity_ratio $sparsity_ratio \ 13 | --sparsity_type $3 \ 14 | --save $4 15 | } 16 | 17 | # llama-65b with wanda pruning method 18 | echo "Running with wanda pruning method" 19 | run_python_command "0,1,2,3,4" "wanda" "unstructured" "out/llama_65b/unstructured/wanda/" 20 | run_python_command "0,1,2,3,4" "wanda" "2:4" "out/llama_65b/2-4/wanda/" 21 | run_python_command "0,1,2,3,4" "wanda" "4:8" "out/llama_65b/4-8/wanda/" 22 | echo "Finished wanda pruning method" 23 | 24 | # llama-65b with sparsegpt pruning method 25 | echo "Running with sparsegpt pruning method" 26 | run_python_command "0,1,2,3,4" "sparsegpt" "unstructured" "out/llama_65b/unstructured/sparsegpt/" 27 | run_python_command "0,1,2,3,4" "sparsegpt" "2:4" "out/llama_65b/2-4/sparsegpt/" 28 | run_python_command "0,1,2,3,4" "sparsegpt" "4:8" "out/llama_65b/4-8/sparsegpt/" 29 | echo "Finished sparsegpt pruning method" 30 | 31 | # llama-65b with magnitude pruning method 32 | echo "Running with magnitude pruning method" 33 | run_python_command "0,1,2,3" "magnitude" "unstructured" "out/llama_65b/unstructured/magnitude/" 34 | run_python_command "0,1,2,3" "magnitude" "2:4" "out/llama_65b/2-4/magnitude/" 35 | run_python_command "0,1,2,3" "magnitude" "4:8" "out/llama_65b/4-8/magnitude/" 36 | echo "Finished magnitude pruning method" -------------------------------------------------------------------------------- /scripts/llama_7b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set common variables 4 | model="decapoda-research/llama-7b-hf" 5 | sparsity_ratio=0.5 6 | cuda_device=0 7 | 8 | # Set CUDA device visibility 9 | export CUDA_VISIBLE_DEVICES=$cuda_device 10 | 11 | # Define function to run python command 12 | run_python_command () { 13 | python main.py \ 14 | --model $model \ 15 | --prune_method $1 \ 16 | --sparsity_ratio $sparsity_ratio \ 17 | --sparsity_type $2 \ 18 | --save $3 19 | } 20 | 21 | # llama-7b with wanda pruning method 22 | echo "Running with wanda pruning method" 23 | run_python_command "wanda" "unstructured" "out/llama_7b/unstructured/wanda/" 24 | run_python_command "wanda" "2:4" "out/llama_7b/2-4/wanda/" 25 | run_python_command "wanda" "4:8" "out/llama_7b/4-8/wanda/" 26 | echo "Finished wanda pruning method" 27 | 28 | # llama-7b with sparsegpt pruning method 29 | echo "Running with sparsegpt pruning method" 30 | run_python_command "sparsegpt" "unstructured" "out/llama_7b/unstructured/sparsegpt/" 31 | run_python_command "sparsegpt" "2:4" "out/llama_7b/2-4/sparsegpt/" 32 | run_python_command "sparsegpt" "4:8" "out/llama_7b/4-8/sparsegpt/" 33 | echo "Finished sparsegpt pruning method" 34 | 35 | # llama-7b with magnitude pruning method 36 | echo "Running with magnitude pruning method" 37 | run_python_command "magnitude" "unstructured" "out/llama_7b/unstructured/magnitude/" 38 | run_python_command "magnitude" "2:4" "out/llama_7b/2-4/magnitude/" 39 | run_python_command "magnitude" "4:8" "out/llama_7b/4-8/magnitude/" 40 | echo "Finished magnitude pruning method" --------------------------------------------------------------------------------