├── .gitignore
├── INSTALL.md
├── LICENSE
├── README.md
├── dense_ft
├── README.md
├── sparse_trainer.py
└── trainer.py
├── image_classifiers
├── README.md
├── datasets.py
├── download_weights.sh
├── engine.py
├── layerwrapper.py
├── main.py
├── models
│ ├── convnext.py
│ ├── deit.py
│ ├── mlp_mixer.py
│ ├── swin_transformer.py
│ └── vision_transformer.py
├── optim_factory.py
├── prune_utils.py
└── utils.py
├── lib
├── ablate.py
├── data.py
├── eval.py
├── layerwrapper.py
├── prune.py
├── prune_opt.py
└── sparsegpt.py
├── lora_ft
├── README.md
├── evaluate_ppl.py
├── finetune_lm.py
└── script.sh
├── main.py
├── main_opt.py
└── scripts
├── ablate_weight_update.sh
├── llama_13b.sh
├── llama_30b.sh
├── llama_65b.sh
└── llama_7b.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | /*__pycache__/
2 | .env
3 | *.pyc
4 | *.DS_Store
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 | Step 1: Create a new conda environment:
3 | ```
4 | conda create -n prune_llm python=3.9
5 | conda activate prune_llm
6 | ```
7 | Step 2: Install relevant packages
8 | ```
9 | conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
10 | pip install transformers==4.28.0 datasets==2.11.0 wandb sentencepiece
11 | pip install accelerate==0.18.0
12 | ```
13 | There are known [issues](https://github.com/huggingface/transformers/issues/22222) with the transformers library on loading the LLaMA tokenizer correctly. Please follow the mentioned suggestions to resolve this issue.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 CMU Locus Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pruning LLMs by Weights and Activations
2 | Official PyTorch implementation of **Wanda** (Pruning by **W**eights **and a**ctivations), as presented in our paper:
3 |
4 | **A Simple and Effective Pruning Approach for Large Language Models**
5 | *Mingjie Sun\*, Zhuang Liu\*, Anna Bair, J. Zico Kolter* (* indicates equal contribution)
6 | Carnegie Mellon University, Meta AI Research and Bosch Center for AI
7 | [Paper](https://arxiv.org/abs/2306.11695) - [Project page](https://eric-mingjie.github.io/wanda/home.html)
8 |
9 | ```bibtex
10 | @article{sun2023wanda,
11 | title={A Simple and Effective Pruning Approach for Large Language Models},
12 | author={Sun, Mingjie and Liu, Zhuang and Bair, Anna and Kolter, J. Zico},
13 | year={2023},
14 | journal={arXiv preprint arXiv:2306.11695}
15 | }
16 | ```
17 |
18 | ---
19 |
20 |
22 |
23 |
24 | Compared to magnitude pruning which removes weights solely based on their magnitudes, our pruning approach **Wanda** removes weights on a *per-output* basis, by the product of weight magnitudes and input activation norms.
25 |
26 | ## Update
27 | - [x] (9.22.2023) Add [support](https://github.com/locuslab/wanda#pruning-llama-2) for LLaMA-2.
28 | - [x] (9.22.2023) Add [code](https://github.com/locuslab/wanda#ablation-on-obs-weight-update) to reproduce the ablation study on OBS weight update in the paper.
29 | - [x] (10.6.2023) Add new [support](https://github.com/locuslab/wanda#ablation-on-obs-weight-update) for the weight update analysis in the ablation study. Feel free to try it out!
30 | - [x] (10.6.2023) Add [support](https://github.com/locuslab/wanda#zero-shot-evaluation) for zero-shot evaluation.
31 | - [x] (10.20.2023) Add code for pruning OPT models.
32 | - [x] (10.23.2023) Add code for [LoRA fine-tuning](lora_ft).
33 |
34 | ## Setup
35 | Installation instructions can be found in [INSTALL.md](INSTALL.md).
36 |
37 | ## Usage
38 | The [scripts](scripts) directory contains all the bash commands to replicate the main results (Table 2) in our paper.
39 |
40 | Below is an example command for pruning LLaMA-7B with Wanda, to achieve unstructured 50% sparsity.
41 | ```sh
42 | python main.py \
43 | --model decapoda-research/llama-7b-hf \
44 | --prune_method wanda \
45 | --sparsity_ratio 0.5 \
46 | --sparsity_type unstructured \
47 | --save out/llama_7b/unstructured/wanda/
48 | ```
49 | We provide a quick overview of the arguments:
50 | - `--model`: The identifier for the LLaMA model on the Hugging Face model hub.
51 | - `--cache_dir`: Directory for loading or storing LLM weights. The default is `llm_weights`.
52 | - `--prune_method`: We have implemented three pruning methods, namely [`magnitude`, `wanda`, `sparsegpt`].
53 | - `--sparsity_ratio`: Denotes the percentage of weights to be pruned.
54 | - `--sparsity_type`: Specifies the type of sparsity [`unstructured`, `2:4`, `4:8`].
55 | - `--use_variant`: Whether to use the Wanda variant, default is `False`.
56 | - `--save`: Specifies the directory where the result will be stored.
57 |
58 | For structured N:M sparsity, set the argument `--sparsity_type` to "2:4" or "4:8". An illustrative command is provided below:
59 | ```sh
60 | python main.py \
61 | --model decapoda-research/llama-7b-hf \
62 | --prune_method wanda \
63 | --sparsity_ratio 0.5 \
64 | --sparsity_type 2:4 \
65 | --save out/llama_7b/2-4/wanda/
66 | ```
67 |
68 | ### Pruning LLaMA-2
69 | For [LLaMA-2](https://ai.meta.com/llama/) models, replace `--model` with `meta-llama/Llama-2-7b-hf` (take `7b` as an example):
70 | ```sh
71 | python main.py \
72 | --model meta-llama/Llama-2-7b-hf \
73 | --prune_method wanda \
74 | --sparsity_ratio 0.5 \
75 | --sparsity_type unstructured \
76 | --save out/llama2_7b/unstructured/wanda/
77 | ```
78 | LLaMA-2 results: (LLaMA-2-34b is not released as of 9.22.2023)
79 | |sparsity| ppl | llama2-7b | llama2-13b | llama2-70b |
80 | |------|------------------|----------|------------|------------|
81 | |-| dense | 5.12 | 4.57 | 3.12 |
82 | |unstructured 50%| magnitude | 14.89 | 6.37 | 4.98 |
83 | |unstructured 50%| sparsegpt | 6.51 | 5.63 | **3.98** |
84 | |unstructured 50%| wanda | **6.42** | **5.56** | **3.98** |
85 | |4:8| magnitude | 16.48 | 6.76 | 5.58 |
86 | |4:8| sparsegpt | 8.12 | 6.60 | 4.59 |
87 | |4:8| wanda | **7.97** | **6.55** | **4.47** |
88 | |2:4| magnitude | 54.59 | 8.33 | 6.33 |
89 | |2:4| sparsegpt | **10.17** | 8.32 | 5.40 |
90 | |2:4| wanda | 11.02 | **8.27** | **5.16** |
91 |
92 | ### Ablation on OBS weight update
93 | To reproduce the analysis on weight update, we provide our implementation for this ablation. All commands can be found in [this script](scripts/ablate_weight_update.sh).
94 | ```sh
95 | for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter
96 | do
97 | CUDA_VISIBLE_DEVICES=0 python main.py \
98 | --model decapoda-research/llama-7b-hf \
99 | --sparsity_ratio 0.5 \
100 | --sparsity_type unstructured \
101 | --prune_method ${method} \
102 | --save out/llama_7b_ablation/unstructured/
103 | done
104 | ```
105 | Here `ablate_{mag/wanda}_{seq/iter}` means that we use magnitude pruning or wanda to obtain the pruned mask at each layer, then apply weight update procedure with either a sequential style or an iterative style every 128 input channels. For details, please see Section 5 of our [paper](https://arxiv.org/abs/2306.11695).
106 |
107 | ### Zero-Shot Evaluation
108 | For evaluating zero-shot tasks, we modify the [EleutherAI LM Harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/master) framework so that it could evaluate pruned LLM models. We provide the modified repo in [this link](https://drive.google.com/file/d/1zugbLyGZKsH1L19L9biHLfaGGFnEc7XL/view?usp=sharing). Make sure to download, extract and install this custom `lm_eval` package from the source code.
109 |
110 | For reproducibility, we used [commit `df3da98`](https://github.com/EleutherAI/lm-evaluation-harness/tree/df3da98c5405deafd519c2ddca52bb7c3fe36bef) on the main branch. All tasks were evaluated on task version of 0 except for BoolQ, where the task version is 1.
111 |
112 | On a high level, the functionality we provide is adding two arguments `pretrained_model` and `tokenizer` in this [function](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/evaluator.py#L17). We can then call this `simple_evaluate` function API from our [codebase](https://github.com/locuslab/wanda/blob/main/lib/eval.py#L148) to evaluate sparse pruned LLMs. To evaluate zero-shot tasks in addition to the WikiText perplexity, pass the `--eval_zero_shot` argument.
113 |
114 | ### Speedup Evaluation
115 | The pruning speed for each method is evaluated by the cumulated time spent on pruning (for each layer), without the forward passes.
116 |
117 | For inference speedup with structured sparsity, we refer the reader to this [blog post](https://pytorch.org/tutorials/prototype/semi_structured_sparse.html), where structured sparsity is supported by `PyTorch >= 2.1`. You can switch between the CUTLASS or CuSPARSELt kernel [here](https://github.com/pytorch/pytorch/blob/v2.1.0/torch/sparse/semi_structured.py#L55).
118 |
119 | Last, for pruning image classifiers, see directory [image_classifiers](image_classifiers) for details.
120 |
121 | ## Acknowledgement
122 | This repository is build upon the [SparseGPT](https://github.com/IST-DASLab/sparsegpt) repository.
123 |
124 | ## License
125 | This project is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information.
126 |
127 | ## Questions
128 | Feel free to discuss papers/code with us through issues/emails!
129 |
130 | mingjies at cs.cmu.edu
131 | liuzhuangthu at gmail.com
--------------------------------------------------------------------------------
/dense_ft/README.md:
--------------------------------------------------------------------------------
1 | ## Dense Fine-tuning
2 |
3 | We provide a sparse Trainer in [sparse_trainer.py](sparse_trainer.py), which can be used as a drop-in-replacement of huggingface Trainer. In `training_step` function, when the gradients are obtained, `mask_grad()` will zero out the gradient corresponding to the pruned weights.
--------------------------------------------------------------------------------
/dense_ft/sparse_trainer.py:
--------------------------------------------------------------------------------
1 | from transformers.trainer import Trainer
2 | import torch
3 | import torch.nn as nn
4 |
5 | def find_layers(module, layers=[nn.Linear], name=''):
6 | if type(module) in layers:
7 | return {name: module}
8 | res = {}
9 | for name1, child in module.named_children():
10 | res.update(find_layers(
11 | child, layers=layers, name=name + '.' + name1 if name != '' else name1
12 | ))
13 | return res
14 |
15 | def fix_grad_nan_inf(model):
16 | layers = model.model.layers
17 | count = 0
18 | total_params = 0
19 | for m in model.parameters():
20 | if m.requires_grad:
21 | if torch.isnan(m.grad).any() or torch.isinf(m.grad).any():
22 | m.grad.zero_()
23 |
24 |
25 | def mask_grad(model):
26 | layers = model.model.layers
27 | count = 0
28 | total_params = 0
29 | for i in range(len(layers)):
30 | layer = layers[i]
31 | subset = find_layers(layer)
32 |
33 | sub_count = 0
34 | sub_params = 0
35 | for name in subset:
36 | W = subset[name].weight.data
37 | mask = (W==0)
38 | subset[name].weight.grad[mask]= 0
39 |
40 | def check_sparsity(model):
41 | use_cache = model.config.use_cache
42 | model.config.use_cache = False
43 |
44 | layers = model.model.layers
45 | count = 0
46 | total_params = 0
47 | for i in range(len(layers)):
48 | layer = layers[i]
49 | subset = find_layers(layer)
50 |
51 | sub_count = 0
52 | sub_params = 0
53 | for name in subset:
54 | W = subset[name].weight.data
55 | count += (W==0).sum().item()
56 | total_params += W.numel()
57 |
58 | sub_count += (W==0).sum().item()
59 | sub_params += W.numel()
60 |
61 | # print(f"layer {i} sparsity {float(sub_count)/sub_params:.4f}")
62 |
63 | model.config.use_cache = use_cache
64 | return float(count)/total_params
65 |
66 | class SparseTrainer(Trainer):
67 | def __init__(self, model= None, args= None, data_collator= None, train_dataset= None, eval_dataset= None,
68 | tokenizer= None, model_init= None, compute_metrics= None, callbacks= None, optimizers= (None, None),
69 | preprocess_logits_for_metrics= None
70 | ):
71 | super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks,
72 | optimizers, preprocess_logits_for_metrics)
73 | self.counter = 0
74 |
75 | def training_step(self, model, inputs):
76 | """
77 | Perform a training step on a batch of inputs.
78 |
79 | Subclass and override to inject custom behavior.
80 |
81 | Args:
82 | model (`nn.Module`):
83 | The model to train.
84 | inputs (`Dict[str, Union[torch.Tensor, Any]]`):
85 | The inputs and targets of the model.
86 |
87 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
88 | argument `labels`. Check your model's documentation for all accepted arguments.
89 |
90 | Return:
91 | `torch.Tensor`: The tensor with training loss on this batch.
92 |
93 | access optimizer through: self.optimizer.optimizer.param_groups[0]
94 | """
95 | self.counter += 1
96 | model.train()
97 | inputs = self._prepare_inputs(inputs)
98 |
99 | with self.compute_loss_context_manager():
100 | loss = self.compute_loss(model, inputs)
101 |
102 | if self.args.n_gpu > 1:
103 | loss = loss.mean() # mean() to average on multi-gpu parallel training
104 |
105 | if self.do_grad_scaling: ## False
106 | self.scaler.scale(loss).backward()
107 | elif self.use_apex: ## False
108 | with amp.scale_loss(loss, self.optimizer) as scaled_loss:
109 | scaled_loss.backward()
110 | else:
111 | self.accelerator.backward(loss)
112 | # pass
113 |
114 | mask_grad(model) ### mask the gradients
115 |
116 | return loss.detach() / self.args.gradient_accumulation_steps
117 |
118 | def compute_loss(self, model, inputs, return_outputs=False):
119 | """
120 | How the loss is computed by Trainer. By default, all models return the loss in the first element.
121 |
122 | Subclass and override for custom behavior.
123 |
124 | ## model type: transformers.models.llama.modeling_llama.LlamaForCausalLM
125 | ## outputs[0]: a single scalar
126 | ## outputs[1]: shape (bs, 2048, 32000)
127 |
128 | ## inputs["input_ids"] shape: (bs, 2048)
129 | ## inputs["attention_mask] shape: (bs, 2048)
130 | """
131 | if self.label_smoother is not None and "labels" in inputs:
132 | labels = inputs.pop("labels")
133 | else:
134 | labels = None
135 |
136 | outputs = model(**inputs)
137 | # Save past state if it exists
138 | # TODO: this needs to be fixed and made cleaner later.
139 | if self.args.past_index >= 0:
140 | self._past = outputs[self.args.past_index]
141 |
142 | if labels is not None:
143 | if is_peft_available() and isinstance(model, PeftModel):
144 | model_name = unwrap_model(model.base_model)._get_name()
145 | else:
146 | model_name = unwrap_model(model)._get_name()
147 | if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
148 | loss = self.label_smoother(outputs, labels, shift_labels=True)
149 | else:
150 | loss = self.label_smoother(outputs, labels)
151 | else:
152 | if isinstance(outputs, dict) and "loss" not in outputs:
153 | raise ValueError(
154 | "The model did not return a loss from the inputs, only the following keys: "
155 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
156 | )
157 | # We don't use .loss here since the model may return tuples instead of ModelOutput.
158 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
159 |
160 | return (loss, outputs) if return_outputs else loss
--------------------------------------------------------------------------------
/image_classifiers/README.md:
--------------------------------------------------------------------------------
1 | # Pruning Image Classifiers
2 | Here we provide the code for pruning ConvNeXt and ViT. This part is built on the [dropout](https://github.com/facebookresearch/dropout) repository.
3 |
4 | ## Environment
5 | We additionally install `timm` for loading pretrained image classifiers.
6 | ```
7 | pip install timm==0.4.12
8 | ```
9 |
10 | ## Download Weights
11 | Run the script [download_weights.sh](download_weights.sh) to download pretrained weights for ConvNeXt-B, DeiT-B and ViT-L, which we used in the paper.
12 |
13 | ## Usage
14 | Here is the command for pruning ConvNeXt/ViT models:
15 | ```
16 | python main.py --model [ARCH] \
17 | --data_path [PATH to ImageNet] \
18 | --resume [PATH to the pretrained weights] \
19 | --prune_metric wanda \
20 | --prune_granularity row \
21 | --sparsity 0.5
22 | ```
23 | where:
24 | - `--model`: network architecture, choices [`convnext_base`, `deit_base_patch16_224`, `vit_large_patch16_224`].
25 | - `--resume`: model path to downloaded pretrained weights.
26 | - `--prune_metric`: [`magnitude`, `wanda`].
27 | - `--prune_granularity`: [`layer`, `row`].
--------------------------------------------------------------------------------
/image_classifiers/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import os
9 | from torchvision import datasets, transforms
10 |
11 | from timm.data.constants import \
12 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
13 | from timm.data import create_transform
14 |
15 | def build_dataset(is_train, args):
16 | transform = build_transform(is_train, args)
17 |
18 | print("Transform = ")
19 | if isinstance(transform, tuple):
20 | for trans in transform:
21 | print(" - - - - - - - - - - ")
22 | for t in trans.transforms:
23 | print(t)
24 | else:
25 | for t in transform.transforms:
26 | print(t)
27 | print("---------------------------")
28 |
29 | if args.data_set == 'CIFAR':
30 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
31 | nb_classes = 100
32 | elif args.data_set == 'IMNET':
33 | print("reading from datapath", args.data_path)
34 | root = os.path.join(args.data_path, 'train' if is_train else 'val_dirs')
35 | dataset = datasets.ImageFolder(root, transform=transform)
36 | nb_classes = 1000
37 | elif args.data_set == "image_folder":
38 | root = args.data_path if is_train else args.eval_data_path
39 | dataset = datasets.ImageFolder(root, transform=transform)
40 | nb_classes = args.nb_classes
41 | assert len(dataset.class_to_idx) == nb_classes
42 | else:
43 | raise NotImplementedError()
44 | print("Number of the class = %d" % nb_classes)
45 |
46 | return dataset, nb_classes
47 |
48 |
49 | def build_transform(is_train, args):
50 | resize_im = args.input_size > 32
51 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
52 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
53 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
54 |
55 | if is_train:
56 | # this should always dispatch to transforms_imagenet_train
57 | transform = create_transform(
58 | input_size=args.input_size,
59 | is_training=True,
60 | color_jitter=args.color_jitter,
61 | auto_augment=args.aa,
62 | interpolation=args.train_interpolation,
63 | re_prob=args.reprob,
64 | re_mode=args.remode,
65 | re_count=args.recount,
66 | mean=mean,
67 | std=std,
68 | )
69 | if not resize_im:
70 | transform.transforms[0] = transforms.RandomCrop(
71 | args.input_size, padding=4)
72 | return transform
73 |
74 | t = []
75 | if resize_im:
76 | # warping (no cropping) when evaluated at 384 or larger
77 | if args.input_size >= 384:
78 | t.append(
79 | transforms.Resize((args.input_size, args.input_size),
80 | interpolation=transforms.InterpolationMode.BICUBIC),
81 | )
82 | print(f"Warping {args.input_size} size input images...")
83 | else:
84 | if args.crop_pct is None:
85 | args.crop_pct = 224 / 256
86 | size = int(args.input_size / args.crop_pct)
87 | t.append(
88 | # to maintain same ratio w.r.t. 224 images
89 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
90 | )
91 | t.append(transforms.CenterCrop(args.input_size))
92 |
93 | t.append(transforms.ToTensor())
94 | t.append(transforms.Normalize(mean, std))
95 | return transforms.Compose(t)
96 |
--------------------------------------------------------------------------------
/image_classifiers/download_weights.sh:
--------------------------------------------------------------------------------
1 | mkdir -p model_weights/vit/
2 | mkdir -p model_weights/convnext
3 | mkdir -p model_weights/deit/
4 |
5 | cd model_weights/vit
6 | wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth
7 |
8 | cd ../convnext/
9 | wget https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth
10 |
11 | cd ../deit
12 | wget https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth
13 |
14 | cd ../..
--------------------------------------------------------------------------------
/image_classifiers/engine.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import math
9 | from typing import Iterable, Optional
10 | import torch
11 | from timm.data import Mixup
12 | from timm.utils import accuracy, ModelEma
13 |
14 | import utils
15 | import torch.nn as nn
16 |
17 | from prune_utils import check_sparsity
18 |
19 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
20 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
21 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
22 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
23 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, schedules={},
24 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False):
25 | model.train(True)
26 | metric_logger = utils.MetricLogger(delimiter=" ")
27 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
28 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
29 | header = 'Epoch: [{}]'.format(epoch)
30 | print_freq = 10
31 |
32 | optimizer.zero_grad()
33 |
34 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
35 | # if data_iter_step > 10:
36 | # break
37 |
38 | step = data_iter_step // update_freq
39 | if step >= num_training_steps_per_epoch:
40 | continue
41 | it = start_steps + step # global training iteration
42 | # Update LR & WD for the first acc
43 | if data_iter_step % update_freq == 0:
44 | if lr_schedule_values is not None or wd_schedule_values is not None:
45 | for i, param_group in enumerate(optimizer.param_groups):
46 | if lr_schedule_values is not None:
47 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
48 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
49 | param_group["weight_decay"] = wd_schedule_values[it]
50 | if 'dp' in schedules:
51 | model.module.update_drop_path(schedules['dp'][it])
52 | if 'do' in schedules:
53 | model.module.update_dropout(schedules['do'][it])
54 |
55 | samples = samples.to(device, non_blocking=True)
56 | targets = targets.to(device, non_blocking=True)
57 |
58 | if mixup_fn is not None:
59 | samples, targets = mixup_fn(samples, targets)
60 |
61 | if use_amp:
62 | with torch.cuda.amp.autocast():
63 | output = model(samples)
64 | loss = criterion(output, targets)
65 | else: # full precision
66 | output = model(samples)
67 | loss = criterion(output, targets)
68 |
69 | loss_value = loss.item()
70 |
71 | if not math.isfinite(loss_value): # this could trigger if using AMP
72 | print("Loss is {}, stopping training".format(loss_value))
73 | assert math.isfinite(loss_value)
74 |
75 | if use_amp:
76 | # this attribute is added by timm on one optimizer (adahessian)
77 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
78 | loss /= update_freq
79 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
80 | parameters=model.parameters(), create_graph=is_second_order,
81 | update_grad=(data_iter_step + 1) % update_freq == 0)
82 |
83 | if (data_iter_step + 1) % update_freq == 0:
84 | optimizer.zero_grad()
85 | if model_ema is not None:
86 | model_ema.update(model)
87 | else: # full precision
88 | loss /= update_freq
89 | loss.backward()
90 |
91 | ###############################################################################
92 | ## code i added
93 | for k, m in enumerate(model.modules()):
94 | if isinstance(m, nn.Linear):
95 | weight_copy = m.weight.data.abs().clone()
96 | mask = weight_copy.gt(0).float().cuda()
97 | m.weight.grad.data.mul_(mask)
98 | ################################################################################
99 |
100 | if (data_iter_step + 1) % update_freq == 0:
101 | optimizer.step()
102 | optimizer.zero_grad()
103 | if model_ema is not None:
104 | model_ema.update(model)
105 |
106 | torch.cuda.synchronize()
107 |
108 | if mixup_fn is None:
109 | class_acc = (output.max(-1)[-1] == targets).float().mean()
110 | else:
111 | class_acc = None
112 | metric_logger.update(loss=loss_value)
113 | metric_logger.update(class_acc=class_acc)
114 | min_lr = 10.
115 | max_lr = 0.
116 | for group in optimizer.param_groups:
117 | min_lr = min(min_lr, group["lr"])
118 | max_lr = max(max_lr, group["lr"])
119 |
120 | metric_logger.update(lr=max_lr)
121 | metric_logger.update(min_lr=min_lr)
122 | weight_decay_value = None
123 | for group in optimizer.param_groups:
124 | if group["weight_decay"] > 0:
125 | weight_decay_value = group["weight_decay"]
126 | metric_logger.update(weight_decay=weight_decay_value)
127 |
128 | if 'dp' in schedules:
129 | metric_logger.update(drop_path=model.module.drop_path)
130 |
131 | if 'do' in schedules:
132 | metric_logger.update(dropout=model.module.drop_rate)
133 |
134 | if use_amp:
135 | metric_logger.update(grad_norm=grad_norm)
136 |
137 | if log_writer is not None:
138 | log_writer.update(loss=loss_value, head="loss")
139 | log_writer.update(class_acc=class_acc, head="loss")
140 | log_writer.update(lr=max_lr, head="opt")
141 | log_writer.update(min_lr=min_lr, head="opt")
142 | log_writer.update(weight_decay=weight_decay_value, head="opt")
143 | if use_amp:
144 | log_writer.update(grad_norm=grad_norm, head="opt")
145 | log_writer.set_step()
146 |
147 | if wandb_logger:
148 | wandb_logger._wandb.log({
149 | 'Rank-0 Batch Wise/train_loss': loss_value,
150 | 'Rank-0 Batch Wise/train_max_lr': max_lr,
151 | 'Rank-0 Batch Wise/train_min_lr': min_lr
152 | }, commit=False)
153 | if class_acc:
154 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False)
155 | if use_amp:
156 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False)
157 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it})
158 |
159 | # gather the stats from all processes
160 | metric_logger.synchronize_between_processes()
161 | print("Averaged stats:", metric_logger)
162 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
163 |
164 | @torch.no_grad()
165 | def evaluate(data_loader, model, device, use_amp=False):
166 | criterion = torch.nn.CrossEntropyLoss()
167 |
168 | metric_logger = utils.MetricLogger(delimiter=" ")
169 | header = 'Test:'
170 |
171 | # switch to evaluation mode
172 | model.eval()
173 | for batch in metric_logger.log_every(data_loader, 10, header):
174 | images = batch[0]
175 | target = batch[-1]
176 |
177 | images = images.to(device, non_blocking=True)
178 | target = target.to(device, non_blocking=True)
179 |
180 | # compute output
181 | if use_amp:
182 | with torch.cuda.amp.autocast():
183 | output = model(images)
184 | loss = criterion(output, target)
185 | else:
186 | output = model(images)
187 | loss = criterion(output, target)
188 |
189 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
190 |
191 | batch_size = images.shape[0]
192 | metric_logger.update(loss=loss.item())
193 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
194 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
195 | # gather the stats from all processes
196 | metric_logger.synchronize_between_processes()
197 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
198 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
199 |
200 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
--------------------------------------------------------------------------------
/image_classifiers/layerwrapper.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | torch.backends.cuda.matmul.allow_tf32 = False
8 | torch.backends.cudnn.allow_tf32 = False
9 |
10 |
11 | class WrappedLayer:
12 |
13 | def __init__(self, layer, layer_id=0, layer_name="none", p_norm=2):
14 | self.layer = layer
15 | self.dev = self.layer.weight.device
16 | self.rows = layer.weight.data.shape[0]
17 | self.columns = layer.weight.data.shape[1]
18 |
19 | self.scaler_row = torch.zeros((self.columns), device=self.dev)
20 | self.nsamples = 0
21 |
22 | self.layer_id = layer_id
23 | self.layer_name = layer_name
24 | self.p_norm = p_norm
25 |
26 | def add_batch(self, inp, out):
27 | assert inp.shape[-1] == self.columns
28 | inp = inp.reshape((-1,self.columns))
29 | tmp = inp.shape[0]
30 | inp = inp.t()
31 |
32 | if self.p_norm == 2:
33 | self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / tmp
34 | elif self.p_norm == 1:
35 | self.scaler_row += torch.norm(inp, p=1, dim=1) ** 1 / tmp
36 |
37 | if torch.isinf(self.scaler_row).sum() > 0:
38 | print("encountered torch.isinf error")
39 | raise ValueError
40 |
41 | def prune(self, W_mask):
42 | self.layer.weight.data[W_mask] = 0
43 | out_ = self.layer(self.inp1)
44 |
45 | dist = (self.out1 - out_).squeeze(dim=0)
46 |
47 | bias = torch.mean(dist, dim=0)
48 | self.layer.bias = nn.Parameter(bias)
49 |
50 | def free(self):
51 | if DEBUG:
52 | self.inp1 = None
53 | self.out1 = None
54 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------
/image_classifiers/main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import argparse
9 | import datetime
10 | import numpy as np
11 | import time
12 | import torch
13 | import torch.nn as nn
14 | import torch.backends.cudnn as cudnn
15 | import json
16 | import os
17 |
18 | from pathlib import Path
19 |
20 | from timm.data.mixup import Mixup
21 | from timm.models import create_model
22 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
23 | from timm.utils import ModelEma
24 | from optim_factory import create_optimizer, LayerDecayValueAssigner
25 |
26 | from datasets import build_dataset
27 | from engine import train_one_epoch, evaluate
28 |
29 | import utils
30 |
31 | import models.convnext
32 | import models.vision_transformer
33 | import models.swin_transformer
34 | import models.mlp_mixer
35 | import models.deit
36 |
37 | from prune_utils import prune_convnext, prune_deit, prune_vit, check_sparsity
38 |
39 | def str2bool(v):
40 | """
41 | Converts string to bool type; enables command line
42 | arguments in the format of '--arg1 true --arg2 false'
43 | """
44 | if isinstance(v, bool):
45 | return v
46 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
47 | return True
48 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
49 | return False
50 | else:
51 | raise argparse.ArgumentTypeError('Boolean value expected.')
52 |
53 | def get_args_parser():
54 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False)
55 | parser.add_argument('--batch_size', default=256, type=int,
56 | help='Per GPU batch size')
57 | parser.add_argument('--epochs', default=300, type=int)
58 | parser.add_argument('--update_freq', default=1, type=int,
59 | help='gradient accumulation steps')
60 |
61 | # Model parameters
62 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL',
63 | help='Name of model to train')
64 | parser.add_argument('--input_size', default=224, type=int,
65 | help='image input size')
66 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float,
67 | help="Layer scale initial values")
68 |
69 | ########################## settings specific to this project ##########################
70 |
71 | # dropout and stochastic depth drop rate; set at most one to non-zero
72 | parser.add_argument('--dropout', type=float, default=0, metavar='PCT',
73 | help='Drop path rate (default: 0.0)')
74 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT',
75 | help='Drop path rate (default: 0.0)')
76 |
77 | # early / late dropout and stochastic depth settings
78 | parser.add_argument('--drop_mode', type=str, default='standard', choices=['standard', 'early', 'late'], help='drop mode')
79 | parser.add_argument('--drop_schedule', type=str, default='constant', choices=['constant', 'linear'],
80 | help='drop schedule for early dropout / s.d. only')
81 | parser.add_argument('--cutoff_epoch', type=int, default=0,
82 | help='if drop_mode is early / late, this is the epoch where dropout ends / starts')
83 |
84 | #######################################################################################
85 |
86 | # EMA related parameters
87 | parser.add_argument('--model_ema', type=str2bool, default=False)
88 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
89 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='')
90 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.')
91 |
92 | # Optimization parameters
93 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
94 | help='Optimizer (default: "adamw"')
95 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
96 | help='Optimizer Epsilon (default: 1e-8)')
97 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
98 | help='Optimizer Betas (default: None, use opt default)')
99 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
100 | help='Clip gradient norm (default: None, no clipping)')
101 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
102 | help='SGD momentum (default: 0.9)')
103 | parser.add_argument('--weight_decay', type=float, default=0.05,
104 | help='weight decay (default: 0.05)')
105 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
106 | weight decay. We use a cosine schedule for WD and using a larger decay by
107 | the end of training improves performance for ViTs.""")
108 |
109 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR',
110 | help='learning rate (default: 4e-3), with total batch size 4096')
111 | parser.add_argument('--layer_decay', type=float, default=1.0)
112 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
113 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
114 | parser.add_argument('--warmup_epochs', type=int, default=50, metavar='N',
115 | help='epochs to warmup LR, if scheduler supports')
116 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
117 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
118 |
119 | # Augmentation parameters
120 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
121 | help='Color jitter factor (default: 0.4)')
122 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
123 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
124 | parser.add_argument('--smoothing', type=float, default=0.1,
125 | help='Label smoothing (default: 0.1)')
126 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
127 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
128 |
129 | # Evaluation parameters
130 | parser.add_argument('--crop_pct', type=float, default=None)
131 |
132 | # * Random Erase params
133 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
134 | help='Random erase prob (default: 0.25)')
135 | parser.add_argument('--remode', type=str, default='pixel',
136 | help='Random erase mode (default: "pixel")')
137 | parser.add_argument('--recount', type=int, default=1,
138 | help='Random erase count (default: 1)')
139 | parser.add_argument('--resplit', type=str2bool, default=False,
140 | help='Do not random erase first (clean) augmentation split')
141 |
142 | # * Mixup params
143 | parser.add_argument('--mixup', type=float, default=0.8,
144 | help='mixup alpha, mixup enabled if > 0.')
145 | parser.add_argument('--cutmix', type=float, default=1.0,
146 | help='cutmix alpha, cutmix enabled if > 0.')
147 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
148 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
149 | parser.add_argument('--mixup_prob', type=float, default=1.0,
150 | help='Probability of performing mixup or cutmix when either/both is enabled')
151 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
152 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
153 | parser.add_argument('--mixup_mode', type=str, default='batch',
154 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
155 |
156 | # * Finetuning params
157 | parser.add_argument('--finetune', default='',
158 | help='finetune from checkpoint')
159 | parser.add_argument('--head_init_scale', default=1.0, type=float,
160 | help='classifier head initial scale, typically adjusted in fine-tuning')
161 | parser.add_argument('--model_key', default='model|module', type=str,
162 | help='which key to load from saved state dict, usually model or model_ema')
163 | parser.add_argument('--model_prefix', default='', type=str)
164 |
165 | # Dataset parameters
166 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
167 | help='dataset path')
168 | parser.add_argument('--eval_data_path', default=None, type=str,
169 | help='dataset path for evaluation')
170 | parser.add_argument('--nb_classes', default=1000, type=int,
171 | help='number of the classification types')
172 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True)
173 | parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'image_folder'],
174 | type=str, help='ImageNet dataset path')
175 | parser.add_argument('--output_dir', default='',
176 | help='path where to save, empty for no saving')
177 | parser.add_argument('--device', default='cuda',
178 | help='device to use for training / testing')
179 | parser.add_argument('--seed', default=0, type=int)
180 |
181 | parser.add_argument('--resume', default='',
182 | help='resume from checkpoint')
183 | parser.add_argument('--auto_resume', type=str2bool, default=True)
184 | parser.add_argument('--save_ckpt', type=str2bool, default=True)
185 | parser.add_argument('--save_ckpt_freq', default=1, type=int)
186 | parser.add_argument('--save_ckpt_num', default=3, type=int)
187 |
188 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
189 | help='start epoch')
190 | parser.add_argument('--eval', type=str2bool, default=False,
191 | help='Perform evaluation only')
192 | parser.add_argument('--dist_eval', type=str2bool, default=True,
193 | help='Enabling distributed evaluation')
194 | parser.add_argument('--disable_eval', type=str2bool, default=False,
195 | help='Disabling evaluation during training')
196 | parser.add_argument('--num_workers', default=10, type=int)
197 | parser.add_argument('--pin_mem', type=str2bool, default=True,
198 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
199 |
200 | # distributed training parameters
201 | parser.add_argument('--world_size', default=1, type=int,
202 | help='number of distributed processes')
203 | parser.add_argument('--local_rank', default=-1, type=int)
204 | parser.add_argument('--dist_on_itp', type=str2bool, default=False)
205 | parser.add_argument('--dist_url', default='env://',
206 | help='url used to set up distributed training')
207 |
208 | parser.add_argument('--use_amp', type=str2bool, default=False,
209 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not")
210 |
211 | # Weights and Biases arguments
212 | parser.add_argument('--enable_wandb', type=str2bool, default=False,
213 | help="enable logging to Weights and Biases")
214 | parser.add_argument('--project', default='convnext', type=str,
215 | help="The name of the W&B project where you're sending the new run.")
216 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False,
217 | help="Save model checkpoints as W&B Artifacts.")
218 |
219 | # arguments for pruning
220 | parser.add_argument("--nsamples", type=int, default=4096)
221 | parser.add_argument("--sparsity", type=float, default=0.)
222 | parser.add_argument("--prune_metric", type=str, choices=["magnitude", "wanda"])
223 | parser.add_argument("--prune_granularity", type=str)
224 | parser.add_argument("--blocksize", type=int, default=1)
225 |
226 | return parser
227 |
228 | def main(args):
229 | utils.init_distributed_mode(args)
230 | print(args)
231 | device = torch.device(args.device)
232 |
233 | # fix the seed for reproducibility
234 | seed = args.seed + utils.get_rank()
235 | torch.manual_seed(seed)
236 | np.random.seed(seed)
237 | cudnn.benchmark = True
238 |
239 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
240 | if args.disable_eval:
241 | args.dist_eval = False
242 | dataset_val = None
243 | else:
244 | dataset_val, _ = build_dataset(is_train=False, args=args)
245 |
246 | num_tasks = utils.get_world_size()
247 | global_rank = utils.get_rank()
248 |
249 | sampler_train = torch.utils.data.DistributedSampler(
250 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed,
251 | )
252 | print("Sampler_train = %s" % str(sampler_train))
253 | if args.dist_eval:
254 | if len(dataset_val) % num_tasks != 0:
255 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
256 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
257 | 'equal num of samples per-process.')
258 | sampler_val = torch.utils.data.DistributedSampler(
259 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
260 | else:
261 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
262 |
263 | if global_rank == 0 and args.enable_wandb:
264 | wandb_logger = utils.WandbLogger(args)
265 | else:
266 | wandb_logger = None
267 |
268 |
269 | data_loader_train = torch.utils.data.DataLoader(
270 | dataset_train, sampler=sampler_train,
271 | batch_size=args.batch_size,
272 | num_workers=args.num_workers,
273 | pin_memory=args.pin_mem,
274 | drop_last=True,
275 | )
276 |
277 | if dataset_val is not None:
278 | data_loader_val = torch.utils.data.DataLoader(
279 | dataset_val, sampler=sampler_val,
280 | batch_size=int(1.5 * args.batch_size),
281 | num_workers=args.num_workers,
282 | pin_memory=args.pin_mem,
283 | drop_last=False
284 | )
285 | else:
286 | data_loader_val = None
287 |
288 | model = utils.build_model(args, pretrained=False)
289 | model.cuda()
290 | if args.distributed:
291 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
292 | model_without_ddp = model.module
293 |
294 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
295 | print('number of params:', n_parameters)
296 |
297 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
298 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size
299 |
300 | # At most one of dropout and stochastic depth should be enabled.
301 | assert(args.dropout == 0 or args.drop_path == 0)
302 | # ConvNeXt does not support dropout.
303 | assert(args.dropout == 0 if args.model.startswith("convnext") else True)
304 |
305 |
306 | if "convnext" in args.model:
307 | checkpoint = torch.load(args.resume, map_location='cpu')
308 | model.load_state_dict(checkpoint["model"])
309 | elif "vit" in args.model:
310 | checkpoint = torch.load(args.resume, map_location='cpu')
311 | model.load_state_dict(checkpoint)
312 | elif "deit" in args.model:
313 | checkpoint = torch.load(args.resume, map_location='cpu')
314 | model.load_state_dict(checkpoint["model"])
315 |
316 | ################################################################################
317 | np.random.seed(0)
318 | calibration_ids = np.random.choice(len(dataset_train), args.nsamples)
319 | calib_data = []
320 | for i in calibration_ids:
321 | calib_data.append(dataset_train[i][0].unsqueeze(dim=0))
322 | calib_data = torch.cat(calib_data, dim=0).to(device)
323 |
324 | tick = time.time()
325 | if args.sparsity != 0:
326 | with torch.no_grad():
327 | if "convnext" in args.model:
328 | prune_convnext(args, model, calib_data, device)
329 | elif "vit" in args.model:
330 | prune_vit(args, model, calib_data, device)
331 | elif "deit" in args.model:
332 | prune_vit(args, model, calib_data, device)
333 | ################################################################################
334 |
335 | actual_sparsity = check_sparsity(model)
336 | print(f"actual sparsity {actual_sparsity}")
337 |
338 | print(f"Eval only mode")
339 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
340 | val_acc = test_stats["acc1"]
341 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%")
342 | return
343 |
344 | if __name__ == '__main__':
345 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()])
346 | args = parser.parse_args()
347 | if args.output_dir:
348 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
349 | main(args)
--------------------------------------------------------------------------------
/image_classifiers/models/convnext.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from timm.models.layers import trunc_normal_, DropPath
12 | from timm.models.registry import register_model
13 |
14 | class Block(nn.Module):
15 | r""" ConvNeXt Block. There are two equivalent implementations:
16 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
17 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
18 | We use (2) as we find it slightly faster in PyTorch
19 |
20 | Args:
21 | dim (int): Number of input channels.
22 | drop_path (float): Stochastic depth rate. Default: 0.0
23 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
24 | """
25 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, drop_rate=0.):
26 | super().__init__()
27 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
28 | self.norm = LayerNorm(dim, eps=1e-6)
29 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
30 | self.act = nn.GELU()
31 | self.pwconv2 = nn.Linear(4 * dim, dim)
32 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
33 | requires_grad=True) if layer_scale_init_value > 0 else None
34 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
35 | self.dropout = nn.Dropout(drop_rate)
36 |
37 | def forward(self, x):
38 | input = x
39 | x = self.dwconv(x)
40 | x = self.dropout(x)
41 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
42 | x = self.norm(x)
43 | x = self.pwconv1(x)
44 | x = self.act(x)
45 | x = self.dropout(x)
46 | x = self.pwconv2(x)
47 | x = self.dropout(x)
48 | if self.gamma is not None:
49 | x = self.gamma * x
50 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
51 | x = input + self.drop_path(x)
52 | return x
53 |
54 |
55 | class ConvNeXt(nn.Module):
56 | r""" ConvNeXt
57 | A PyTorch impl of : `A ConvNet for the 2020s` -
58 | https://arxiv.org/pdf/2201.03545.pdf
59 |
60 | Args:
61 | in_chans (int): Number of input image channels. Default: 3
62 | num_classes (int): Number of classes for classification head. Default: 1000
63 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
64 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
65 | drop_path_rate (float): Stochastic depth rate. Default: 0.
66 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
67 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
68 | drop_rate (float): Dropout rate
69 | """
70 | def __init__(self, in_chans=3, num_classes=1000,
71 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
72 | layer_scale_init_value=1e-6, head_init_scale=1., drop_rate=0.
73 | ):
74 | super().__init__()
75 |
76 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
77 | stem = nn.Sequential(
78 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
79 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
80 | )
81 | self.downsample_layers.append(stem)
82 | for i in range(3):
83 | downsample_layer = nn.Sequential(
84 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
85 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
86 | )
87 | self.downsample_layers.append(downsample_layer)
88 |
89 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
90 | self.depths = depths
91 | self.drop_path = drop_path_rate
92 | self.drop_rate = drop_rate
93 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
94 | cur = 0
95 | for i in range(4):
96 | stage = nn.Sequential(
97 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
98 | layer_scale_init_value=layer_scale_init_value, drop_rate=drop_rate) for j in range(depths[i])]
99 | )
100 | self.stages.append(stage)
101 | cur += depths[i]
102 |
103 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
104 | self.head = nn.Linear(dims[-1], num_classes)
105 |
106 | self.apply(self._init_weights)
107 | self.head.weight.data.mul_(head_init_scale)
108 | self.head.bias.data.mul_(head_init_scale)
109 |
110 | def _init_weights(self, m):
111 | if isinstance(m, (nn.Conv2d, nn.Linear)):
112 | trunc_normal_(m.weight, std=.02)
113 | nn.init.constant_(m.bias, 0)
114 |
115 | def forward_features(self, x):
116 | for i in range(4):
117 | x = self.downsample_layers[i](x)
118 | x = self.stages[i](x)
119 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
120 |
121 | def forward(self, x):
122 | x = self.forward_features(x)
123 | x = self.head(x)
124 | return x
125 |
126 | def update_drop_path(self, drop_path_rate):
127 | self.drop_path = drop_path_rate
128 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
129 | cur = 0
130 | for i in range(4):
131 | for j in range(self.depths[i]):
132 | self.stages[i][j].drop_path.drop_prob = dp_rates[cur + j]
133 | cur += self.depths[i]
134 |
135 | class LayerNorm(nn.Module):
136 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
137 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
138 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
139 | with shape (batch_size, channels, height, width).
140 | """
141 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
142 | super().__init__()
143 | self.weight = nn.Parameter(torch.ones(normalized_shape))
144 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
145 | self.eps = eps
146 | self.data_format = data_format
147 | if self.data_format not in ["channels_last", "channels_first"]:
148 | raise NotImplementedError
149 | self.normalized_shape = (normalized_shape, )
150 |
151 | def forward(self, x):
152 | if self.data_format == "channels_last":
153 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
154 | elif self.data_format == "channels_first":
155 | u = x.mean(1, keepdim=True)
156 | s = (x - u).pow(2).mean(1, keepdim=True)
157 | x = (x - u) / torch.sqrt(s + self.eps)
158 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
159 | return x
160 |
161 | model_urls = {
162 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
163 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
164 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
165 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
166 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
167 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
168 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
169 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
170 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
171 | }
172 |
173 | @register_model
174 | def convnext_atto(pretrained=False, **kwargs):
175 | model = ConvNeXt(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs)
176 | return model
177 |
178 | @register_model
179 | def convnext_mini(pretrained=False, **kwargs):
180 | model = ConvNeXt(depths=[2, 2, 4, 2], dims=[48, 96, 192, 384], **kwargs)
181 | return model
182 |
183 | @register_model
184 | def convnext_femto(pretrained=False, **kwargs):
185 | model = ConvNeXt(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs)
186 | return model
187 |
188 | @register_model
189 | def convnext_pico(pretrained=False, **kwargs):
190 | # timm pico variant
191 | model = ConvNeXt(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs)
192 | return model
193 |
194 | @register_model
195 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
196 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
197 | return model
198 |
199 | @register_model
200 | def convnext_small(pretrained=False,in_22k=False, **kwargs):
201 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
202 | return model
203 |
204 | @register_model
205 | def convnext_base(pretrained=False, in_22k=False, **kwargs):
206 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
207 | return model
208 |
209 | @register_model
210 | def convnext_large(pretrained=False, in_22k=False, **kwargs):
211 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
212 | return model
213 |
214 | @register_model
215 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
216 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
217 | return model
218 |
--------------------------------------------------------------------------------
/image_classifiers/models/deit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import torch
4 | import torch.nn as nn
5 | from functools import partial
6 |
7 | from timm.models.vision_transformer import VisionTransformer, _cfg
8 | from timm.models.registry import register_model
9 | from timm.models.layers import trunc_normal_
10 |
11 |
12 | __all__ = [
13 | 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
14 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
15 | 'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
16 | 'deit_base_distilled_patch16_384',
17 | ]
18 |
19 |
20 | class DistilledVisionTransformer(VisionTransformer):
21 | def __init__(self, *args, **kwargs):
22 | super().__init__(*args, **kwargs)
23 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
24 | num_patches = self.patch_embed.num_patches
25 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
26 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
27 |
28 | trunc_normal_(self.dist_token, std=.02)
29 | trunc_normal_(self.pos_embed, std=.02)
30 | self.head_dist.apply(self._init_weights)
31 |
32 | def forward_features(self, x):
33 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
34 | # with slight modifications to add the dist_token
35 | B = x.shape[0]
36 | x = self.patch_embed(x)
37 |
38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
39 | dist_token = self.dist_token.expand(B, -1, -1)
40 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
41 |
42 | x = x + self.pos_embed
43 | x = self.pos_drop(x)
44 |
45 | for blk in self.blocks:
46 | x = blk(x)
47 |
48 | x = self.norm(x)
49 | return x[:, 0], x[:, 1]
50 |
51 | def forward(self, x):
52 | x, x_dist = self.forward_features(x)
53 | x = self.head(x)
54 | x_dist = self.head_dist(x_dist)
55 | if self.training:
56 | return x, x_dist
57 | else:
58 | # during inference, return the average of both classifier predictions
59 | return (x + x_dist) / 2
60 |
61 |
62 | @register_model
63 | def deit_tiny_patch16_224(pretrained=False, **kwargs):
64 | model = VisionTransformer(
65 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
66 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
67 | model.default_cfg = _cfg()
68 | if pretrained:
69 | checkpoint = torch.hub.load_state_dict_from_url(
70 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
71 | map_location="cpu", check_hash=True
72 | )
73 | model.load_state_dict(checkpoint["model"])
74 | return model
75 |
76 |
77 | @register_model
78 | def deit_small_patch16_224(pretrained=False, **kwargs):
79 | model = VisionTransformer(
80 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
81 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
82 | model.default_cfg = _cfg()
83 | if pretrained:
84 | checkpoint = torch.hub.load_state_dict_from_url(
85 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
86 | map_location="cpu", check_hash=True
87 | )
88 | model.load_state_dict(checkpoint["model"])
89 | return model
90 |
91 |
92 | @register_model
93 | def deit_base_patch16_224(pretrained=False, **kwargs):
94 | model = VisionTransformer(
95 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
96 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
97 | model.default_cfg = _cfg()
98 | if pretrained:
99 | checkpoint = torch.hub.load_state_dict_from_url(
100 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
101 | map_location="cpu", check_hash=True
102 | )
103 | model.load_state_dict(checkpoint["model"])
104 | return model
105 |
106 |
107 | @register_model
108 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
109 | model = DistilledVisionTransformer(
110 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
111 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
112 | model.default_cfg = _cfg()
113 | if pretrained:
114 | checkpoint = torch.hub.load_state_dict_from_url(
115 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
116 | map_location="cpu", check_hash=True
117 | )
118 | model.load_state_dict(checkpoint["model"])
119 | return model
120 |
121 |
122 | @register_model
123 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
124 | model = DistilledVisionTransformer(
125 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
126 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
127 | model.default_cfg = _cfg()
128 | if pretrained:
129 | checkpoint = torch.hub.load_state_dict_from_url(
130 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
131 | map_location="cpu", check_hash=True
132 | )
133 | model.load_state_dict(checkpoint["model"])
134 | return model
135 |
136 |
137 | @register_model
138 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
139 | model = DistilledVisionTransformer(
140 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
141 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
142 | model.default_cfg = _cfg()
143 | if pretrained:
144 | checkpoint = torch.hub.load_state_dict_from_url(
145 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
146 | map_location="cpu", check_hash=True
147 | )
148 | model.load_state_dict(checkpoint["model"])
149 | return model
150 |
151 |
152 | @register_model
153 | def deit_base_patch16_384(pretrained=False, **kwargs):
154 | model = VisionTransformer(
155 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
156 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
157 | model.default_cfg = _cfg()
158 | if pretrained:
159 | checkpoint = torch.hub.load_state_dict_from_url(
160 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
161 | map_location="cpu", check_hash=True
162 | )
163 | model.load_state_dict(checkpoint["model"])
164 | return model
165 |
166 |
167 | @register_model
168 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
169 | model = DistilledVisionTransformer(
170 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
171 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
172 | model.default_cfg = _cfg()
173 | if pretrained:
174 | checkpoint = torch.hub.load_state_dict_from_url(
175 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
176 | map_location="cpu", check_hash=True
177 | )
178 | model.load_state_dict(checkpoint["model"])
179 | return model
--------------------------------------------------------------------------------
/image_classifiers/models/mlp_mixer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import math
9 | from copy import deepcopy
10 | from functools import partial
11 |
12 | import torch
13 | import torch.nn as nn
14 |
15 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16 | from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply
17 | from timm.models.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
18 | from timm.models.registry import register_model
19 |
20 |
21 | def _cfg(url='', **kwargs):
22 | return {
23 | 'url': url,
24 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
25 | 'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
26 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
27 | 'first_conv': 'stem.proj', 'classifier': 'head',
28 | **kwargs
29 | }
30 |
31 |
32 | default_cfgs = dict(
33 | mixer_s32_224=_cfg(),
34 | mixer_s16_224=_cfg(),
35 | mixer_b32_224=_cfg(),
36 | mixer_b16_224=_cfg(
37 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
38 | ),
39 | mixer_b16_224_in21k=_cfg(
40 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
41 | num_classes=21843
42 | ),
43 | mixer_l32_224=_cfg(),
44 | mixer_l16_224=_cfg(
45 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
46 | ),
47 | mixer_l16_224_in21k=_cfg(
48 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
49 | num_classes=21843
50 | ),
51 |
52 | # Mixer ImageNet-21K-P pretraining
53 | mixer_b16_224_miil_in21k=_cfg(
54 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth',
55 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
56 | ),
57 | mixer_b16_224_miil=_cfg(
58 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth',
59 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
60 | ),
61 |
62 | gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
63 | gmixer_24_224=_cfg(
64 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth',
65 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
66 |
67 | resmlp_12_224=_cfg(
68 | url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
69 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
70 | resmlp_24_224=_cfg(
71 | url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
72 | #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
73 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
74 | resmlp_36_224=_cfg(
75 | url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
76 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
77 | resmlp_big_24_224=_cfg(
78 | url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
79 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
80 |
81 | resmlp_12_distilled_224=_cfg(
82 | url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
83 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
84 | resmlp_24_distilled_224=_cfg(
85 | url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
86 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
87 | resmlp_36_distilled_224=_cfg(
88 | url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
89 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
90 | resmlp_big_24_distilled_224=_cfg(
91 | url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
92 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
93 |
94 | resmlp_big_24_224_in22ft1k=_cfg(
95 | url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
96 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
97 |
98 | gmlp_ti16_224=_cfg(),
99 | gmlp_s16_224=_cfg(
100 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
101 | ),
102 | gmlp_b16_224=_cfg(),
103 | )
104 |
105 |
106 | class MixerBlock(nn.Module):
107 | """ Residual Block w/ token mixing and channel MLPs
108 | Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
109 | """
110 | def __init__(
111 | self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp,
112 | norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
113 | super().__init__()
114 | tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
115 | self.norm1 = norm_layer(dim)
116 | self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
117 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
118 | self.norm2 = norm_layer(dim)
119 | self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
120 |
121 | def forward(self, x):
122 | x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
123 | x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
124 | return x
125 |
126 |
127 | class MlpMixer(nn.Module):
128 |
129 | def __init__(
130 | self,
131 | num_classes=1000,
132 | img_size=224,
133 | in_chans=3,
134 | patch_size=16,
135 | num_blocks=8,
136 | embed_dim=512,
137 | mlp_ratio=(0.5, 4.0),
138 | block_layer=MixerBlock,
139 | mlp_layer=Mlp,
140 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
141 | act_layer=nn.GELU,
142 | drop_rate=0.,
143 | drop_path_rate=0.,
144 | nlhb=False,
145 | stem_norm=False,
146 | ):
147 | super().__init__()
148 | self.num_classes = num_classes
149 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
150 |
151 | self.stem = PatchEmbed(
152 | img_size=img_size, patch_size=patch_size, in_chans=in_chans,
153 | embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None)
154 | # FIXME drop_path (stochastic depth scaling rule or all the same?)
155 | self.drop_path = drop_path_rate
156 | self.drop_rate = drop_rate
157 | self.num_blocks = num_blocks
158 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
159 | self.blocks = nn.Sequential(*[
160 | block_layer(
161 | embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
162 | act_layer=act_layer, drop=drop_rate, drop_path=dpr[i])
163 | for i in range(num_blocks)])
164 | self.norm = norm_layer(embed_dim)
165 | self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
166 |
167 | self.init_weights(nlhb=nlhb)
168 |
169 | def init_weights(self, nlhb=False):
170 | head_bias = -math.log(self.num_classes) if nlhb else 0.
171 | named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
172 |
173 | def get_classifier(self):
174 | return self.head
175 |
176 | def reset_classifier(self, num_classes, global_pool=''):
177 | self.num_classes = num_classes
178 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
179 |
180 | def forward_features(self, x):
181 | x = self.stem(x)
182 | x = self.blocks(x)
183 | x = self.norm(x)
184 | x = x.mean(dim=1)
185 | return x
186 |
187 | def forward(self, x):
188 | x = self.forward_features(x)
189 | x = self.head(x)
190 | return x
191 |
192 | def update_drop_path(self, drop_path_rate):
193 | self.drop_path = drop_path_rate
194 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_blocks)]
195 | cur = 0
196 | for block in self.blocks:
197 | block.drop_path.drop_prob = dpr[cur]
198 | cur += 1
199 | assert cur == self.num_blocks
200 |
201 | def update_dropout(self, drop_rate):
202 | self.drop_rate = drop_rate
203 | for module in self.modules():
204 | if isinstance(module, nn.Dropout):
205 | module.p = drop_rate
206 |
207 |
208 | def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
209 | """ Mixer weight initialization (trying to match Flax defaults)
210 | """
211 | if isinstance(module, nn.Linear):
212 | if name.startswith('head'):
213 | nn.init.zeros_(module.weight)
214 | nn.init.constant_(module.bias, head_bias)
215 | else:
216 | if flax:
217 | # Flax defaults
218 | lecun_normal_(module.weight)
219 | if module.bias is not None:
220 | nn.init.zeros_(module.bias)
221 | else:
222 | # like MLP init in vit (my original init)
223 | nn.init.xavier_uniform_(module.weight)
224 | if module.bias is not None:
225 | if 'mlp' in name:
226 | nn.init.normal_(module.bias, std=1e-6)
227 | else:
228 | nn.init.zeros_(module.bias)
229 | elif isinstance(module, nn.Conv2d):
230 | lecun_normal_(module.weight)
231 | if module.bias is not None:
232 | nn.init.zeros_(module.bias)
233 | elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
234 | nn.init.ones_(module.weight)
235 | nn.init.zeros_(module.bias)
236 | elif hasattr(module, 'init_weights'):
237 | # NOTE if a parent module contains init_weights method, it can override the init of the
238 | # child modules as this will be called in depth-first order.
239 | module.init_weights()
240 |
241 |
242 | def checkpoint_filter_fn(state_dict, model):
243 | """ Remap checkpoints if needed """
244 | if 'patch_embed.proj.weight' in state_dict:
245 | # Remap FB ResMlp models -> timm
246 | out_dict = {}
247 | for k, v in state_dict.items():
248 | k = k.replace('patch_embed.', 'stem.')
249 | k = k.replace('attn.', 'linear_tokens.')
250 | k = k.replace('mlp.', 'mlp_channels.')
251 | k = k.replace('gamma_', 'ls')
252 | if k.endswith('.alpha') or k.endswith('.beta'):
253 | v = v.reshape(1, 1, -1)
254 | out_dict[k] = v
255 | return out_dict
256 | return state_dict
257 |
258 |
259 | def _create_mixer(variant, pretrained=False, **kwargs):
260 | if kwargs.get('features_only', None):
261 | raise RuntimeError('features_only not implemented for MLP-Mixer models.')
262 |
263 | model = build_model_with_cfg(
264 | MlpMixer, variant, pretrained,
265 | default_cfg=default_cfgs[variant],
266 | pretrained_filter_fn=checkpoint_filter_fn,
267 | **kwargs)
268 | return model
269 |
270 | @register_model
271 | def mixer_t32(pretrained=False, **kwargs):
272 | """ Mixer-S/32 224x224
273 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
274 | """
275 | model = MlpMixer(patch_size=32, num_blocks=8, embed_dim=256, **kwargs)
276 | return model
277 |
278 |
279 | @register_model
280 | def mixer_s32(pretrained=False, **kwargs):
281 | """ Mixer-S/32 224x224
282 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
283 | """
284 | model = MlpMixer(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)
285 | return model
286 |
287 |
288 | @register_model
289 | def mixer_s16(pretrained=False, **kwargs):
290 | """ Mixer-S/16 224x224
291 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
292 | """
293 | model = MlpMixer(patch_size=16, num_blocks=8, embed_dim=512, **kwargs)
294 | return model
295 |
296 |
297 | @register_model
298 | def mixer_b32(pretrained=False, **kwargs):
299 | """ Mixer-B/32 224x224
300 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
301 | """
302 | model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs)
303 | model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
304 | return model
305 |
306 |
307 | @register_model
308 | def mixer_b16(pretrained=False, **kwargs):
309 | """ Mixer-B/16 224x224. ImageNet-1k pretrained weights.
310 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
311 | """
312 | model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
313 | model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
314 | return model
315 |
316 |
--------------------------------------------------------------------------------
/image_classifiers/models/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | import torch.nn as nn
10 | from functools import partial
11 |
12 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13 | from timm.models.helpers import load_pretrained
14 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
15 | from timm.models.resnet import resnet26d, resnet50d
16 | from timm.models.registry import register_model
17 |
18 |
19 | def _cfg(url='', **kwargs):
20 | return {
21 | 'url': url,
22 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
23 | 'crop_pct': .9, 'interpolation': 'bicubic',
24 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
25 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
26 | **kwargs
27 | }
28 |
29 |
30 | default_cfgs = {
31 | # patch models
32 | 'vit_small_patch16_224': _cfg(
33 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
34 | ),
35 | 'vit_base_patch16_224': _cfg(
36 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
37 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
38 | ),
39 | 'vit_base_patch16_384': _cfg(
40 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
41 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
42 | 'vit_base_patch32_384': _cfg(
43 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
44 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
45 | 'vit_large_patch16_224': _cfg(
46 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
47 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
48 | 'vit_large_patch16_384': _cfg(
49 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
50 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
51 | 'vit_large_patch32_384': _cfg(
52 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
53 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
54 | 'vit_huge_patch16_224': _cfg(),
55 | 'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)),
56 | # hybrid models
57 | 'vit_small_resnet26d_224': _cfg(),
58 | 'vit_small_resnet50d_s3_224': _cfg(),
59 | 'vit_base_resnet26d_224': _cfg(),
60 | 'vit_base_resnet50d_224': _cfg(),
61 | }
62 |
63 |
64 | class Mlp(nn.Module):
65 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
66 | super().__init__()
67 | out_features = out_features or in_features
68 | hidden_features = hidden_features or in_features
69 | self.fc1 = nn.Linear(in_features, hidden_features)
70 | self.act = act_layer()
71 | self.fc2 = nn.Linear(hidden_features, out_features)
72 | self.drop = nn.Dropout(drop)
73 |
74 | def forward(self, x):
75 | x = self.fc1(x)
76 | x = self.act(x)
77 | x = self.drop(x)
78 | x = self.fc2(x)
79 | x = self.drop(x)
80 | return x
81 |
82 |
83 | class Attention(nn.Module):
84 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
85 | super().__init__()
86 | self.num_heads = num_heads
87 | head_dim = dim // num_heads
88 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
89 | self.scale = qk_scale or head_dim ** -0.5
90 |
91 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
92 | self.attn_drop = nn.Dropout(attn_drop)
93 | self.proj = nn.Linear(dim, dim)
94 | self.proj_drop = nn.Dropout(proj_drop)
95 |
96 | def forward(self, x):
97 | B, N, C = x.shape
98 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
99 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
100 |
101 | attn = (q @ k.transpose(-2, -1)) * self.scale
102 | attn = attn.softmax(dim=-1)
103 | attn = self.attn_drop(attn)
104 |
105 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
106 | x = self.proj(x)
107 | x = self.proj_drop(x)
108 | return x
109 |
110 |
111 | class Block(nn.Module):
112 |
113 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
114 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
115 | super().__init__()
116 | self.norm1 = norm_layer(dim)
117 | self.attn = Attention(
118 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
119 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
120 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
121 | self.norm2 = norm_layer(dim)
122 | mlp_hidden_dim = int(dim * mlp_ratio)
123 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
124 |
125 | def forward(self, x):
126 | x = x + self.drop_path(self.attn(self.norm1(x)))
127 | x = x + self.drop_path(self.mlp(self.norm2(x)))
128 | return x
129 |
130 |
131 | class PatchEmbed(nn.Module):
132 | """ Image to Patch Embedding
133 | """
134 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
135 | super().__init__()
136 | img_size = to_2tuple(img_size)
137 | patch_size = to_2tuple(patch_size)
138 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
139 | self.img_size = img_size
140 | self.patch_size = patch_size
141 | self.num_patches = num_patches
142 |
143 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
144 |
145 | def forward(self, x):
146 | B, C, H, W = x.shape
147 | # FIXME look at relaxing size constraints
148 | assert H == self.img_size[0] and W == self.img_size[1], \
149 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
150 | x = self.proj(x).flatten(2).transpose(1, 2)
151 | return x
152 |
153 |
154 | class HybridEmbed(nn.Module):
155 | """ CNN Feature Map Embedding
156 | Extract feature map from CNN, flatten, project to embedding dim.
157 | """
158 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
159 | super().__init__()
160 | assert isinstance(backbone, nn.Module)
161 | img_size = to_2tuple(img_size)
162 | self.img_size = img_size
163 | self.backbone = backbone
164 | if feature_size is None:
165 | with torch.no_grad():
166 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
167 | # map for all networks, the feature metadata has reliable channel and stride info, but using
168 | # stride to calc feature dim requires info about padding of each stage that isn't captured.
169 | training = backbone.training
170 | if training:
171 | backbone.eval()
172 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
173 | feature_size = o.shape[-2:]
174 | feature_dim = o.shape[1]
175 | backbone.train(training)
176 | else:
177 | feature_size = to_2tuple(feature_size)
178 | feature_dim = self.backbone.feature_info.channels()[-1]
179 | self.num_patches = feature_size[0] * feature_size[1]
180 | self.proj = nn.Linear(feature_dim, embed_dim)
181 |
182 | def forward(self, x):
183 | x = self.backbone(x)[-1]
184 | x = x.flatten(2).transpose(1, 2)
185 | x = self.proj(x)
186 | return x
187 |
188 |
189 | class VisionTransformer(nn.Module):
190 | """ Vision Transformer with support for patch or hybrid CNN input stage
191 | """
192 |
193 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
194 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
195 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
196 | super().__init__()
197 | self.num_classes = num_classes
198 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
199 | # I add these two lines
200 | self.drop_rate=drop_rate
201 | attn_drop_rate=drop_rate
202 | if hybrid_backbone is not None:
203 | self.patch_embed = HybridEmbed(
204 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
205 | else:
206 | self.patch_embed = PatchEmbed(
207 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
208 | num_patches = self.patch_embed.num_patches
209 |
210 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
211 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
212 | self.pos_drop = nn.Dropout(p=drop_rate)
213 | self.depth = depth
214 |
215 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
216 | self.blocks = nn.ModuleList([
217 | Block(
218 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
219 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
220 | for i in range(depth)])
221 | self.norm = norm_layer(embed_dim)
222 |
223 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
224 | #self.repr = nn.Linear(embed_dim, representation_size)
225 | #self.repr_act = nn.Tanh()
226 |
227 | # Classifier head
228 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
229 |
230 | trunc_normal_(self.pos_embed, std=.02)
231 | trunc_normal_(self.cls_token, std=.02)
232 | self.apply(self._init_weights)
233 |
234 | def _init_weights(self, m):
235 | if isinstance(m, nn.Linear):
236 | trunc_normal_(m.weight, std=.02)
237 | if isinstance(m, nn.Linear) and m.bias is not None:
238 | nn.init.constant_(m.bias, 0)
239 | elif isinstance(m, nn.LayerNorm):
240 | nn.init.constant_(m.bias, 0)
241 | nn.init.constant_(m.weight, 1.0)
242 |
243 | @torch.jit.ignore
244 | def no_weight_decay(self):
245 | return {'pos_embed', 'cls_token'}
246 |
247 | def get_classifier(self):
248 | return self.head
249 |
250 | def reset_classifier(self, num_classes, global_pool=''):
251 | self.num_classes = num_classes
252 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
253 |
254 | def forward_features(self, x):
255 | B = x.shape[0]
256 | x = self.patch_embed(x)
257 |
258 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
259 | x = torch.cat((cls_tokens, x), dim=1)
260 | x = x + self.pos_embed
261 | x = self.pos_drop(x)
262 |
263 | for blk in self.blocks:
264 | x = blk(x)
265 |
266 | x = self.norm(x)
267 | return x[:, 0]
268 |
269 | def forward(self, x):
270 | x = self.forward_features(x)
271 | x = self.head(x)
272 | return x
273 |
274 | def update_drop_path(self, drop_path_rate):
275 | self.drop_path = drop_path_rate
276 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, self.depth)]
277 | for i in range(self.depth):
278 | self.blocks[i].drop_path.drop_prob = dp_rates[i]
279 |
280 | def update_dropout(self, drop_rate):
281 | self.drop_rate = drop_rate
282 | for module in self.modules():
283 | if isinstance(module, nn.Dropout):
284 | module.p = drop_rate
285 |
286 |
287 | def _conv_filter(state_dict, patch_size=16):
288 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
289 | out_dict = {}
290 | for k, v in state_dict.items():
291 | if 'patch_embed.proj.weight' in k:
292 | v = v.reshape((v.shape[0], 3, patch_size, patch_size))
293 | out_dict[k] = v
294 | return out_dict
295 |
296 | @register_model
297 | def vit_tiny(pretrained=False, **kwargs):
298 | model = VisionTransformer(
299 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
300 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
301 | return model
302 |
303 | @register_model
304 | def vit_small(pretrained=False, **kwargs):
305 | model = VisionTransformer(
306 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
307 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
308 | return model
309 |
310 | @register_model
311 | def vit_base(pretrained=False, **kwargs):
312 | model = VisionTransformer(
313 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
314 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
315 | return model
316 |
317 | @register_model
318 | def vit_large(pretrained=False, **kwargs):
319 | model = VisionTransformer(
320 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
321 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
322 | return model
--------------------------------------------------------------------------------
/image_classifiers/optim_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | from torch import optim as optim
10 |
11 | from timm.optim.adafactor import Adafactor
12 | from timm.optim.adahessian import Adahessian
13 | from timm.optim.adamp import AdamP
14 | from timm.optim.lookahead import Lookahead
15 | from timm.optim.nadam import Nadam
16 | # from timm.optim.novograd import NovoGrad
17 | from timm.optim.nvnovograd import NvNovoGrad
18 | from timm.optim.radam import RAdam
19 | from timm.optim.rmsprop_tf import RMSpropTF
20 | from timm.optim.sgdp import SGDP
21 |
22 | import json
23 |
24 | try:
25 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
26 | has_apex = True
27 | except ImportError:
28 | has_apex = False
29 |
30 |
31 | def get_num_layer_for_convnext(var_name):
32 | """
33 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three
34 | consecutive blocks, including possible neighboring downsample layers;
35 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
36 | """
37 | num_max_layer = 12
38 | if var_name.startswith("downsample_layers"):
39 | stage_id = int(var_name.split('.')[1])
40 | if stage_id == 0:
41 | layer_id = 0
42 | elif stage_id == 1 or stage_id == 2:
43 | layer_id = stage_id + 1
44 | elif stage_id == 3:
45 | layer_id = 12
46 | return layer_id
47 |
48 | elif var_name.startswith("stages"):
49 | stage_id = int(var_name.split('.')[1])
50 | block_id = int(var_name.split('.')[2])
51 | if stage_id == 0 or stage_id == 1:
52 | layer_id = stage_id + 1
53 | elif stage_id == 2:
54 | layer_id = 3 + block_id // 3
55 | elif stage_id == 3:
56 | layer_id = 12
57 | return layer_id
58 | else:
59 | return num_max_layer + 1
60 |
61 | class LayerDecayValueAssigner(object):
62 | def __init__(self, values):
63 | self.values = values
64 |
65 | def get_scale(self, layer_id):
66 | return self.values[layer_id]
67 |
68 | def get_layer_id(self, var_name):
69 | return get_num_layer_for_convnext(var_name)
70 |
71 |
72 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
73 | parameter_group_names = {}
74 | parameter_group_vars = {}
75 |
76 | for name, param in model.named_parameters():
77 | if not param.requires_grad:
78 | continue # frozen weights
79 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
80 | group_name = "no_decay"
81 | this_weight_decay = 0.
82 | else:
83 | group_name = "decay"
84 | this_weight_decay = weight_decay
85 | if get_num_layer is not None:
86 | layer_id = get_num_layer(name)
87 | group_name = "layer_%d_%s" % (layer_id, group_name)
88 | else:
89 | layer_id = None
90 |
91 | if group_name not in parameter_group_names:
92 | if get_layer_scale is not None:
93 | scale = get_layer_scale(layer_id)
94 | else:
95 | scale = 1.
96 |
97 | parameter_group_names[group_name] = {
98 | "weight_decay": this_weight_decay,
99 | "params": [],
100 | "lr_scale": scale
101 | }
102 | parameter_group_vars[group_name] = {
103 | "weight_decay": this_weight_decay,
104 | "params": [],
105 | "lr_scale": scale
106 | }
107 |
108 | parameter_group_vars[group_name]["params"].append(param)
109 | parameter_group_names[group_name]["params"].append(name)
110 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
111 | return list(parameter_group_vars.values())
112 |
113 |
114 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
115 | opt_lower = args.opt.lower()
116 | weight_decay = args.weight_decay
117 | # if weight_decay and filter_bias_and_bn:
118 | if filter_bias_and_bn:
119 | skip = {}
120 | if skip_list is not None:
121 | skip = skip_list
122 | elif hasattr(model, 'no_weight_decay'):
123 | skip = model.no_weight_decay()
124 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
125 | weight_decay = 0.
126 | else:
127 | parameters = model.parameters()
128 |
129 | if 'fused' in opt_lower:
130 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
131 |
132 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
133 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
134 | opt_args['eps'] = args.opt_eps
135 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
136 | opt_args['betas'] = args.opt_betas
137 |
138 | opt_split = opt_lower.split('_')
139 | opt_lower = opt_split[-1]
140 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
141 | opt_args.pop('eps', None)
142 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
143 | elif opt_lower == 'momentum':
144 | opt_args.pop('eps', None)
145 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
146 | elif opt_lower == 'adam':
147 | optimizer = optim.Adam(parameters, **opt_args)
148 | elif opt_lower == 'adamw':
149 | optimizer = optim.AdamW(parameters, **opt_args)
150 | elif opt_lower == 'nadam':
151 | optimizer = Nadam(parameters, **opt_args)
152 | elif opt_lower == 'radam':
153 | optimizer = RAdam(parameters, **opt_args)
154 | elif opt_lower == 'adamp':
155 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
156 | elif opt_lower == 'sgdp':
157 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
158 | elif opt_lower == 'adadelta':
159 | optimizer = optim.Adadelta(parameters, **opt_args)
160 | elif opt_lower == 'adafactor':
161 | if not args.lr:
162 | opt_args['lr'] = None
163 | optimizer = Adafactor(parameters, **opt_args)
164 | elif opt_lower == 'adahessian':
165 | optimizer = Adahessian(parameters, **opt_args)
166 | elif opt_lower == 'rmsprop':
167 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
168 | elif opt_lower == 'rmsproptf':
169 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
170 | elif opt_lower == 'novograd':
171 | optimizer = NovoGrad(parameters, **opt_args)
172 | elif opt_lower == 'nvnovograd':
173 | optimizer = NvNovoGrad(parameters, **opt_args)
174 | elif opt_lower == 'fusedsgd':
175 | opt_args.pop('eps', None)
176 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
177 | elif opt_lower == 'fusedmomentum':
178 | opt_args.pop('eps', None)
179 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
180 | elif opt_lower == 'fusedadam':
181 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
182 | elif opt_lower == 'fusedadamw':
183 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
184 | elif opt_lower == 'fusedlamb':
185 | optimizer = FusedLAMB(parameters, **opt_args)
186 | elif opt_lower == 'fusednovograd':
187 | opt_args.setdefault('betas', (0.95, 0.98))
188 | optimizer = FusedNovoGrad(parameters, **opt_args)
189 | else:
190 | assert False and "Invalid optimizer"
191 |
192 | if len(opt_split) > 1:
193 | if opt_split[0] == 'lookahead':
194 | optimizer = Lookahead(optimizer)
195 |
196 | return optimizer
197 |
--------------------------------------------------------------------------------
/image_classifiers/prune_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from layerwrapper import WrappedLayer
4 |
5 | def find_layers(module, layers=[nn.Linear], name=''):
6 | if type(module) in layers:
7 | return {name: module}
8 | res = {}
9 | for name1, child in module.named_children():
10 | res.update(find_layers(
11 | child, layers=layers, name=name + '.' + name1 if name != '' else name1
12 | ))
13 | return res
14 |
15 | def check_sparsity(model):
16 | subset = find_layers(model, layers=[nn.Linear])
17 | zero_cnt = 0
18 | fc_params = 0
19 | for name in subset:
20 | W = subset[name].weight.data
21 | if W.shape[0] == 1000:
22 | continue
23 | zero_cnt += (W==0).sum().item()
24 | fc_params += W.numel()
25 | return float(zero_cnt) / fc_params
26 |
27 | def compute_mask(W_metric, prune_granularity, sparsity):
28 | if prune_granularity == "layer":
29 | thres = torch.sort(W_metric.flatten().cuda())[0][int(W_metric.numel() * sparsity)].cpu()
30 | W_mask = (W_metric <= thres)
31 | return W_mask
32 | elif prune_granularity == "row":
33 | W_mask = (torch.zeros_like(W_metric)==1)
34 | sort_res = torch.sort(W_metric, dim=-1, stable=True)
35 |
36 | indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity)]
37 | W_mask.scatter_(1, indices, True)
38 | return W_mask
39 |
40 | def prune_deit(args, model, calib_data, device):
41 | inps = calib_data
42 | bs = inps.shape[0]
43 | require_forward = (args.prune_metric in ["wanda"])
44 |
45 | metric_stats = []
46 | for blk in model.blocks:
47 | subset = find_layers(blk)
48 | res_per_layer = {}
49 | for name in subset:
50 | res_per_layer[name] = torch.abs(subset[name].weight.data)
51 | metric_stats.append(res_per_layer)
52 |
53 | thresh = None
54 | #####################################
55 | inps = model.patch_embed(inps)
56 |
57 | cls_tokens = model.cls_token.expand(bs, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
58 | dist_token = model.dist_token.expand(bs, -1, -1)
59 | inps = torch.cat((cls_tokens, dist_token, inps), dim=1)
60 |
61 | inps = inps + model.pos_embed
62 | inps = model.pos_drop(inps)
63 |
64 | for block_id, blk in enumerate(model.blocks):
65 | subset = find_layers(blk)
66 |
67 | if require_forward:
68 | wrapped_layers = {}
69 | for name in subset:
70 | wrapped_layers[name] = WrappedLayer(subset[name])
71 |
72 | def add_batch(name):
73 | def tmp(_, inp, out):
74 | wrapped_layers[name].add_batch(inp[0].data, out.data)
75 | return tmp
76 |
77 | handles = []
78 | for name in wrapped_layers:
79 | handles.append(subset[name].register_forward_hook(add_batch(name)))
80 |
81 | if bs > 256:
82 | tmp_res = []
83 | for i1 in range(0, bs, 256):
84 | j1 = min(i1+256, bs)
85 | tmp_res.append(blk(inps[i1:j1]))
86 | inps = torch.cat(tmp_res, dim=0)
87 | else:
88 | inps = blk(inps)
89 |
90 | for h in handles:
91 | h.remove()
92 |
93 | ################# pruning ###################
94 | for name in subset:
95 | if args.prune_metric == "wanda":
96 | metric_stats[block_id][name] *= torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))
97 |
98 | W_mask = compute_mask(metric_stats[block_id][name], args.prune_granularity, args.sparsity)
99 |
100 | subset[name].weight.data[W_mask] = 0
101 |
102 | def prune_vit(args, model, calib_data, device):
103 | inps = calib_data
104 | bs = inps.shape[0]
105 | require_forward = (args.prune_metric in ["wanda"])
106 |
107 | metric_stats = []
108 | for blk in model.blocks:
109 | subset = find_layers(blk)
110 | res_per_layer = {}
111 | for name in subset:
112 | res_per_layer[name] = torch.abs(subset[name].weight.data)
113 | metric_stats.append(res_per_layer)
114 |
115 | thresh = None
116 | #####################################
117 | inps = model.patch_embed(inps)
118 |
119 | cls_tokens = model.cls_token.expand(bs, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
120 | inps = torch.cat((cls_tokens, inps), dim=1)
121 | inps = inps + model.pos_embed
122 | inps = model.pos_drop(inps)
123 |
124 | for block_id, blk in enumerate(model.blocks):
125 | print(f"block {block_id}")
126 | subset = find_layers(blk)
127 |
128 | if require_forward:
129 | wrapped_layers = {}
130 | for name in subset:
131 | wrapped_layers[name] = WrappedLayer(subset[name])
132 |
133 | def add_batch(name):
134 | def tmp(_, inp, out):
135 | wrapped_layers[name].add_batch(inp[0].data, out.data)
136 | return tmp
137 |
138 | handles = []
139 | for name in wrapped_layers:
140 | handles.append(subset[name].register_forward_hook(add_batch(name)))
141 |
142 | if bs > 256:
143 | tmp_res = []
144 | for i1 in range(0, bs, 256):
145 | j1 = min(i1+256, bs)
146 | tmp_res.append(blk(inps[i1:j1]))
147 | inps = torch.cat(tmp_res, dim=0)
148 | else:
149 | inps = blk(inps)
150 |
151 | for h in handles:
152 | h.remove()
153 |
154 | ################# pruning ###################
155 | for name in subset:
156 | if args.prune_metric == "wanda":
157 | metric_stats[block_id][name] *= torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))
158 |
159 | W_mask = compute_mask(metric_stats[block_id][name], args.prune_granularity, args.sparsity)
160 |
161 | subset[name].weight.data[W_mask] = 0
162 | ##############################################
163 |
164 | def prune_convnext(args, model, calib_data, device):
165 | inps = calib_data
166 | bs = inps.shape[0]
167 | require_forward = (args.prune_metric in ["wanda"])
168 |
169 | ##############################################################
170 | metric_stats = []
171 | for block_id in range(4):
172 | subset = find_layers(model.stages[block_id])
173 | res_per_layer = {}
174 | for name in subset:
175 | res_per_layer[name] = torch.abs(subset[name].weight.data)
176 | metric_stats.append(res_per_layer)
177 | ##############################################################
178 |
179 | thresh = None
180 | for block_id in range(4):
181 | print(f"block {block_id}")
182 | subset = find_layers(model.stages[block_id])
183 |
184 | if require_forward:
185 | layer = model.downsample_layers[block_id]
186 | if bs > 1024:
187 | tmp_res = []
188 | for i1 in range(0, bs, 512):
189 | j1 = min(i1+512, bs)
190 | tmp_res.append(layer(inps[i1:j1]))
191 | inps = torch.cat(tmp_res, dim=0)
192 | else:
193 | inps = layer(inps)
194 |
195 | wrapped_layers = {}
196 | for name in subset:
197 | wrapped_layers[name] = WrappedLayer(subset[name])
198 |
199 | def add_batch(name):
200 | def tmp(_, inp, out):
201 | wrapped_layers[name].add_batch(inp[0].data, out.data)
202 | return tmp
203 |
204 | handles = []
205 | for name in wrapped_layers:
206 | handles.append(subset[name].register_forward_hook(add_batch(name)))
207 | layer = model.stages[block_id]
208 | if bs > 1024:
209 | tmp_res = []
210 | for i1 in range(0, bs, 512):
211 | j1 = min(i1+512, bs)
212 | tmp_res.append(layer(inps[i1:j1]))
213 | inps = torch.cat(tmp_res, dim=0)
214 | else:
215 | inps = layer(inps)
216 | for h in handles:
217 | h.remove()
218 |
219 | ################# pruning ###################
220 | for name in subset:
221 | if args.prune_metric == "wanda":
222 | metric_stats[block_id][name] *= torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))
223 |
224 | W_mask = compute_mask(metric_stats[block_id][name], args.prune_granularity, args.sparsity)
225 |
226 | subset[name].weight.data[W_mask] = 0
227 | ##############################################
--------------------------------------------------------------------------------
/image_classifiers/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import os
9 | import math
10 | import time
11 | from collections import defaultdict, deque
12 | import datetime
13 | import numpy as np
14 | from timm.utils import get_state_dict
15 |
16 | from pathlib import Path
17 | from timm.models import create_model
18 | import torch
19 | import torch.distributed as dist
20 | from torch._six import inf
21 |
22 | class SmoothedValue(object):
23 | """Track a series of values and provide access to smoothed values over a
24 | window or the global series average.
25 | """
26 |
27 | def __init__(self, window_size=20, fmt=None):
28 | if fmt is None:
29 | fmt = "{median:.4f} ({global_avg:.4f})"
30 | self.deque = deque(maxlen=window_size)
31 | self.total = 0.0
32 | self.count = 0
33 | self.fmt = fmt
34 |
35 | def update(self, value, n=1):
36 | self.deque.append(value)
37 | self.count += n
38 | self.total += value * n
39 |
40 | def synchronize_between_processes(self):
41 | """
42 | Warning: does not synchronize the deque!
43 | """
44 | if not is_dist_avail_and_initialized():
45 | return
46 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
47 | dist.barrier()
48 | dist.all_reduce(t)
49 | t = t.tolist()
50 | self.count = int(t[0])
51 | self.total = t[1]
52 |
53 | @property
54 | def median(self):
55 | d = torch.tensor(list(self.deque))
56 | return d.median().item()
57 |
58 | @property
59 | def avg(self):
60 | d = torch.tensor(list(self.deque), dtype=torch.float32)
61 | return d.mean().item()
62 |
63 | @property
64 | def global_avg(self):
65 | return self.total / self.count
66 |
67 | @property
68 | def max(self):
69 | return max(self.deque)
70 |
71 | @property
72 | def value(self):
73 | return self.deque[-1]
74 |
75 | def __str__(self):
76 | return self.fmt.format(
77 | median=self.median,
78 | avg=self.avg,
79 | global_avg=self.global_avg,
80 | max=self.max,
81 | value=self.value)
82 |
83 |
84 | class MetricLogger(object):
85 | def __init__(self, delimiter="\t"):
86 | self.meters = defaultdict(SmoothedValue)
87 | self.delimiter = delimiter
88 |
89 | def update(self, **kwargs):
90 | for k, v in kwargs.items():
91 | if v is None:
92 | continue
93 | if isinstance(v, torch.Tensor):
94 | v = v.item()
95 | assert isinstance(v, (float, int))
96 | self.meters[k].update(v)
97 |
98 | def __getattr__(self, attr):
99 | if attr in self.meters:
100 | return self.meters[attr]
101 | if attr in self.__dict__:
102 | return self.__dict__[attr]
103 | raise AttributeError("'{}' object has no attribute '{}'".format(
104 | type(self).__name__, attr))
105 |
106 | def __str__(self):
107 | loss_str = []
108 | for name, meter in self.meters.items():
109 | loss_str.append(
110 | "{}: {}".format(name, str(meter))
111 | )
112 | return self.delimiter.join(loss_str)
113 |
114 | def synchronize_between_processes(self):
115 | for meter in self.meters.values():
116 | meter.synchronize_between_processes()
117 |
118 | def add_meter(self, name, meter):
119 | self.meters[name] = meter
120 |
121 | def log_every(self, iterable, print_freq, header=None):
122 | i = 0
123 | if not header:
124 | header = ''
125 | start_time = time.time()
126 | end = time.time()
127 | iter_time = SmoothedValue(fmt='{avg:.4f}')
128 | data_time = SmoothedValue(fmt='{avg:.4f}')
129 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
130 | log_msg = [
131 | header,
132 | '[{0' + space_fmt + '}/{1}]',
133 | 'eta: {eta}',
134 | '{meters}',
135 | 'time: {time}',
136 | 'data: {data}'
137 | ]
138 | if torch.cuda.is_available():
139 | log_msg.append('max mem: {memory:.0f}')
140 | log_msg = self.delimiter.join(log_msg)
141 | MB = 1024.0 * 1024.0
142 | for obj in iterable:
143 | data_time.update(time.time() - end)
144 | yield obj
145 | iter_time.update(time.time() - end)
146 | if i % print_freq == 0 or i == len(iterable) - 1:
147 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
148 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149 | if torch.cuda.is_available():
150 | print(log_msg.format(
151 | i, len(iterable), eta=eta_string,
152 | meters=str(self),
153 | time=str(iter_time), data=str(data_time),
154 | memory=torch.cuda.max_memory_allocated() / MB))
155 | else:
156 | print(log_msg.format(
157 | i, len(iterable), eta=eta_string,
158 | meters=str(self),
159 | time=str(iter_time), data=str(data_time)))
160 | i += 1
161 | end = time.time()
162 | total_time = time.time() - start_time
163 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
164 | print('{} Total time: {} ({:.4f} s / it)'.format(
165 | header, total_time_str, total_time / len(iterable)))
166 |
167 |
168 | class TensorboardLogger(object):
169 | def __init__(self, log_dir):
170 | self.writer = SummaryWriter(logdir=log_dir)
171 | self.step = 0
172 |
173 | def set_step(self, step=None):
174 | if step is not None:
175 | self.step = step
176 | else:
177 | self.step += 1
178 |
179 | def update(self, head='scalar', step=None, **kwargs):
180 | for k, v in kwargs.items():
181 | if v is None:
182 | continue
183 | if isinstance(v, torch.Tensor):
184 | v = v.item()
185 | assert isinstance(v, (float, int))
186 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
187 |
188 | def flush(self):
189 | self.writer.flush()
190 |
191 |
192 | class WandbLogger(object):
193 | def __init__(self, args):
194 | self.args = args
195 |
196 | try:
197 | import wandb
198 | self._wandb = wandb
199 | except ImportError:
200 | raise ImportError(
201 | "To use the Weights and Biases Logger please install wandb."
202 | "Run `pip install wandb` to install it."
203 | )
204 |
205 | # Initialize a W&B run
206 | if self._wandb.run is None:
207 | self._wandb.init(
208 | project=args.project,
209 | config=args
210 | )
211 |
212 | def log_epoch_metrics(self, metrics, commit=True):
213 | """
214 | Log train/test metrics onto W&B.
215 | """
216 | # Log number of model parameters as W&B summary
217 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None)
218 | metrics.pop('n_parameters', None)
219 |
220 | # Log current epoch
221 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False)
222 | metrics.pop('epoch')
223 |
224 | for k, v in metrics.items():
225 | if 'train' in k:
226 | self._wandb.log({f'Global Train/{k}': v}, commit=False)
227 | elif 'test' in k:
228 | self._wandb.log({f'Global Test/{k}': v}, commit=False)
229 |
230 | self._wandb.log({})
231 |
232 | def log_checkpoints(self):
233 | output_dir = self.args.output_dir
234 | model_artifact = self._wandb.Artifact(
235 | self._wandb.run.id + "_model", type="model"
236 | )
237 |
238 | model_artifact.add_dir(output_dir)
239 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"])
240 |
241 | def set_steps(self):
242 | # Set global training step
243 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step')
244 | # Set epoch-wise step
245 | self._wandb.define_metric('Global Train/*', step_metric='epoch')
246 | self._wandb.define_metric('Global Test/*', step_metric='epoch')
247 |
248 |
249 | def setup_for_distributed(is_master):
250 | """
251 | This function disables printing when not in master process
252 | """
253 | import builtins as __builtin__
254 | builtin_print = __builtin__.print
255 |
256 | def print(*args, **kwargs):
257 | force = kwargs.pop('force', False)
258 | if is_master or force:
259 | builtin_print(*args, **kwargs)
260 |
261 | __builtin__.print = print
262 |
263 |
264 | def is_dist_avail_and_initialized():
265 | if not dist.is_available():
266 | return False
267 | if not dist.is_initialized():
268 | return False
269 | return True
270 |
271 |
272 | def get_world_size():
273 | if not is_dist_avail_and_initialized():
274 | return 1
275 | return dist.get_world_size()
276 |
277 |
278 | def get_rank():
279 | if not is_dist_avail_and_initialized():
280 | return 0
281 | return dist.get_rank()
282 |
283 |
284 | def is_main_process():
285 | return get_rank() == 0
286 |
287 |
288 | def save_on_master(*args, **kwargs):
289 | if is_main_process():
290 | torch.save(*args, **kwargs)
291 |
292 |
293 | def init_distributed_mode(args):
294 |
295 | if args.dist_on_itp:
296 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
297 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
298 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
299 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
300 | os.environ['LOCAL_RANK'] = str(args.gpu)
301 | os.environ['RANK'] = str(args.rank)
302 | os.environ['WORLD_SIZE'] = str(args.world_size)
303 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
304 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
305 | args.rank = int(os.environ["RANK"])
306 | args.world_size = int(os.environ['WORLD_SIZE'])
307 | args.gpu = int(os.environ['LOCAL_RANK'])
308 | elif 'SLURM_PROCID' in os.environ:
309 | args.rank = int(os.environ['SLURM_PROCID'])
310 | args.gpu = args.rank % torch.cuda.device_count()
311 |
312 | os.environ['RANK'] = str(args.rank)
313 | os.environ['LOCAL_RANK'] = str(args.gpu)
314 | os.environ['WORLD_SIZE'] = str(args.world_size)
315 | else:
316 | print('Not using distributed mode')
317 | args.distributed = False
318 | return
319 |
320 | args.distributed = True
321 |
322 | torch.cuda.set_device(args.gpu)
323 | args.dist_backend = 'nccl'
324 | print('| distributed init (rank {}): {}, gpu {}'.format(
325 | args.rank, args.dist_url, args.gpu), flush=True)
326 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
327 | world_size=args.world_size, rank=args.rank)
328 | torch.distributed.barrier()
329 | setup_for_distributed(args.rank == 0)
330 |
331 |
332 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
333 | missing_keys = []
334 | unexpected_keys = []
335 | error_msgs = []
336 | # copy state_dict so _load_from_state_dict can modify it
337 | metadata = getattr(state_dict, '_metadata', None)
338 | state_dict = state_dict.copy()
339 | if metadata is not None:
340 | state_dict._metadata = metadata
341 |
342 | def load(module, prefix=''):
343 | local_metadata = {} if metadata is None else metadata.get(
344 | prefix[:-1], {})
345 | module._load_from_state_dict(
346 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
347 | for name, child in module._modules.items():
348 | if child is not None:
349 | load(child, prefix + name + '.')
350 |
351 | load(model, prefix=prefix)
352 |
353 | warn_missing_keys = []
354 | ignore_missing_keys = []
355 | for key in missing_keys:
356 | keep_flag = True
357 | for ignore_key in ignore_missing.split('|'):
358 | if ignore_key in key:
359 | keep_flag = False
360 | break
361 | if keep_flag:
362 | warn_missing_keys.append(key)
363 | else:
364 | ignore_missing_keys.append(key)
365 |
366 | missing_keys = warn_missing_keys
367 |
368 | if len(missing_keys) > 0:
369 | print("Weights of {} not initialized from pretrained model: {}".format(
370 | model.__class__.__name__, missing_keys))
371 | if len(unexpected_keys) > 0:
372 | print("Weights from pretrained model not used in {}: {}".format(
373 | model.__class__.__name__, unexpected_keys))
374 | if len(ignore_missing_keys) > 0:
375 | print("Ignored weights of {} not initialized from pretrained model: {}".format(
376 | model.__class__.__name__, ignore_missing_keys))
377 | if len(error_msgs) > 0:
378 | print('\n'.join(error_msgs))
379 |
380 |
381 | class NativeScalerWithGradNormCount:
382 | state_dict_key = "amp_scaler"
383 |
384 | def __init__(self):
385 | self._scaler = torch.cuda.amp.GradScaler()
386 |
387 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
388 | self._scaler.scale(loss).backward(create_graph=create_graph)
389 |
390 | ########################################################
391 | ## Code I added
392 | for param in parameters:
393 | weight_copy = param.data.abs().clone()
394 | mask = weight_copy.gt(0).float().cuda()
395 | sparsity = mask.sum() / mask.numel()
396 | if sparsity > 0.3:
397 | # non-trivial sparsity
398 | param.grad.data.mul_(mask)
399 | ########################################################
400 |
401 | if update_grad:
402 | if clip_grad is not None:
403 | assert parameters is not None
404 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
405 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
406 | else:
407 | self._scaler.unscale_(optimizer)
408 | norm = get_grad_norm_(parameters)
409 | self._scaler.step(optimizer)
410 | self._scaler.update()
411 | else:
412 | norm = None
413 | return norm
414 |
415 | def state_dict(self):
416 | return self._scaler.state_dict()
417 |
418 | def load_state_dict(self, state_dict):
419 | self._scaler.load_state_dict(state_dict)
420 |
421 |
422 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
423 | if isinstance(parameters, torch.Tensor):
424 | parameters = [parameters]
425 | parameters = [p for p in parameters if p.grad is not None]
426 | norm_type = float(norm_type)
427 | if len(parameters) == 0:
428 | return torch.tensor(0.)
429 | device = parameters[0].grad.device
430 | if norm_type == inf:
431 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
432 | else:
433 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
434 | return total_norm
435 |
436 |
437 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
438 | start_warmup_value=0, warmup_steps=-1):
439 | warmup_schedule = np.array([])
440 | warmup_iters = warmup_epochs * niter_per_ep
441 | if warmup_steps > 0:
442 | warmup_iters = warmup_steps
443 | print("Set warmup steps = %d" % warmup_iters)
444 | if warmup_epochs > 0:
445 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
446 |
447 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
448 | schedule = np.array(
449 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
450 |
451 | schedule = np.concatenate((warmup_schedule, schedule))
452 |
453 | assert len(schedule) == epochs * niter_per_ep
454 | return schedule
455 |
456 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
457 | output_dir = Path(args.output_dir)
458 | epoch_name = str(epoch)
459 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
460 | for checkpoint_path in checkpoint_paths:
461 | to_save = {
462 | 'model': model_without_ddp.state_dict(),
463 | 'optimizer': optimizer.state_dict(),
464 | 'epoch': epoch,
465 | 'scaler': loss_scaler.state_dict(),
466 | 'args': args,
467 | }
468 |
469 | if model_ema is not None:
470 | to_save['model_ema'] = get_state_dict(model_ema)
471 |
472 | save_on_master(to_save, checkpoint_path)
473 |
474 | if is_main_process() and isinstance(epoch, int):
475 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq
476 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del)
477 | if os.path.exists(old_ckpt):
478 | os.remove(old_ckpt)
479 |
480 |
481 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
482 | output_dir = Path(args.output_dir)
483 | if args.auto_resume and len(args.resume) == 0:
484 | import glob
485 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
486 | latest_ckpt = -1
487 | for ckpt in all_checkpoints:
488 | t = ckpt.split('-')[-1].split('.')[0]
489 | if t.isdigit():
490 | latest_ckpt = max(int(t), latest_ckpt)
491 | if latest_ckpt >= 0:
492 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
493 | print("Auto resume checkpoint: %s" % args.resume)
494 |
495 | if args.resume:
496 | if args.resume.startswith('https'):
497 | checkpoint = torch.hub.load_state_dict_from_url(
498 | args.resume, map_location='cpu', check_hash=True)
499 | else:
500 | checkpoint = torch.load(args.resume, map_location='cpu')
501 | model_without_ddp.load_state_dict(checkpoint['model'])
502 | print("Resume checkpoint %s" % args.resume)
503 | if 'optimizer' in checkpoint and 'epoch' in checkpoint:
504 | optimizer.load_state_dict(checkpoint['optimizer'])
505 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema'
506 | args.start_epoch = checkpoint['epoch'] + 1
507 | else:
508 | assert args.eval, 'Does not support resuming with checkpoint-best'
509 | if hasattr(args, 'model_ema') and args.model_ema:
510 | if 'model_ema' in checkpoint.keys():
511 | model_ema.ema.load_state_dict(checkpoint['model_ema'])
512 | else:
513 | model_ema.ema.load_state_dict(checkpoint['model'])
514 | if 'scaler' in checkpoint:
515 | loss_scaler.load_state_dict(checkpoint['scaler'])
516 | print("With optim & sched!")
517 |
518 | def reg_scheduler(base_value, final_value, epochs, niter_per_ep, early_epochs=0, early_value=None,
519 | mode='linear', early_mode='regular'):
520 | early_schedule = np.array([])
521 | early_iters = early_epochs * niter_per_ep
522 | if early_value is None:
523 | early_value = final_value
524 | if early_epochs > 0:
525 | print(f"Set early value to {early_mode} {early_value}")
526 | if early_mode == 'regular':
527 | early_schedule = np.array([early_value] * early_iters)
528 | elif early_mode == 'linear':
529 | early_schedule = np.linspace(early_value, base_value, early_iters)
530 | elif early_mode == 'cosine':
531 | early_schedule = np.array(
532 | [base_value + 0.5 * (early_value - base_value) * (1 + math.cos(math.pi * i / early_iters)) for i in np.arange(early_iters)])
533 | regular_epochs = epochs - early_epochs
534 | iters = np.arange(regular_epochs * niter_per_ep)
535 | schedule = np.linspace(base_value, final_value, len(iters))
536 | schedule = np.concatenate((early_schedule, schedule))
537 |
538 | assert len(schedule) == epochs * niter_per_ep
539 | return schedule
540 |
541 | def build_model(args, pretrained=False):
542 | if args.model.startswith("convnext"):
543 | model = create_model(
544 | args.model,
545 | pretrained=pretrained,
546 | num_classes=args.nb_classes,
547 | layer_scale_init_value=args.layer_scale_init_value,
548 | head_init_scale=args.head_init_scale,
549 | drop_path_rate=args.drop_path,
550 | drop_rate=args.dropout,
551 | )
552 | else:
553 | model = create_model(
554 | args.model,
555 | pretrained=pretrained,
556 | num_classes=args.nb_classes,
557 | drop_path_rate=args.drop_path,
558 | drop_rate =args.dropout
559 | )
560 | return model
--------------------------------------------------------------------------------
/lib/ablate.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 | import transformers
7 |
8 | torch.backends.cuda.matmul.allow_tf32 = False
9 | torch.backends.cudnn.allow_tf32 = False
10 |
11 | class AblateGPT:
12 |
13 | def __init__(self, layer):
14 | self.layer = layer
15 | self.dev = self.layer.weight.device
16 | W = layer.weight.data.clone()
17 | if isinstance(self.layer, nn.Conv2d):
18 | W = W.flatten(1)
19 | if isinstance(self.layer, transformers.Conv1D):
20 | W = W.t()
21 | self.rows = W.shape[0]
22 | self.columns = W.shape[1]
23 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
24 | self.nsamples = 0
25 |
26 | self.scaler_row = torch.zeros((self.columns), device=self.dev)
27 |
28 | def add_batch(self, inp, out):
29 | if len(inp.shape) == 2:
30 | inp = inp.unsqueeze(0)
31 | tmp = inp.shape[0]
32 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
33 | if len(inp.shape) == 3:
34 | inp = inp.reshape((-1, inp.shape[-1]))
35 | inp = inp.t()
36 | self.H *= self.nsamples / (self.nsamples + tmp)
37 |
38 | self.scaler_row *= self.nsamples / (self.nsamples+tmp)
39 |
40 | self.nsamples += tmp
41 | inp = math.sqrt(2 / self.nsamples) * inp.float()
42 | self.H += inp.matmul(inp.t())
43 | self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples
44 |
45 | def get_wanda_mask(self, sparsity, prunen, prunem):
46 | W_metric = torch.abs(self.layer.weight.data) * torch.sqrt(self.scaler_row.reshape((1,-1)))
47 | W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False
48 | if prunen != 0:
49 | for ii in range(W_metric.shape[1]):
50 | if ii % prunem == 0:
51 | tmp = W_metric[:,ii:(ii+prunem)].float()
52 | W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True)
53 | else:
54 | sort_res = torch.sort(W_metric, dim=-1, stable=True)
55 | indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity)]
56 | W_mask.scatter_(1, indices, True)
57 |
58 | return W_mask
59 |
60 | def get_mag_mask(self, sparsity, prunen, prunem):
61 | W = self.layer.weight.data
62 | W_metric = torch.abs(W)
63 | if prunen != 0:
64 | W_mask = (torch.zeros_like(W)==1)
65 | for ii in range(W_metric.shape[1]):
66 | if ii % prunem == 0:
67 | tmp = W_metric[:,ii:(ii+prunem)].float()
68 | W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True)
69 | else:
70 | thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*sparsity)].cpu()
71 | W_mask = (W_metric<=thresh)
72 |
73 | return W_mask
74 |
75 | def fasterprune(
76 | self, args, sparsity, mask=None, prune_n=0, prune_m=0, blocksize=128, percdamp=.01
77 | ):
78 | W = self.layer.weight.data.clone()
79 | if isinstance(self.layer, nn.Conv2d):
80 | W = W.flatten(1)
81 | if isinstance(self.layer, transformers.Conv1D):
82 | W = W.t()
83 | W = W.float()
84 |
85 | tick = time.time()
86 |
87 | H = self.H
88 | del self.H
89 | dead = torch.diag(H) == 0
90 | H[dead, dead] = 1
91 | W[:, dead] = 0
92 |
93 | Losses = torch.zeros(self.rows, device=self.dev)
94 |
95 | damp = percdamp * torch.mean(torch.diag(H))
96 | diag = torch.arange(self.columns, device=self.dev)
97 | H[diag, diag] += damp
98 | H = torch.linalg.cholesky(H)
99 | H = torch.cholesky_inverse(H)
100 | H = torch.linalg.cholesky(H, upper=True)
101 | Hinv = H
102 |
103 | for i1 in range(0, self.columns, blocksize):
104 | i2 = min(i1 + blocksize, self.columns)
105 | count = i2 - i1
106 |
107 | W1 = W[:, i1:i2].clone()
108 | Q1 = torch.zeros_like(W1)
109 | Err1 = torch.zeros_like(W1)
110 | Losses1 = torch.zeros_like(W1)
111 | Hinv1 = Hinv[i1:i2, i1:i2]
112 |
113 | if prune_n == 0 or mask is not None:
114 | if mask is not None:
115 | mask1 = mask[:, i1:i2]
116 | else:
117 | # tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
118 | if "wanda" in args.prune_method:
119 | tmp = torch.abs(W1) * torch.sqrt(self.scaler_row[i1:i2].reshape((1,-1)))
120 | elif "mag" in args.prune_method:
121 | tmp = torch.abs(W1)
122 | thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
123 | mask1 = tmp <= thresh
124 | else:
125 | mask1 = torch.zeros_like(W1) == 1
126 |
127 | for i in range(count):
128 | w = W1[:, i]
129 | d = Hinv1[i, i]
130 |
131 | if prune_n != 0 and i % prune_m == 0 and mask is None:
132 | # tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2
133 | if "wanda" in args.prune_method:
134 | tmp = torch.abs(W1[:, i:(i+prune_m)]) * torch.sqrt(self.scaler_row[(i+i1):(i+i1+prune_m)].reshape((1,-1)))
135 | elif "mag" in args.prune_method:
136 | tmp = torch.abs(W1[:, i:(i+prune_m)])
137 | mask1.scatter_(1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True)
138 |
139 | q = w.clone()
140 | q[mask1[:, i]] = 0
141 |
142 | Q1[:, i] = q
143 | Losses1[:, i] = (w - q) ** 2 / d ** 2
144 |
145 | err1 = (w - q) / d
146 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
147 | Err1[:, i] = err1
148 |
149 | W[:, i1:i2] = Q1
150 | Losses += torch.sum(Losses1, 1) / 2
151 |
152 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
153 |
154 | torch.cuda.synchronize()
155 | if isinstance(self.layer, transformers.Conv1D):
156 | W = W.t()
157 | self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
158 |
159 | def free(self):
160 | self.H = None
161 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------
/lib/data.py:
--------------------------------------------------------------------------------
1 | # Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
2 |
3 | import numpy as np
4 | import random
5 | import torch
6 | from datasets import load_dataset
7 |
8 | # Set seed for reproducibility
9 | def set_seed(seed):
10 | np.random.seed(seed)
11 | torch.random.manual_seed(seed)
12 |
13 | # Wrapper for tokenized input IDs
14 | class TokenizerWrapper:
15 | def __init__(self, input_ids):
16 | self.input_ids = input_ids
17 |
18 | # Load and process wikitext2 dataset
19 | def get_wikitext2(nsamples, seed, seqlen, tokenizer):
20 | # Load train and test datasets
21 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
22 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
23 |
24 | # Encode datasets
25 | trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
26 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
27 |
28 | # Generate samples from training set
29 | random.seed(seed)
30 | trainloader = []
31 | for _ in range(nsamples):
32 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
33 | j = i + seqlen
34 | inp = trainenc.input_ids[:, i:j]
35 | tar = inp.clone()
36 | tar[:, :-1] = -100
37 | trainloader.append((inp, tar))
38 | return trainloader, testenc
39 |
40 | # Load and process c4 dataset
41 | def get_c4(nsamples, seed, seqlen, tokenizer):
42 | # Load train and validation datasets
43 | traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
44 | valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
45 |
46 | # Generate samples from training set
47 | random.seed(seed)
48 | trainloader = []
49 | for _ in range(nsamples):
50 | while True:
51 | i = random.randint(0, len(traindata) - 1)
52 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
53 | if trainenc.input_ids.shape[1] > seqlen:
54 | break
55 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
56 | j = i + seqlen
57 | inp = trainenc.input_ids[:, i:j]
58 | tar = inp.clone()
59 | tar[:, :-1] = -100
60 | trainloader.append((inp, tar))
61 |
62 | # Prepare validation dataset
63 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
64 | valenc = valenc.input_ids[:, :(256 * seqlen)]
65 | valenc = TokenizerWrapper(valenc)
66 | return trainloader, valenc
67 |
68 | # Function to select the appropriate loader based on dataset name
69 | def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None):
70 | if 'wikitext2' in name:
71 | return get_wikitext2(nsamples, seed, seqlen, tokenizer)
72 | if "c4" in name:
73 | return get_c4(nsamples, seed, seqlen, tokenizer)
--------------------------------------------------------------------------------
/lib/eval.py:
--------------------------------------------------------------------------------
1 | # Import necessary modules
2 | import time
3 | import torch
4 | import torch.nn as nn
5 |
6 | # Import get_loaders function from data module within the same directory
7 | from .data import get_loaders
8 |
9 | from collections import defaultdict
10 | import fnmatch
11 |
12 |
13 | # Function to evaluate perplexity (ppl) on a specified model and tokenizer
14 | def eval_ppl(args, model, tokenizer, device=torch.device("cuda:0")):
15 | # Set dataset
16 | dataset = "wikitext2"
17 |
18 | # Print status
19 | print(f"evaluating on {dataset}")
20 |
21 | # Get the test loader
22 | _, testloader = get_loaders(
23 | dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer
24 | )
25 |
26 | # Evaluate ppl in no grad context to avoid updating the model
27 | with torch.no_grad():
28 | ppl_test = eval_ppl_wikitext(model, testloader, 1, device)
29 | return ppl_test
30 |
31 | # Function to evaluate perplexity (ppl) specifically on the wikitext dataset
32 | def eval_ppl_wikitext_train(model, trainloader, bs=1, device=None):
33 | # Get input IDs
34 | # testenc = testenc.input_ids
35 |
36 | # Calculate number of samples
37 | # nsamples = testenc.numel() // model.seqlen
38 | nsamples = len(trainloader)
39 |
40 | # List to store negative log likelihoods
41 | nlls = []
42 | print(f"nsamples {nsamples}")
43 |
44 | # Loop through each batch
45 | for i in range(0,nsamples,bs):
46 | if i % 50 == 0:
47 | print(f"sample {i}")
48 |
49 | # Calculate end index
50 | j = min(i+bs, nsamples)
51 |
52 | # Prepare inputs and move to device
53 | # inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
54 | inputs = trainloader[i][0].to(device)
55 | inputs = inputs.reshape(j-i, model.seqlen)
56 |
57 | # Forward pass through the model
58 | lm_logits = model(inputs).logits
59 |
60 | # Shift logits and labels for next token prediction
61 | shift_logits = lm_logits[:, :-1, :].contiguous()
62 | shift_labels = inputs[:, 1:]
63 |
64 | # Compute loss
65 | loss_fct = nn.CrossEntropyLoss()
66 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
67 |
68 | # Calculate negative log likelihood
69 | neg_log_likelihood = loss.float() * model.seqlen * (j-i)
70 |
71 | # Append to list of negative log likelihoods
72 | nlls.append(neg_log_likelihood)
73 |
74 | # Compute perplexity
75 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
76 |
77 | # Empty CUDA cache to save memory
78 | torch.cuda.empty_cache()
79 |
80 | return ppl.item()
81 |
82 | # Function to evaluate perplexity (ppl) specifically on the wikitext dataset
83 | def eval_ppl_wikitext(model, testenc, bs=1, device=None):
84 | # Get input IDs
85 | testenc = testenc.input_ids
86 |
87 | # Calculate number of samples
88 | nsamples = testenc.numel() // model.seqlen
89 |
90 | # List to store negative log likelihoods
91 | nlls = []
92 | print(f"nsamples {nsamples}")
93 |
94 | # Loop through each batch
95 | for i in range(0,nsamples,bs):
96 | if i % 50 == 0:
97 | print(f"sample {i}")
98 |
99 | # Calculate end index
100 | j = min(i+bs, nsamples)
101 |
102 | # Prepare inputs and move to device
103 | inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
104 | inputs = inputs.reshape(j-i, model.seqlen)
105 |
106 | # Forward pass through the model
107 | lm_logits = model(inputs).logits
108 |
109 | # Shift logits and labels for next token prediction
110 | shift_logits = lm_logits[:, :-1, :].contiguous()
111 | shift_labels = inputs[:, 1:]
112 |
113 | # Compute loss
114 | loss_fct = nn.CrossEntropyLoss()
115 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
116 |
117 | # Calculate negative log likelihood
118 | neg_log_likelihood = loss.float() * model.seqlen * (j-i)
119 |
120 | # Append to list of negative log likelihoods
121 | nlls.append(neg_log_likelihood)
122 |
123 | # Compute perplexity
124 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
125 |
126 | # Empty CUDA cache to save memory
127 | torch.cuda.empty_cache()
128 |
129 | return ppl.item()
130 |
131 |
132 | def eval_zero_shot(model_name, model, tokenizer, task_list=["boolq","rte","hellaswag","winogrande","arc_challenge","arc_easy","openbookqa"],
133 | num_fewshot=0, use_accelerate=False, add_special_tokens=False):
134 | from lm_eval import tasks, evaluator
135 | def pattern_match(patterns, source_list):
136 | task_names = set()
137 | for pattern in patterns:
138 | for matching in fnmatch.filter(source_list, pattern):
139 | task_names.add(matching)
140 | return list(task_names)
141 | task_names = pattern_match(task_list, tasks.ALL_TASKS)
142 | model_args = f"pretrained={model_name},cache_dir=./llm_weights"
143 | limit = None
144 | if "70b" in model_name or "65b" in model_name:
145 | limit = 2000
146 | if use_accelerate:
147 | model_args = f"pretrained={model_name},cache_dir=./llm_weights,use_accelerate=True"
148 | results = evaluator.simple_evaluate(
149 | model="hf-causal-experimental",
150 | model_args=model_args,
151 | tasks=task_names,
152 | num_fewshot=num_fewshot,
153 | batch_size=None,
154 | device=None,
155 | no_cache=True,
156 | limit=limit,
157 | description_dict={},
158 | decontamination_ngrams_path=None,
159 | check_integrity=False,
160 | pretrained_model=model,
161 | tokenizer=tokenizer,
162 | add_special_tokens=add_special_tokens
163 | )
164 |
165 | return results
--------------------------------------------------------------------------------
/lib/layerwrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | # Define WrappedGPT class
5 | class WrappedGPT:
6 | """
7 | This class wraps a GPT layer for specific operations.
8 | """
9 |
10 | def __init__(self, layer, layer_id=0, layer_name="none"):
11 | self.layer = layer
12 | self.dev = self.layer.weight.device
13 | self.rows = layer.weight.data.shape[0]
14 | self.columns = layer.weight.data.shape[1]
15 |
16 | self.scaler_row = torch.zeros((self.columns), device=self.dev)
17 | self.nsamples = 0
18 |
19 | self.layer_id = layer_id
20 | self.layer_name = layer_name
21 |
22 | def add_batch(self, inp, out):
23 | if len(inp.shape) == 2:
24 | inp = inp.unsqueeze(0)
25 | tmp = inp.shape[0]
26 | if isinstance(self.layer, nn.Linear):
27 | if len(inp.shape) == 3:
28 | inp = inp.reshape((-1, inp.shape[-1]))
29 | inp = inp.t()
30 |
31 | self.scaler_row *= self.nsamples / (self.nsamples+tmp)
32 | self.nsamples += tmp
33 |
34 | inp = inp.type(torch.float32)
35 | self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples
--------------------------------------------------------------------------------
/lib/prune.py:
--------------------------------------------------------------------------------
1 | import time
2 | import heapq
3 | import torch
4 | import torch.nn as nn
5 | from .sparsegpt import SparseGPT
6 | from .layerwrapper import WrappedGPT
7 | from .data import get_loaders
8 |
9 | from .ablate import AblateGPT
10 |
11 | def find_layers(module, layers=[nn.Linear], name=''):
12 | """
13 | Recursively find the layers of a certain type in a module.
14 |
15 | Args:
16 | module (nn.Module): PyTorch module.
17 | layers (list): List of layer types to find.
18 | name (str): Name of the module.
19 |
20 | Returns:
21 | dict: Dictionary of layers of the given type(s) within the module.
22 | """
23 | if type(module) in layers:
24 | return {name: module}
25 | res = {}
26 | for name1, child in module.named_children():
27 | res.update(find_layers(
28 | child, layers=layers, name=name + '.' + name1 if name != '' else name1
29 | ))
30 | return res
31 |
32 | def check_sparsity(model):
33 | use_cache = model.config.use_cache
34 | model.config.use_cache = False
35 |
36 | layers = model.model.layers
37 | count = 0
38 | total_params = 0
39 | for i in range(len(layers)):
40 | layer = layers[i]
41 | subset = find_layers(layer)
42 |
43 | sub_count = 0
44 | sub_params = 0
45 | for name in subset:
46 | W = subset[name].weight.data
47 | count += (W==0).sum().item()
48 | total_params += W.numel()
49 |
50 | sub_count += (W==0).sum().item()
51 | sub_params += W.numel()
52 |
53 | print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}")
54 |
55 | model.config.use_cache = use_cache
56 | return float(count)/total_params
57 |
58 | def prepare_calibration_input(model, dataloader, device):
59 | use_cache = model.config.use_cache
60 | model.config.use_cache = False
61 | layers = model.model.layers
62 |
63 | # dev = model.hf_device_map["model.embed_tokens"]
64 | if "model.embed_tokens" in model.hf_device_map:
65 | device = model.hf_device_map["model.embed_tokens"]
66 |
67 | dtype = next(iter(model.parameters())).dtype
68 | inps = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device)
69 | inps.requires_grad = False
70 | cache = {'i': 0, 'attention_mask': None, "position_ids": None}
71 |
72 | class Catcher(nn.Module):
73 | def __init__(self, module):
74 | super().__init__()
75 | self.module = module
76 | def forward(self, inp, **kwargs):
77 | inps[cache['i']] = inp
78 | cache['i'] += 1
79 | cache['attention_mask'] = kwargs['attention_mask']
80 | cache['position_ids'] = kwargs['position_ids']
81 | raise ValueError
82 | layers[0] = Catcher(layers[0])
83 | for batch in dataloader:
84 | try:
85 | model(batch[0].to(device))
86 | except ValueError:
87 | pass
88 | layers[0] = layers[0].module
89 |
90 | outs = torch.zeros_like(inps)
91 | attention_mask = cache['attention_mask']
92 | position_ids = cache['position_ids']
93 | model.config.use_cache = use_cache
94 |
95 | return inps, outs, attention_mask, position_ids
96 |
97 | def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before):
98 | thres_cumsum = sum_before * alpha
99 | sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1))
100 | thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1)
101 | W_mask = (W_metric <= thres)
102 | cur_sparsity = (W_mask==True).sum() / W_mask.numel()
103 | return W_mask, cur_sparsity
104 |
105 | def prune_magnitude(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
106 | layers = model.model.layers
107 |
108 | for i in range(len(layers)):
109 | layer = layers[i]
110 | subset = find_layers(layer)
111 |
112 | for name in subset:
113 | W = subset[name].weight.data
114 | W_metric = torch.abs(W)
115 | if prune_n != 0:
116 | W_mask = (torch.zeros_like(W)==1)
117 | for ii in range(W_metric.shape[1]):
118 | if ii % prune_m == 0:
119 | tmp = W_metric[:,ii:(ii+prune_m)].float()
120 | W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
121 | else:
122 | thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*args.sparsity_ratio)].cpu()
123 | W_mask = (W_metric<=thresh)
124 |
125 | W[W_mask] = 0
126 |
127 | def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
128 | use_cache = model.config.use_cache
129 | model.config.use_cache = False
130 |
131 | print("loading calibdation data")
132 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)
133 | print("dataset loading complete")
134 | with torch.no_grad():
135 | inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device)
136 |
137 | layers = model.model.layers
138 | for i in range(len(layers)):
139 | layer = layers[i]
140 | subset = find_layers(layer)
141 |
142 | if f"model.layers.{i}" in model.hf_device_map: ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
143 | dev = model.hf_device_map[f"model.layers.{i}"]
144 | inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev)
145 |
146 | wrapped_layers = {}
147 | for name in subset:
148 | wrapped_layers[name] = WrappedGPT(subset[name])
149 |
150 | def add_batch(name):
151 | def tmp(_, inp, out):
152 | wrapped_layers[name].add_batch(inp[0].data, out.data)
153 | return tmp
154 |
155 | handles = []
156 | for name in wrapped_layers:
157 | handles.append(subset[name].register_forward_hook(add_batch(name)))
158 | for j in range(args.nsamples):
159 | with torch.no_grad():
160 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
161 | for h in handles:
162 | h.remove()
163 |
164 | for name in subset:
165 | print(f"pruning layer {i} name {name}")
166 | W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))
167 |
168 | W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False
169 | if prune_n != 0:
170 | # structured n:m sparsity
171 | for ii in range(W_metric.shape[1]):
172 | if ii % prune_m == 0:
173 | tmp = W_metric[:,ii:(ii+prune_m)].float()
174 | W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
175 | else:
176 | sort_res = torch.sort(W_metric, dim=-1, stable=True)
177 |
178 | if args.use_variant:
179 | # wanda variant
180 | tmp_metric = torch.cumsum(sort_res[0], dim=1)
181 | sum_before = W_metric.sum(dim=1)
182 |
183 | alpha = 0.4
184 | alpha_hist = [0., 0.8]
185 | W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before)
186 | while (torch.abs(cur_sparsity - args.sparsity_ratio)>0.001) and (alpha_hist[1]-alpha_hist[0]>=0.001):
187 | if cur_sparsity > args.sparsity_ratio:
188 | alpha_new = (alpha + alpha_hist[0]) / 2.0
189 | alpha_hist[1] = alpha
190 | else:
191 | alpha_new = (alpha + alpha_hist[1]) / 2.0
192 | alpha_hist[0] = alpha
193 |
194 | alpha = alpha_new
195 | W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before)
196 | print(f"alpha found {alpha} sparsity {cur_sparsity:.6f}")
197 | else:
198 | # unstructured pruning
199 | indices = sort_res[1][:,:int(W_metric.shape[1]*args.sparsity_ratio)]
200 | W_mask.scatter_(1, indices, True)
201 |
202 | subset[name].weight.data[W_mask] = 0 ## set weights to zero
203 |
204 | for j in range(args.nsamples):
205 | with torch.no_grad():
206 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
207 | inps, outs = outs, inps
208 |
209 | model.config.use_cache = use_cache
210 | torch.cuda.empty_cache()
211 |
212 |
213 | @torch.no_grad()
214 | def prune_sparsegpt(args, model, tokenizer, dev, prune_n=0, prune_m=0):
215 | ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
216 | print('Starting ...')
217 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)
218 |
219 | use_cache = model.config.use_cache
220 | model.config.use_cache = False
221 | layers = model.model.layers
222 |
223 | if "model.embed_tokens" in model.hf_device_map:
224 | dev = model.hf_device_map["model.embed_tokens"]
225 |
226 | dtype = next(iter(model.parameters())).dtype
227 | inps = torch.zeros(
228 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
229 | )
230 | cache = {'i': 0, 'attention_mask': None, "position_ids": None}
231 |
232 | class Catcher(nn.Module):
233 | def __init__(self, module):
234 | super().__init__()
235 | self.module = module
236 | def forward(self, inp, **kwargs):
237 | inps[cache['i']] = inp
238 | cache['i'] += 1
239 | cache['attention_mask'] = kwargs['attention_mask']
240 | cache['position_ids'] = kwargs['position_ids']
241 | raise ValueError
242 | layers[0] = Catcher(layers[0])
243 | for batch in dataloader:
244 | try:
245 | model(batch[0].to(dev))
246 | except ValueError:
247 | pass
248 | layers[0] = layers[0].module
249 | torch.cuda.empty_cache()
250 |
251 | outs = torch.zeros_like(inps)
252 | attention_mask = cache['attention_mask']
253 | position_ids = cache['position_ids']
254 |
255 | print('Ready.')
256 |
257 | for i in range(len(layers)):
258 | layer = layers[i]
259 | if f"model.layers.{i}" in model.hf_device_map:
260 | dev = model.hf_device_map[f"model.layers.{i}"]
261 | print(f"layer {i} device {dev}")
262 | inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev)
263 |
264 | subset = find_layers(layer)
265 |
266 | gpts = {}
267 | for name in subset:
268 | gpts[name] = SparseGPT(subset[name])
269 |
270 | def add_batch(name):
271 | def tmp(_, inp, out):
272 | gpts[name].add_batch(inp[0].data, out.data)
273 | return tmp
274 |
275 | handles = []
276 | for name in gpts:
277 | handles.append(subset[name].register_forward_hook(add_batch(name)))
278 |
279 | for j in range(args.nsamples):
280 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
281 | for h in handles:
282 | h.remove()
283 |
284 | for name in gpts:
285 | print(i, name)
286 | print('Pruning ...')
287 |
288 | gpts[name].fasterprune(args.sparsity_ratio, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)
289 | gpts[name].free()
290 |
291 | for j in range(args.nsamples):
292 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
293 |
294 | layers[i] = layer
295 | torch.cuda.empty_cache()
296 |
297 | inps, outs = outs, inps
298 |
299 | model.config.use_cache = use_cache
300 | torch.cuda.empty_cache()
301 |
302 |
303 |
304 | @torch.no_grad()
305 | def prune_ablate(args, model, tokenizer, dev, prune_n=0, prune_m=0):
306 | ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
307 | print('Starting ...')
308 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)
309 |
310 | use_cache = model.config.use_cache
311 | model.config.use_cache = False
312 | layers = model.model.layers
313 |
314 | if "model.embed_tokens" in model.hf_device_map:
315 | dev = model.hf_device_map["model.embed_tokens"]
316 |
317 | dtype = next(iter(model.parameters())).dtype
318 | inps = torch.zeros(
319 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
320 | )
321 | cache = {'i': 0, 'attention_mask': None, "position_ids": None}
322 |
323 | class Catcher(nn.Module):
324 | def __init__(self, module):
325 | super().__init__()
326 | self.module = module
327 | def forward(self, inp, **kwargs):
328 | inps[cache['i']] = inp
329 | cache['i'] += 1
330 | cache['attention_mask'] = kwargs['attention_mask']
331 | cache['position_ids'] = kwargs['position_ids']
332 | raise ValueError
333 | layers[0] = Catcher(layers[0])
334 | for batch in dataloader:
335 | try:
336 | model(batch[0].to(dev))
337 | except ValueError:
338 | pass
339 | layers[0] = layers[0].module
340 | torch.cuda.empty_cache()
341 |
342 | outs = torch.zeros_like(inps)
343 | attention_mask = cache['attention_mask']
344 | position_ids = cache['position_ids']
345 |
346 | print('Ready.')
347 |
348 | for i in range(len(layers)):
349 | layer = layers[i]
350 | if f"model.layers.{i}" in model.hf_device_map:
351 | dev = model.hf_device_map[f"model.layers.{i}"]
352 | print(f"layer {i} device {dev}")
353 | inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev)
354 |
355 | subset = find_layers(layer)
356 |
357 | gpts = {}
358 | for name in subset:
359 | gpts[name] = AblateGPT(subset[name])
360 |
361 | def add_batch(name):
362 | def tmp(_, inp, out):
363 | gpts[name].add_batch(inp[0].data, out.data)
364 | return tmp
365 |
366 | handles = []
367 | for name in gpts:
368 | handles.append(subset[name].register_forward_hook(add_batch(name)))
369 |
370 | for j in range(args.nsamples):
371 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
372 | for h in handles:
373 | h.remove()
374 |
375 | for name in gpts:
376 | print(i, name)
377 | print('Pruning ...')
378 |
379 | if args.prune_method == "ablate_wanda_seq":
380 | prune_mask = gpts[name].get_wanda_mask(args.sparsity_ratio, prune_n, prune_m)
381 | elif args.prune_method == "ablate_mag_seq":
382 | prune_mask = gpts[name].get_mag_mask(args.sparsity_ratio, prune_n, prune_m)
383 | elif "iter" in args.prune_method:
384 | prune_mask = None
385 |
386 | gpts[name].fasterprune(args, args.sparsity_ratio, mask=prune_mask, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)
387 | gpts[name].free()
388 |
389 | for j in range(args.nsamples):
390 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
391 |
392 | layers[i] = layer
393 | torch.cuda.empty_cache()
394 |
395 | inps, outs = outs, inps
396 |
397 | model.config.use_cache = use_cache
398 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------
/lib/prune_opt.py:
--------------------------------------------------------------------------------
1 | import time
2 | import heapq
3 | import torch
4 | import torch.nn as nn
5 | from .sparsegpt import SparseGPT
6 | from .layerwrapper import WrappedGPT
7 | from .data import get_loaders
8 |
9 | from .ablate import AblateGPT
10 |
11 | def find_layers(module, layers=[nn.Linear], name=''):
12 | """
13 | Recursively find the layers of a certain type in a module.
14 |
15 | Args:
16 | module (nn.Module): PyTorch module.
17 | layers (list): List of layer types to find.
18 | name (str): Name of the module.
19 |
20 | Returns:
21 | dict: Dictionary of layers of the given type(s) within the module.
22 | """
23 | if type(module) in layers:
24 | return {name: module}
25 | res = {}
26 | for name1, child in module.named_children():
27 | res.update(find_layers(
28 | child, layers=layers, name=name + '.' + name1 if name != '' else name1
29 | ))
30 | return res
31 |
32 | def check_sparsity(model):
33 | use_cache = model.config.use_cache
34 | model.config.use_cache = False
35 |
36 | layers = model.model.decoder.layers
37 | count = 0
38 | total_params = 0
39 | for i in range(len(layers)):
40 | layer = layers[i]
41 | subset = find_layers(layer)
42 |
43 | sub_count = 0
44 | sub_params = 0
45 | for name in subset:
46 | W = subset[name].weight.data
47 | count += (W==0).sum().item()
48 | total_params += W.numel()
49 |
50 | sub_count += (W==0).sum().item()
51 | sub_params += W.numel()
52 |
53 | print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}")
54 |
55 | model.config.use_cache = use_cache
56 | return float(count)/total_params
57 |
58 | def prepare_calibration_input(model, dataloader, device):
59 | use_cache = model.config.use_cache
60 | model.config.use_cache = False
61 | layers = model.model.decoder.layers
62 |
63 | if "model.embed_tokens" in model.hf_device_map:
64 | device = model.hf_device_map["model.embed_tokens"]
65 |
66 | dtype = next(iter(model.parameters())).dtype
67 | inps = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device)
68 | inps.requires_grad = False
69 | cache = {'i': 0, 'attention_mask': None, "position_ids": None}
70 |
71 | class Catcher(nn.Module):
72 | def __init__(self, module):
73 | super().__init__()
74 | self.module = module
75 | def forward(self, inp, **kwargs):
76 | inps[cache['i']] = inp
77 | cache['i'] += 1
78 | cache['attention_mask'] = kwargs['attention_mask']
79 | raise ValueError
80 | layers[0] = Catcher(layers[0])
81 | for batch in dataloader:
82 | try:
83 | model(batch[0].to(device))
84 | except ValueError:
85 | pass
86 | layers[0] = layers[0].module
87 |
88 | outs = torch.zeros_like(inps)
89 | attention_mask = cache['attention_mask']
90 | model.config.use_cache = use_cache
91 |
92 | return inps, outs, attention_mask
93 |
94 | def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before):
95 | thres_cumsum = sum_before * alpha
96 | sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1))
97 | thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1)
98 | W_mask = (W_metric <= thres)
99 | cur_sparsity = (W_mask==True).sum() / W_mask.numel()
100 | return W_mask, cur_sparsity
101 |
102 | def prune_magnitude(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
103 | layers = model.model.decoder.layers
104 |
105 | for i in range(len(layers)):
106 | layer = layers[i]
107 | subset = find_layers(layer)
108 |
109 | for name in subset:
110 | W = subset[name].weight.data
111 | W_metric = torch.abs(W)
112 | if prune_n != 0:
113 | W_mask = (torch.zeros_like(W)==1)
114 | for ii in range(W_metric.shape[1]):
115 | if ii % prune_m == 0:
116 | tmp = W_metric[:,ii:(ii+prune_m)].float()
117 | W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
118 | else:
119 | thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*args.sparsity_ratio)].cpu()
120 | W_mask = (W_metric<=thresh)
121 |
122 | W[W_mask] = 0
123 |
124 | def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
125 | use_cache = model.config.use_cache
126 | model.config.use_cache = False
127 |
128 | print("loading calibdation data")
129 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)
130 | print("dataset loading complete")
131 | with torch.no_grad():
132 | inps, outs, attention_mask = prepare_calibration_input(model, dataloader, device)
133 |
134 | layers = model.model.decoder.layers
135 | for i in range(len(layers)):
136 | layer = layers[i]
137 | subset = find_layers(layer)
138 |
139 | if f"model.layers.{i}" in model.hf_device_map: ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
140 | dev = model.hf_device_map[f"model.layers.{i}"]
141 | inps, outs, attention_mask = inps.to(dev), outs.to(dev), attention_mask.to(dev)
142 |
143 | wrapped_layers = {}
144 | for name in subset:
145 | wrapped_layers[name] = WrappedGPT(subset[name])
146 |
147 | def add_batch(name):
148 | def tmp(_, inp, out):
149 | wrapped_layers[name].add_batch(inp[0].data, out.data)
150 | return tmp
151 |
152 | handles = []
153 | for name in wrapped_layers:
154 | handles.append(subset[name].register_forward_hook(add_batch(name)))
155 | for j in range(args.nsamples):
156 | with torch.no_grad():
157 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
158 | for h in handles:
159 | h.remove()
160 |
161 | for name in subset:
162 | print(f"pruning layer {i} name {name}")
163 | W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))
164 |
165 | W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False
166 | if prune_n != 0:
167 | # structured n:m sparsity
168 | for ii in range(W_metric.shape[1]):
169 | if ii % prune_m == 0:
170 | tmp = W_metric[:,ii:(ii+prune_m)].float()
171 | W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
172 | else:
173 | sort_res = torch.sort(W_metric, dim=-1, stable=True)
174 |
175 | # unstructured pruning
176 | indices = sort_res[1][:,:int(W_metric.shape[1]*args.sparsity_ratio)]
177 | W_mask.scatter_(1, indices, True)
178 |
179 | subset[name].weight.data[W_mask] = 0 ## set weights to zero
180 |
181 | for j in range(args.nsamples):
182 | with torch.no_grad():
183 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
184 | inps, outs = outs, inps
185 |
186 | model.config.use_cache = use_cache
187 | torch.cuda.empty_cache()
188 |
189 | @torch.no_grad()
190 | def prune_sparsegpt(args, model, tokenizer, dev, prune_n=0, prune_m=0):
191 | ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
192 | print('Starting ...')
193 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)
194 |
195 | use_cache = model.config.use_cache
196 | model.config.use_cache = False
197 | layers = model.model.decoder.layers
198 |
199 | if "model.embed_tokens" in model.hf_device_map:
200 | dev = model.hf_device_map["model.embed_tokens"]
201 |
202 | dtype = next(iter(model.parameters())).dtype
203 | inps = torch.zeros(
204 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
205 | )
206 | cache = {'i': 0, 'attention_mask': None, "position_ids": None}
207 |
208 | class Catcher(nn.Module):
209 | def __init__(self, module):
210 | super().__init__()
211 | self.module = module
212 | def forward(self, inp, **kwargs):
213 | inps[cache['i']] = inp
214 | cache['i'] += 1
215 | cache['attention_mask'] = kwargs['attention_mask']
216 | # cache['position_ids'] = kwargs['position_ids']
217 | raise ValueError
218 | layers[0] = Catcher(layers[0])
219 | for batch in dataloader:
220 | try:
221 | model(batch[0].to(dev))
222 | except ValueError:
223 | pass
224 | layers[0] = layers[0].module
225 | torch.cuda.empty_cache()
226 |
227 | outs = torch.zeros_like(inps)
228 | attention_mask = cache['attention_mask']
229 |
230 | print('Ready.')
231 |
232 | for i in range(len(layers)):
233 | layer = layers[i]
234 | if f"model.layers.{i}" in model.hf_device_map:
235 | dev = model.hf_device_map[f"model.layers.{i}"]
236 | print(f"layer {i} device {dev}")
237 | inps, outs, attention_mask = inps.to(dev), outs.to(dev), attention_mask.to(dev)
238 |
239 | subset = find_layers(layer)
240 |
241 | gpts = {}
242 | for name in subset:
243 | gpts[name] = SparseGPT(subset[name])
244 |
245 | def add_batch(name):
246 | def tmp(_, inp, out):
247 | gpts[name].add_batch(inp[0].data, out.data)
248 | return tmp
249 |
250 | handles = []
251 | for name in gpts:
252 | handles.append(subset[name].register_forward_hook(add_batch(name)))
253 |
254 | for j in range(args.nsamples):
255 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
256 | for h in handles:
257 | h.remove()
258 |
259 | for name in gpts:
260 | print(i, name)
261 | print('Pruning ...')
262 |
263 | gpts[name].fasterprune(args.sparsity_ratio, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)
264 | gpts[name].free()
265 |
266 | for j in range(args.nsamples):
267 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
268 |
269 | layers[i] = layer
270 | torch.cuda.empty_cache()
271 |
272 | inps, outs = outs, inps
273 |
274 | model.config.use_cache = use_cache
275 | torch.cuda.empty_cache()
276 |
277 | @torch.no_grad()
278 | def prune_ablate(args, model, tokenizer, dev, prune_n=0, prune_m=0):
279 | ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
280 | print('Starting ...')
281 | dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)
282 |
283 | use_cache = model.config.use_cache
284 | model.config.use_cache = False
285 | layers = model.model.decoder.layers
286 |
287 | if "model.embed_tokens" in model.hf_device_map:
288 | dev = model.hf_device_map["model.embed_tokens"]
289 |
290 | dtype = next(iter(model.parameters())).dtype
291 | inps = torch.zeros(
292 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
293 | )
294 | cache = {'i': 0, 'attention_mask': None, "position_ids": None}
295 |
296 | class Catcher(nn.Module):
297 | def __init__(self, module):
298 | super().__init__()
299 | self.module = module
300 | def forward(self, inp, **kwargs):
301 | inps[cache['i']] = inp
302 | cache['i'] += 1
303 | cache['attention_mask'] = kwargs['attention_mask']
304 | # cache['position_ids'] = kwargs['position_ids']
305 | raise ValueError
306 | layers[0] = Catcher(layers[0])
307 | for batch in dataloader:
308 | try:
309 | model(batch[0].to(dev))
310 | except ValueError:
311 | pass
312 | layers[0] = layers[0].module
313 | torch.cuda.empty_cache()
314 |
315 | outs = torch.zeros_like(inps)
316 | attention_mask = cache['attention_mask']
317 | # position_ids = cache['position_ids']
318 |
319 | print('Ready.')
320 |
321 | for i in range(len(layers)):
322 | layer = layers[i]
323 | if f"model.layers.{i}" in model.hf_device_map:
324 | dev = model.hf_device_map[f"model.layers.{i}"]
325 | print(f"layer {i} device {dev}")
326 | inps, outs, attention_mask = inps.to(dev), outs.to(dev), attention_mask.to(dev)
327 |
328 | subset = find_layers(layer)
329 |
330 | gpts = {}
331 | for name in subset:
332 | gpts[name] = AblateGPT(subset[name])
333 |
334 | def add_batch(name):
335 | def tmp(_, inp, out):
336 | gpts[name].add_batch(inp[0].data, out.data)
337 | return tmp
338 |
339 | handles = []
340 | for name in gpts:
341 | handles.append(subset[name].register_forward_hook(add_batch(name)))
342 |
343 | for j in range(args.nsamples):
344 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
345 | for h in handles:
346 | h.remove()
347 |
348 | for name in gpts:
349 | print(i, name)
350 | print('Pruning ...')
351 |
352 | if args.prune_method == "ablate_wanda_seq":
353 | prune_mask = gpts[name].get_wanda_mask(args.sparsity_ratio, prune_n, prune_m)
354 | elif args.prune_method == "ablate_mag_seq":
355 | prune_mask = gpts[name].get_mag_mask(args.sparsity_ratio, prune_n, prune_m)
356 | elif "iter" in args.prune_method:
357 | prune_mask = None
358 |
359 | gpts[name].fasterprune(args, args.sparsity_ratio, mask=prune_mask,
360 | prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)
361 | gpts[name].free()
362 |
363 | for j in range(args.nsamples):
364 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
365 |
366 | layers[i] = layer
367 | torch.cuda.empty_cache()
368 |
369 | inps, outs = outs, inps
370 |
371 | model.config.use_cache = use_cache
372 | torch.cuda.empty_cache()
373 |
--------------------------------------------------------------------------------
/lib/sparsegpt.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 | import transformers
7 |
8 | torch.backends.cuda.matmul.allow_tf32 = False
9 | torch.backends.cudnn.allow_tf32 = False
10 |
11 | ## SparseGPT: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
12 | class SparseGPT:
13 |
14 | def __init__(self, layer):
15 | self.layer = layer
16 | self.dev = self.layer.weight.device
17 | W = layer.weight.data.clone()
18 | if isinstance(self.layer, nn.Conv2d):
19 | W = W.flatten(1)
20 | if isinstance(self.layer, transformers.Conv1D):
21 | W = W.t()
22 | self.rows = W.shape[0]
23 | self.columns = W.shape[1]
24 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
25 | self.nsamples = 0
26 |
27 | def add_batch(self, inp, out):
28 | if len(inp.shape) == 2:
29 | inp = inp.unsqueeze(0)
30 | tmp = inp.shape[0]
31 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
32 | if len(inp.shape) == 3:
33 | inp = inp.reshape((-1, inp.shape[-1]))
34 | inp = inp.t()
35 | self.H *= self.nsamples / (self.nsamples + tmp)
36 | self.nsamples += tmp
37 | inp = math.sqrt(2 / self.nsamples) * inp.float()
38 | self.H += inp.matmul(inp.t())
39 |
40 | def fasterprune(
41 | self, sparsity, prune_n=0, prune_m=0, blocksize=128, percdamp=.01
42 | ):
43 | W = self.layer.weight.data.clone()
44 | if isinstance(self.layer, nn.Conv2d):
45 | W = W.flatten(1)
46 | if isinstance(self.layer, transformers.Conv1D):
47 | W = W.t()
48 | W = W.float()
49 |
50 | tick = time.time()
51 |
52 | H = self.H
53 | del self.H
54 | dead = torch.diag(H) == 0
55 | H[dead, dead] = 1
56 | W[:, dead] = 0
57 |
58 | Losses = torch.zeros(self.rows, device=self.dev)
59 |
60 | damp = percdamp * torch.mean(torch.diag(H))
61 | diag = torch.arange(self.columns, device=self.dev)
62 | H[diag, diag] += damp
63 | H = torch.linalg.cholesky(H)
64 | H = torch.cholesky_inverse(H)
65 | H = torch.linalg.cholesky(H, upper=True)
66 | Hinv = H
67 |
68 | mask = None
69 |
70 | for i1 in range(0, self.columns, blocksize):
71 | i2 = min(i1 + blocksize, self.columns)
72 | count = i2 - i1
73 |
74 | W1 = W[:, i1:i2].clone()
75 | Q1 = torch.zeros_like(W1)
76 | Err1 = torch.zeros_like(W1)
77 | Losses1 = torch.zeros_like(W1)
78 | Hinv1 = Hinv[i1:i2, i1:i2]
79 |
80 | if prune_n == 0:
81 | if mask is not None:
82 | mask1 = mask[:, i1:i2]
83 | else:
84 | tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
85 | thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
86 | mask1 = tmp <= thresh
87 | else:
88 | mask1 = torch.zeros_like(W1) == 1
89 |
90 | for i in range(count):
91 | w = W1[:, i]
92 | d = Hinv1[i, i]
93 |
94 | if prune_n != 0 and i % prune_m == 0:
95 | tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2
96 | mask1.scatter_(1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True)
97 |
98 | q = w.clone()
99 | q[mask1[:, i]] = 0
100 |
101 | Q1[:, i] = q
102 | Losses1[:, i] = (w - q) ** 2 / d ** 2
103 |
104 | err1 = (w - q) / d
105 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
106 | Err1[:, i] = err1
107 |
108 | W[:, i1:i2] = Q1
109 | Losses += torch.sum(Losses1, 1) / 2
110 |
111 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
112 |
113 | torch.cuda.synchronize()
114 | if isinstance(self.layer, transformers.Conv1D):
115 | W = W.t()
116 | self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
117 |
118 | def free(self):
119 | self.H = None
120 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------
/lora_ft/README.md:
--------------------------------------------------------------------------------
1 | ## LoRA Fine-tuning of pruned LLMs
2 | Here we provide the script for the lora fine-tuning experiments in the paper. The commands for reproducing our experiments are in [script.sh](script.sh).
3 |
4 | This codebase is based on [run_clm.py](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling). Here we adapt this code with LoRA fine-tuning on the C4 training dataset. Some custom changes we make in the code include:
5 | - [loc 1](https://github.com/locuslab/wanda/blob/main/lora_ft/finetune_lm.py#L374): set up LLaMA-7B for LoRA fine-tuning;
6 | - [loc 2](https://github.com/locuslab/wanda/blob/main/lora_ft/finetune_lm.py#L521): set up training arguments for Trainer.
7 | - [loc 3](https://github.com/locuslab/wanda/blob/main/lora_ft/finetune_lm.py#L364): load the tokenizer from vicuna, which are the same as the original LLaMA tokenizer but also fix the issues of some special tokens.
8 | - [loc 4](https://github.com/locuslab/wanda/blob/main/lora_ft/finetune_lm.py#L319): load the c4 training dataset.
9 |
10 | To train a LoRA adapter, run the command:
11 | ```sh
12 | CUDA_VISIBLE_DEVICES=0 python finetune_lm.py \
13 | --model_name_or_path [PATH to load sparse pruned LLaMA-7B] \
14 | --config_name "decapoda-research/llama-7b-hf" \
15 | --dataset_name c4 \
16 | --num_train_epochs 1 \
17 | --block_size 1024 \
18 | --per_device_train_batch_size 1 \
19 | --per_device_eval_batch_size 8 \
20 | --do_train \
21 | --do_eval \
22 | --max_train_samples 30000 \
23 | --max_eval_samples 128 \
24 | --learning_rate 1e-4 \
25 | --overwrite_output_dir \
26 | --output_dir [PATH to save the LoRA weights]
27 | ```
28 | We provide a quick overview of the arguments:
29 | - `--model_name_or_path`: The path/directory where pruned LLaMA-7B are saved with `model.save_pretrained(PATH)`.
30 | - `--block_size`: context size, if you have 80GB gpu, you can set it to 2048;
31 | - `--max_train_samples`: the number of training sequences, 30000 would lead to roughly 12 hours of training on 1 GPU;
32 | - `--learning_rate`: the learning rate for LoRA fine-tuning;
33 |
34 | We provide the code to evaluate LoRA adapter on WikiText validation dataset in [evaluate_ppl.py](evaluate_ppl.py). For zero shot evaluation, additionally pass the `--eval_zero_shot` argument.
--------------------------------------------------------------------------------
/lora_ft/evaluate_ppl.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from transformers import AutoModelForCausalLM
3 | from datasets import load_dataset
4 | import torch
5 | import torch.nn as nn
6 | from peft import PeftModel, PeftConfig
7 | from tqdm import tqdm
8 | import sys
9 | import json
10 | import time
11 | import os
12 |
13 | import fnmatch
14 |
15 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
16 | if type(module) in layers:
17 | return {name: module}
18 | res = {}
19 | for name1, child in module.named_children():
20 | res.update(find_layers(
21 | child, layers=layers, name=name + '.' + name1 if name != '' else name1
22 | ))
23 | return res
24 |
25 | def check_sparsity(model):
26 | use_cache = model.config.use_cache
27 | model.config.use_cache = False
28 |
29 | try:
30 | layers = model.model.layers
31 | except:
32 | layers = model.model.model.layers
33 | count = 0
34 | total_params = 0
35 | for i in range(len(layers)):
36 | layer = layers[i]
37 | subset = find_layers(layer)
38 |
39 | for name in subset:
40 | W = subset[name].weight.data
41 | cur_zeros = (W==0).sum().item()
42 | cur_total = W.numel()
43 |
44 | count += cur_zeros
45 | total_params += cur_total
46 |
47 | print(f"layer {i} name {name} {W.shape} sparsity {float(cur_zeros)/cur_total}")
48 |
49 | print(f"total number of params {total_params}")
50 | model.config.use_cache = use_cache
51 | return float(count)/total_params
52 |
53 | def evaluate_ppl(dataset_name, model, tokenizer, ctx_length):
54 | # max_length = model.seqlen
55 | model_seqlen = ctx_length
56 | max_length = ctx_length
57 | stride = ctx_length
58 |
59 | if dataset_name == "wikitext":
60 | test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
61 | encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt")
62 | seq_len = encodings.input_ids.size(1)
63 | elif dataset_name == "ptb":
64 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
65 | encodings = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
66 | seq_len = encodings.input_ids.size(1)
67 | elif dataset_name == "c4":
68 | valdata = load_dataset(
69 | 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
70 | )
71 | encodings = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
72 | # encodings = encodings.input_ids[:, :(256 * model.seqlen)]
73 | seq_len = 256 * model_seqlen
74 |
75 | nlls = []
76 | prev_end_loc = 0
77 | for begin_loc in tqdm(range(0, seq_len, stride)):
78 | end_loc = min(begin_loc + max_length, seq_len)
79 | trg_len = end_loc - prev_end_loc # may be different from stride on last loop
80 | input_ids = encodings.input_ids[:, begin_loc:end_loc].cuda()
81 | target_ids = input_ids.clone()
82 | target_ids[:, :-trg_len] = -100
83 |
84 | with torch.no_grad():
85 | outputs = model(input_ids, labels=target_ids)
86 |
87 | neg_log_likelihood = outputs.loss
88 |
89 | nlls.append(neg_log_likelihood)
90 |
91 | prev_end_loc = end_loc
92 | if end_loc == seq_len:
93 | break
94 |
95 | ppl = torch.exp(torch.stack(nlls).mean())
96 | return ppl.item()
97 |
98 | def eval_llm(model, tokenizer, task_list=["boolq","piqa","hellaswag","winogrande","arc_challenge","arc_easy","openbookqa"], num_fewshot=0):
99 | from lm_eval import tasks, evaluator
100 | def pattern_match(patterns, source_list):
101 | task_names = set()
102 | for pattern in patterns:
103 | for matching in fnmatch.filter(source_list, pattern):
104 | task_names.add(matching)
105 | return list(task_names)
106 | task_names = pattern_match(task_list, tasks.ALL_TASKS)
107 | results = evaluator.simple_evaluate(
108 | model="hf-causal-experimental",
109 | model_args="pretrained=decapoda-research/llama-7b-hf",
110 | tasks=task_names,
111 | num_fewshot=num_fewshot,
112 | batch_size=None,
113 | # device='cuda:0',
114 | device=None,
115 | no_cache=True,
116 | limit=None,
117 | description_dict={},
118 | decontamination_ngrams_path=None,
119 | check_integrity=False,
120 | pretrained_model=model,
121 | tokenizer=tokenizer
122 | )
123 |
124 | return results
125 |
126 | def main(args):
127 | model = AutoModelForCausalLM.from_pretrained(
128 | args.model,
129 | torch_dtype=torch.float16, cache_dir=args.cache_dir, low_cpu_mem_usage=True, device_map="auto")
130 | tokenizer = AutoTokenizer.from_pretrained(
131 | "lmsys/vicuna-13b-delta-v0",
132 | cache_dir=args.cache_dir,
133 | padding_side="right",
134 | use_fast=True,
135 | )
136 |
137 | model = PeftModel.from_pretrained(model,args.lora_weights,torch_dtype=torch.float16)
138 |
139 | model.eval()
140 |
141 | ppl = evaluate_ppl("wikitext", model, tokenizer, args.ctx_length)
142 | print(f"perplexity on wikitext {ppl}")
143 |
144 | if args.eval_zero_shot:
145 | task_list_dict = {0: ["boolq", "rte","hellaswag","winogrande", "arc_easy","arc_challenge", "openbookqa"]}
146 | accelerate=False
147 | for num_shot in [0]:
148 | task_list = task_list_dict[num_shot]
149 | results = eval_llm(model, tokenizer, task_list, num_shot)
150 |
151 | if __name__ == "__main__":
152 | import argparse
153 |
154 | parser = argparse.ArgumentParser()
155 |
156 | parser.add_argument(
157 | '--model', type=str
158 | )
159 | parser.add_argument(
160 | '--cache_dir', type=str, default="llm_weights"
161 | )
162 | parser.add_argument(
163 | '--lora_weights', type=str, default=None
164 | )
165 | parser.add_argument(
166 | '--ctx_length', type=int, default=2048
167 | )
168 | parser.add_argument("--eval_zero_shot", action="store_true")
169 |
170 | args = parser.parse_args()
171 | main(args)
--------------------------------------------------------------------------------
/lora_ft/script.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python finetune_lm.py \
2 | --model_name_or_path [PATH to load sparse pruned LLaMA-7B] \
3 | --config_name "decapoda-research/llama-7b-hf" \
4 | --dataset_name c4 \
5 | --num_train_epochs 1 \
6 | --block_size 1024 \
7 | --per_device_train_batch_size 1 \
8 | --per_device_eval_batch_size 8 \
9 | --do_train \
10 | --do_eval \
11 | --max_train_samples 30000 \
12 | --max_eval_samples 128 \
13 | --learning_rate 1e-4 \
14 | --overwrite_output_dir \
15 | --output_dir [PATH to save the LoRA weights]
16 |
17 | CUDA_VISIBLE_DEVICES=0 python evaluate_ppl.py \
18 | --model [PATH to load sparse pruned LLaMA-7B] \
19 | --lora_weights [PATH to load the LoRA weights]
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import torch
5 | from transformers import AutoTokenizer, AutoModelForCausalLM
6 | from importlib.metadata import version
7 |
8 | from lib.prune import prune_wanda, prune_magnitude, prune_sparsegpt, prune_ablate, check_sparsity, find_layers
9 | from lib.eval import eval_ppl, eval_zero_shot
10 |
11 | print('torch', version('torch'))
12 | print('transformers', version('transformers'))
13 | print('accelerate', version('accelerate'))
14 | print('# of gpus: ', torch.cuda.device_count())
15 |
16 | def get_llm(model_name, cache_dir="llm_weights"):
17 | model = AutoModelForCausalLM.from_pretrained(
18 | model_name,
19 | torch_dtype=torch.float16,
20 | cache_dir=cache_dir,
21 | low_cpu_mem_usage=True,
22 | device_map="auto"
23 | )
24 |
25 | model.seqlen = model.config.max_position_embeddings
26 | return model
27 |
28 | def main():
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--model', type=str, help='LLaMA model')
31 | parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
32 | parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
33 | parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level')
34 | parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"])
35 | parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt",
36 | "ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"])
37 | parser.add_argument("--cache_dir", default="llm_weights", type=str )
38 | parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix")
39 | parser.add_argument('--save', type=str, default=None, help='Path to save results.')
40 | parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
41 |
42 | parser.add_argument("--eval_zero_shot", action="store_true")
43 | args = parser.parse_args()
44 |
45 | # Setting seeds for reproducibility
46 | np.random.seed(args.seed)
47 | torch.random.manual_seed(args.seed)
48 |
49 | # Handling n:m sparsity
50 | prune_n, prune_m = 0, 0
51 | if args.sparsity_type != "unstructured":
52 | assert args.sparsity_ratio == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity"
53 | prune_n, prune_m = map(int, args.sparsity_type.split(":"))
54 |
55 | model_name = args.model.split("/")[-1]
56 | print(f"loading llm model {args.model}")
57 | model = get_llm(args.model, args.cache_dir)
58 | model.eval()
59 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
60 |
61 | device = torch.device("cuda:0")
62 | if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
63 | device = model.hf_device_map["lm_head"]
64 | print("use device ", device)
65 |
66 | if args.sparsity_ratio != 0:
67 | print("pruning starts")
68 | if args.prune_method == "wanda":
69 | prune_wanda(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
70 | elif args.prune_method == "magnitude":
71 | prune_magnitude(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
72 | elif args.prune_method == "sparsegpt":
73 | prune_sparsegpt(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
74 | elif "ablate" in args.prune_method:
75 | prune_ablate(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
76 |
77 | ################################################################
78 | print("*"*30)
79 | sparsity_ratio = check_sparsity(model)
80 | print(f"sparsity sanity check {sparsity_ratio:.4f}")
81 | print("*"*30)
82 | ################################################################
83 | ppl_test = eval_ppl(args, model, tokenizer, device)
84 | print(f"wikitext perplexity {ppl_test}")
85 |
86 | if not os.path.exists(args.save):
87 | os.makedirs(args.save)
88 | save_filepath = os.path.join(args.save, f"log_{args.prune_method}.txt")
89 | with open(save_filepath, "w") as f:
90 | print("method\tactual_sparsity\tppl_test", file=f, flush=True)
91 | print(f"{args.prune_method}\t{sparsity_ratio:.4f}\t{ppl_test:.4f}", file=f, flush=True)
92 |
93 | if args.eval_zero_shot:
94 | accelerate=False
95 | if "30b" in args.model or "65b" in args.model or "70b" in args.model:
96 | accelerate=True
97 |
98 | task_list = ["boolq", "rte","hellaswag","winogrande", "arc_easy","arc_challenge", "openbookqa"]
99 | num_shot = 0
100 | results = eval_zero_shot(args.model, model, tokenizer, task_list, num_shot, accelerate)
101 | print("********************************")
102 | print("zero_shot evaluation results")
103 | print(results)
104 |
105 | if args.save_model:
106 | model.save_pretrained(args.save_model)
107 | tokenizer.save_pretrained(args.save_model)
108 |
109 | if __name__ == '__main__':
110 | main()
--------------------------------------------------------------------------------
/main_opt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import torch
5 | from transformers import AutoTokenizer, AutoModelForCausalLM
6 | from importlib.metadata import version
7 |
8 | from lib.prune_opt import prune_wanda, prune_magnitude, prune_sparsegpt, prune_ablate, check_sparsity, find_layers
9 | from lib.eval import eval_ppl, eval_zero_shot
10 |
11 | print('torch', version('torch'))
12 | print('transformers', version('transformers'))
13 | print('accelerate', version('accelerate'))
14 | print('# of gpus: ', torch.cuda.device_count())
15 |
16 | def get_llm(model_name, cache_dir="llm_weights"):
17 | model = AutoModelForCausalLM.from_pretrained(
18 | model_name,
19 | torch_dtype=torch.float16,
20 | cache_dir=cache_dir,
21 | low_cpu_mem_usage=True,
22 | device_map="auto"
23 | )
24 |
25 | model.seqlen = model.config.max_position_embeddings
26 | return model
27 |
28 | def main():
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--model', type=str, help='LLaMA model')
31 | parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
32 | parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
33 | parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level')
34 | parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"])
35 | parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt",
36 | "ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"])
37 | parser.add_argument("--cache_dir", default="llm_weights", type=str )
38 | parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix")
39 | parser.add_argument('--save', type=str, default=None, help='Path to save results.')
40 | parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
41 |
42 | parser.add_argument("--eval_zero_shot", action="store_true")
43 | args = parser.parse_args()
44 |
45 | # Setting seeds for reproducibility
46 | np.random.seed(args.seed)
47 | torch.random.manual_seed(args.seed)
48 |
49 | # Handling n:m sparsity
50 | prune_n, prune_m = 0, 0
51 | if args.sparsity_type != "unstructured":
52 | assert args.sparsity_ratio == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity"
53 | prune_n, prune_m = map(int, args.sparsity_type.split(":"))
54 |
55 | model_name = args.model.split("/")[-1]
56 | print(f"loading llm model {args.model}")
57 | model = get_llm(args.model, args.cache_dir)
58 | model.eval()
59 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
60 |
61 | device = torch.device("cuda:0")
62 | if "30b" in args.model or "66b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
63 | device = model.hf_device_map["lm_head"]
64 | print("use device ", device)
65 |
66 | if args.sparsity_ratio != 0:
67 | print("pruning starts")
68 | if args.prune_method == "wanda":
69 | prune_wanda(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
70 | elif args.prune_method == "magnitude":
71 | prune_magnitude(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
72 | elif args.prune_method == "sparsegpt":
73 | prune_sparsegpt(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
74 | elif "ablate" in args.prune_method:
75 | prune_ablate(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
76 |
77 | ################################################################
78 | print("*"*30)
79 | sparsity_ratio = check_sparsity(model)
80 | print(f"sparsity sanity check {sparsity_ratio:.4f}")
81 | print("*"*30)
82 | ################################################################
83 | ppl_test = eval_ppl(args, model, tokenizer, device)
84 | print(f"wikitext perplexity {ppl_test}")
85 |
86 | if not os.path.exists(args.save):
87 | os.makedirs(args.save)
88 | save_filepath = os.path.join(args.save, f"log_{args.prune_method}.txt")
89 | with open(save_filepath, "w") as f:
90 | print("method\tactual_sparsity\tppl_test", file=f, flush=True)
91 | print(f"{args.prune_method}\t{sparsity_ratio:.4f}\t{ppl_test:.4f}", file=f, flush=True)
92 |
93 | if args.eval_zero_shot:
94 | accelerate=False
95 | if "30b" in args.model or "66b" in args.model:
96 | accelerate=True
97 |
98 | task_list = ["boolq", "rte","hellaswag","winogrande", "arc_easy","arc_challenge", "openbookqa"]
99 | num_shot = 0
100 | results = eval_zero_shot(args.model, model, tokenizer, task_list, num_shot, accelerate)
101 | print("********************************")
102 | print("zero_shot evaluation results")
103 | print(results)
104 |
105 | if args.save_model:
106 | model.save_pretrained(args.save_model)
107 | tokenizer.save_pretrained(args.save_model)
108 |
109 | if __name__ == '__main__':
110 | main()
--------------------------------------------------------------------------------
/scripts/ablate_weight_update.sh:
--------------------------------------------------------------------------------
1 | for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter
2 | do
3 | CUDA_VISIBLE_DEVICES=0 python main.py \
4 | --model decapoda-research/llama-7b-hf \
5 | --nsamples 128 \
6 | --sparsity_ratio 0.5 \
7 | --sparsity_type unstructured \
8 | --prune_method ${method} \
9 | --save out/llama_7b_ablation/unstructured/
10 | done
11 |
12 | for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter
13 | do
14 | CUDA_VISIBLE_DEVICES=0 python main.py \
15 | --model decapoda-research/llama-7b-hf \
16 | --nsamples 128 \
17 | --sparsity_ratio 0.5 \
18 | --sparsity_type 4:8 \
19 | --prune_method ${method} \
20 | --save out/llama_7b_ablation/4:8/
21 | done
22 |
23 | for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter
24 | do
25 | CUDA_VISIBLE_DEVICES=0 python main.py \
26 | --model decapoda-research/llama-7b-hf \
27 | --nsamples 128 \
28 | --sparsity_ratio 0.5 \
29 | --sparsity_type 2:4 \
30 | --prune_method ${method} \
31 | --save out/llama_7b_ablation/2:4/
32 | done
33 |
--------------------------------------------------------------------------------
/scripts/llama_13b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set common variables
4 | model="decapoda-research/llama-13b-hf"
5 | sparsity_ratio=0.5
6 | cuda_device=0
7 |
8 | # Set CUDA device visibility
9 | export CUDA_VISIBLE_DEVICES=$cuda_device
10 |
11 | # Define function to run python command
12 | run_python_command () {
13 | python main.py \
14 | --model $model \
15 | --prune_method $1 \
16 | --sparsity_ratio $sparsity_ratio \
17 | --sparsity_type $2 \
18 | --save $3
19 | }
20 |
21 | # llama-13b with wanda pruning method
22 | echo "Running with wanda pruning method"
23 | run_python_command "wanda" "unstructured" "out/llama_13b/unstructured/wanda/"
24 | run_python_command "wanda" "2:4" "out/llama_13b/2-4/wanda/"
25 | run_python_command "wanda" "4:8" "out/llama_13b/4-8/wanda/"
26 | echo "Finished wanda pruning method"
27 |
28 | # llama-13b with sparsegpt pruning method
29 | echo "Running with sparsegpt pruning method"
30 | run_python_command "sparsegpt" "unstructured" "out/llama_13b/unstructured/sparsegpt/"
31 | run_python_command "sparsegpt" "2:4" "out/llama_13b/2-4/sparsegpt/"
32 | run_python_command "sparsegpt" "4:8" "out/llama_13b/4-8/sparsegpt/"
33 | echo "Finished sparsegpt pruning method"
34 |
35 | # llama-13b with magnitude pruning method
36 | echo "Running with magnitude pruning method"
37 | run_python_command "magnitude" "unstructured" "out/llama_13b/unstructured/magnitude/"
38 | run_python_command "magnitude" "2:4" "out/llama_13b/2-4/magnitude/"
39 | run_python_command "magnitude" "4:8" "out/llama_13b/4-8/magnitude/"
40 | echo "Finished magnitude pruning method"
--------------------------------------------------------------------------------
/scripts/llama_30b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set common variables
4 | model="decapoda-research/llama-30b-hf"
5 | sparsity_ratio=0.5
6 | cuda_devices="0,1"
7 |
8 | # Set CUDA device visibility
9 | export CUDA_VISIBLE_DEVICES=$cuda_devices
10 |
11 | # Define function to run python command
12 | run_python_command () {
13 | python main.py \
14 | --model $model \
15 | --prune_method $1 \
16 | --sparsity_ratio $sparsity_ratio \
17 | --sparsity_type $2 \
18 | --save $3
19 | }
20 |
21 | # llama-30b with wanda pruning method
22 | echo "Running with wanda pruning method"
23 | run_python_command "wanda" "unstructured" "out/llama_30b/unstructured/wanda/"
24 | run_python_command "wanda" "2:4" "out/llama_30b/2-4/wanda/"
25 | run_python_command "wanda" "4:8" "out/llama_30b/4-8/wanda/"
26 | echo "Finished wanda pruning method"
27 |
28 | # llama-30b with sparsegpt pruning method
29 | echo "Running with sparsegpt pruning method"
30 | run_python_command "sparsegpt" "unstructured" "out/llama_30b/unstructured/sparsegpt/"
31 | run_python_command "sparsegpt" "2:4" "out/llama_30b/2-4/sparsegpt/"
32 | run_python_command "sparsegpt" "4:8" "out/llama_30b/4-8/sparsegpt/"
33 | echo "Finished sparsegpt pruning method"
34 |
35 | # llama-30b with magnitude pruning method
36 | echo "Running with magnitude pruning method"
37 | run_python_command "magnitude" "unstructured" "out/llama_30b/unstructured/magnitude/"
38 | run_python_command "magnitude" "2:4" "out/llama_30b/2-4/magnitude/"
39 | run_python_command "magnitude" "4:8" "out/llama_30b/4-8/magnitude/"
40 | echo "Finished magnitude pruning method"
--------------------------------------------------------------------------------
/scripts/llama_65b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set common variables
4 | model="decapoda-research/llama-65b-hf"
5 | sparsity_ratio=0.5
6 |
7 | # Define function to run python command
8 | run_python_command () {
9 | CUDA_VISIBLE_DEVICES=$1 python main.py \
10 | --model $model \
11 | --prune_method $2 \
12 | --sparsity_ratio $sparsity_ratio \
13 | --sparsity_type $3 \
14 | --save $4
15 | }
16 |
17 | # llama-65b with wanda pruning method
18 | echo "Running with wanda pruning method"
19 | run_python_command "0,1,2,3,4" "wanda" "unstructured" "out/llama_65b/unstructured/wanda/"
20 | run_python_command "0,1,2,3,4" "wanda" "2:4" "out/llama_65b/2-4/wanda/"
21 | run_python_command "0,1,2,3,4" "wanda" "4:8" "out/llama_65b/4-8/wanda/"
22 | echo "Finished wanda pruning method"
23 |
24 | # llama-65b with sparsegpt pruning method
25 | echo "Running with sparsegpt pruning method"
26 | run_python_command "0,1,2,3,4" "sparsegpt" "unstructured" "out/llama_65b/unstructured/sparsegpt/"
27 | run_python_command "0,1,2,3,4" "sparsegpt" "2:4" "out/llama_65b/2-4/sparsegpt/"
28 | run_python_command "0,1,2,3,4" "sparsegpt" "4:8" "out/llama_65b/4-8/sparsegpt/"
29 | echo "Finished sparsegpt pruning method"
30 |
31 | # llama-65b with magnitude pruning method
32 | echo "Running with magnitude pruning method"
33 | run_python_command "0,1,2,3" "magnitude" "unstructured" "out/llama_65b/unstructured/magnitude/"
34 | run_python_command "0,1,2,3" "magnitude" "2:4" "out/llama_65b/2-4/magnitude/"
35 | run_python_command "0,1,2,3" "magnitude" "4:8" "out/llama_65b/4-8/magnitude/"
36 | echo "Finished magnitude pruning method"
--------------------------------------------------------------------------------
/scripts/llama_7b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set common variables
4 | model="decapoda-research/llama-7b-hf"
5 | sparsity_ratio=0.5
6 | cuda_device=0
7 |
8 | # Set CUDA device visibility
9 | export CUDA_VISIBLE_DEVICES=$cuda_device
10 |
11 | # Define function to run python command
12 | run_python_command () {
13 | python main.py \
14 | --model $model \
15 | --prune_method $1 \
16 | --sparsity_ratio $sparsity_ratio \
17 | --sparsity_type $2 \
18 | --save $3
19 | }
20 |
21 | # llama-7b with wanda pruning method
22 | echo "Running with wanda pruning method"
23 | run_python_command "wanda" "unstructured" "out/llama_7b/unstructured/wanda/"
24 | run_python_command "wanda" "2:4" "out/llama_7b/2-4/wanda/"
25 | run_python_command "wanda" "4:8" "out/llama_7b/4-8/wanda/"
26 | echo "Finished wanda pruning method"
27 |
28 | # llama-7b with sparsegpt pruning method
29 | echo "Running with sparsegpt pruning method"
30 | run_python_command "sparsegpt" "unstructured" "out/llama_7b/unstructured/sparsegpt/"
31 | run_python_command "sparsegpt" "2:4" "out/llama_7b/2-4/sparsegpt/"
32 | run_python_command "sparsegpt" "4:8" "out/llama_7b/4-8/sparsegpt/"
33 | echo "Finished sparsegpt pruning method"
34 |
35 | # llama-7b with magnitude pruning method
36 | echo "Running with magnitude pruning method"
37 | run_python_command "magnitude" "unstructured" "out/llama_7b/unstructured/magnitude/"
38 | run_python_command "magnitude" "2:4" "out/llama_7b/2-4/magnitude/"
39 | run_python_command "magnitude" "4:8" "out/llama_7b/4-8/magnitude/"
40 | echo "Finished magnitude pruning method"
--------------------------------------------------------------------------------