├── quantization
├── __init__.py
├── quantizers
│ ├── __init__.py
│ ├── quantizer_utils.py
│ ├── base_quantizers.py
│ └── uniform_quantizers.py
├── qstates.py
├── utils.py
├── quantization_manager.py
├── hijacker.py
├── base_quantized_model.py
├── base_quantized_classes.py
├── autoquant_utils.py
└── range_estimators.py
├── transformers_language
├── __init__.py
├── models
│ ├── __init__.py
│ ├── softmax.py
│ ├── opt_attention.py
│ └── bert_attention.py
├── dataset_setups.py
├── quant_configs.py
├── utils.py
└── args.py
├── img
└── bert_attention_patterns.png
├── model_configs
├── opt-1.3b.yaml
├── opt-12L12H.yaml
├── opt-350m.yaml
├── opt-6L12H.yaml
└── bert-6L12H.yaml
├── docker
├── requirements.txt
└── Dockerfile
├── .gitignore
├── accelerate_configs
├── 1gpu_fp16.yaml
└── 1gpu_no_mp.yaml
├── scripts
├── opt_1.3b_vanilla.sh
├── opt_350m_vanilla.sh
├── opt_350m_gated_attention.sh
├── opt_1.3b_gated_attention.sh
├── bert_base_vanilla.sh
├── bert_base_clipped_softmax.sh
├── opt_125m_clipped_softmax.sh
├── bert_base_gated_attention.sh
├── opt_125m_vanilla.sh
└── opt_125m_gated_attention.sh
├── setup.py
├── LICENSE
├── README.md
├── validate_mlm.py
├── validate_clm.py
└── run_mlm.py
/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
--------------------------------------------------------------------------------
/transformers_language/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
--------------------------------------------------------------------------------
/transformers_language/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 |
--------------------------------------------------------------------------------
/img/bert_attention_patterns.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Qualcomm-AI-research/outlier-free-transformers/HEAD/img/bert_attention_patterns.png
--------------------------------------------------------------------------------
/model_configs/opt-1.3b.yaml:
--------------------------------------------------------------------------------
1 | max_position_embeddings: 1024
2 | ffn_dim: 8192
3 | hidden_size: 2048
4 | num_hidden_layers: 24
5 | num_attention_heads: 32
6 | init_std: 0.006
7 | dropout: 0.1
8 |
--------------------------------------------------------------------------------
/model_configs/opt-12L12H.yaml:
--------------------------------------------------------------------------------
1 | max_position_embeddings: 512
2 | ffn_dim: 3072
3 | hidden_size: 768
4 | num_hidden_layers: 12
5 | num_attention_heads: 12
6 | init_std: 0.006
7 | dropout: 0.1
8 |
--------------------------------------------------------------------------------
/model_configs/opt-350m.yaml:
--------------------------------------------------------------------------------
1 | max_position_embeddings: 1024
2 | ffn_dim: 4096
3 | hidden_size: 1024
4 | num_hidden_layers: 24
5 | num_attention_heads: 16
6 | init_std: 0.006
7 | dropout: 0.1
8 |
--------------------------------------------------------------------------------
/model_configs/opt-6L12H.yaml:
--------------------------------------------------------------------------------
1 | max_position_embeddings: 512
2 | ffn_dim: 3072
3 | hidden_size: 768
4 | num_hidden_layers: 6
5 | num_attention_heads: 12
6 | init_std: 0.006
7 | dropout: 0.1
8 |
--------------------------------------------------------------------------------
/docker/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate~=0.21.0
2 | datasets~=2.14.0
3 | numpy~=1.23.5
4 | pyyaml
5 | scipy~=1.10.1
6 | tensorboard~=2.14.0
7 | timm~=0.4.12
8 | tqdm
9 | transformers~=4.31.0
10 |
--------------------------------------------------------------------------------
/model_configs/bert-6L12H.yaml:
--------------------------------------------------------------------------------
1 | max_position_embeddings: 256
2 | num_hidden_layers: 6
3 | num_attention_heads: 12
4 | hidden_size: 768
5 | intermediate_size: 3072
6 | gradient_checkpointing: False
7 |
8 |
--------------------------------------------------------------------------------
/transformers_language/dataset_setups.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | from enum import auto
4 |
5 | from quantization.utils import BaseEnumOptions
6 |
7 |
8 | class DatasetSetups(BaseEnumOptions):
9 | wikitext_2 = auto()
10 | wikitext_103 = auto()
11 | bookcorpus_and_wiki = auto()
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # PyCharm
2 | .idea
3 |
4 | # MAC OS
5 | .DS_Store
6 |
7 | # pytest
8 | .coverage
9 | .pytest
10 | .pytest_cache
11 |
12 | # Python
13 | *__pycache__*
14 | *.py[cod]
15 | *.cpython-38.pyc
16 | *.egg-info
17 |
18 | # Patches (not to be committed by mistake)
19 | *.patch
20 |
21 | # Notebooks
22 | *.ipynb
23 |
24 | # Plots
25 | *.pdf
26 | *.png
27 | *.jpg
28 |
29 | # Hidden files
30 | .*
31 | !.gitignore
32 |
33 |
--------------------------------------------------------------------------------
/quantization/quantizers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | from quantization.quantizers.uniform_quantizers import (
4 | AsymmetricUniformQuantizer,
5 | SymmetricUniformQuantizer,
6 | )
7 | from quantization.utils import ClassEnumOptions, MethodMap
8 |
9 |
10 | class QMethods(ClassEnumOptions):
11 | symmetric_uniform = MethodMap(SymmetricUniformQuantizer)
12 | asymmetric_uniform = MethodMap(AsymmetricUniformQuantizer)
13 |
--------------------------------------------------------------------------------
/accelerate_configs/1gpu_fp16.yaml:
--------------------------------------------------------------------------------
1 | command_file: null
2 | commands: null
3 | compute_environment: LOCAL_MACHINE
4 | deepspeed_config: {}
5 | distributed_type: 'NO'
6 | downcast_bf16: 'no'
7 | fsdp_config: {}
8 | gpu_ids: all
9 | machine_rank: 0
10 | main_process_ip: null
11 | main_process_port: null
12 | main_training_function: main
13 | megatron_lm_config: {}
14 | mixed_precision: fp16
15 | num_machines: 1
16 | num_processes: 1
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_name: null
20 | tpu_zone: null
21 | use_cpu: false
22 |
--------------------------------------------------------------------------------
/accelerate_configs/1gpu_no_mp.yaml:
--------------------------------------------------------------------------------
1 | command_file: null
2 | commands: null
3 | compute_environment: LOCAL_MACHINE
4 | deepspeed_config: {}
5 | distributed_type: 'NO'
6 | downcast_bf16: 'no'
7 | fsdp_config: {}
8 | gpu_ids: all
9 | machine_rank: 0
10 | main_process_ip: null
11 | main_process_port: null
12 | main_training_function: main
13 | megatron_lm_config: {}
14 | mixed_precision: 'no'
15 | num_machines: 1
16 | num_processes: 1
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_name: null
20 | tpu_zone: null
21 | use_cpu: false
22 |
--------------------------------------------------------------------------------
/quantization/qstates.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | from enum import auto
4 |
5 | from quantization.utils import BaseEnumOptions
6 |
7 |
8 | class Qstates(BaseEnumOptions):
9 | estimate_ranges = auto() # ranges are updated in eval and train mode
10 | fix_ranges = auto() # quantization ranges are fixed for train and eval
11 | learn_ranges = auto() # quantization params are nn.Parameters
12 | estimate_ranges_train = (
13 | auto()
14 | ) # quantization ranges are updated during train and fixed for eval
15 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/cuda:11.3.1-base-ubuntu20.04
2 |
3 | WORKDIR /app
4 | ENV LC_ALL=C.UTF-8
5 | ENV LANG=C.UTF-8
6 |
7 | RUN apt update && \
8 | DEBIAN_FRONTEND=noninteractive apt install --yes --no-install-recommends \
9 | git \
10 | less \
11 | python-is-python3 \
12 | python3 \
13 | python3-pip \
14 | tree \
15 | vim \
16 | && \
17 | rm -rf /var/lib/apt/lists/*
18 |
19 | COPY docker/requirements.txt requirements.txt
20 |
21 | RUN python3 -m pip install --no-cache-dir --upgrade pip && \
22 | python3 -m pip install --no-cache-dir --upgrade --requirement requirements.txt
23 |
24 | COPY . /app
25 | RUN python3 -m pip install --no-cache-dir /app
26 |
--------------------------------------------------------------------------------
/scripts/opt_1.3b_vanilla.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \
3 | --pad_to_max_length \
4 | --with_tracking \
5 | --report_to tensorboard \
6 | --seed 1000 \
7 | --dataset_setup bookcorpus_and_wiki \
8 | --preprocessing_num_workers 4 \
9 | --data_cache_dir ~/.hf_data \
10 | --model_cache_dir ~/.hf_cache \
11 | --model_type opt \
12 | --tokenizer_name facebook/opt-350m \
13 | --max_seq_length 2048 \
14 | --block_size 512 \
15 | --learning_rate 0.0004 \
16 | --lr_scheduler_type linear \
17 | --max_train_steps 100000 \
18 | --num_warmup_steps 2000 \
19 | --per_device_train_batch_size 14 \
20 | --per_device_eval_batch_size 14 \
21 | --gradient_accumulation_steps 18 \
22 | --max_grad_norm 1.0 \
23 | --weight_decay 0.1 \
24 | --checkpointing_steps 1000 \
25 | --config_path model_configs/opt-1.3b.yaml \
26 | --attn_softmax vanilla \
27 | --output_dir output
28 |
--------------------------------------------------------------------------------
/scripts/opt_350m_vanilla.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \
3 | --pad_to_max_length \
4 | --with_tracking \
5 | --report_to tensorboard \
6 | --seed 1000 \
7 | --dataset_setup bookcorpus_and_wiki \
8 | --preprocessing_num_workers 4 \
9 | --data_cache_dir ~/.hf_data \
10 | --model_cache_dir ~/.hf_cache \
11 | --model_type opt \
12 | --tokenizer_name facebook/opt-350m \
13 | --max_seq_length 2048 \
14 | --block_size 512 \
15 | --learning_rate 0.0004 \
16 | --lr_scheduler_type linear \
17 | --max_train_steps 100000 \
18 | --num_warmup_steps 2000 \
19 | --per_device_train_batch_size 24 \
20 | --per_device_eval_batch_size 24 \
21 | --gradient_accumulation_steps 11 \
22 | --max_grad_norm 1.0 \
23 | --weight_decay 0.1 \
24 | --checkpointing_steps 1000 \
25 | --config_path model_configs/opt-350m.yaml \
26 | --attn_softmax vanilla \
27 | --output_dir output
28 |
--------------------------------------------------------------------------------
/scripts/opt_350m_gated_attention.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \
3 | --pad_to_max_length \
4 | --with_tracking \
5 | --report_to tensorboard \
6 | --seed 1000 \
7 | --dataset_setup bookcorpus_and_wiki \
8 | --preprocessing_num_workers 4 \
9 | --data_cache_dir ~/.hf_data \
10 | --model_cache_dir ~/.hf_cache \
11 | --model_type opt \
12 | --tokenizer_name facebook/opt-350m \
13 | --max_seq_length 2048 \
14 | --block_size 512 \
15 | --learning_rate 0.0004 \
16 | --lr_scheduler_type linear \
17 | --max_train_steps 100000 \
18 | --num_warmup_steps 2000 \
19 | --per_device_train_batch_size 24 \
20 | --per_device_eval_batch_size 24 \
21 | --gradient_accumulation_steps 11 \
22 | --max_grad_norm 1.0 \
23 | --weight_decay 0.1 \
24 | --checkpointing_steps 1000 \
25 | --config_path model_configs/opt-350m.yaml \
26 | --attn_gate_type conditional_per_token \
27 | --attn_gate_init 0.25 \
28 | --output_dir output
29 |
--------------------------------------------------------------------------------
/scripts/opt_1.3b_gated_attention.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \
3 | --pad_to_max_length \
4 | --with_tracking \
5 | --report_to tensorboard \
6 | --seed 1000 \
7 | --dataset_setup bookcorpus_and_wiki \
8 | --preprocessing_num_workers 4 \
9 | --data_cache_dir ~/.hf_data \
10 | --model_cache_dir ~/.hf_cache \
11 | --model_type opt \
12 | --tokenizer_name facebook/opt-350m \
13 | --max_seq_length 2048 \
14 | --block_size 512 \
15 | --learning_rate 0.0004 \
16 | --lr_scheduler_type linear \
17 | --max_train_steps 100000 \
18 | --num_warmup_steps 2000 \
19 | --per_device_train_batch_size 14 \
20 | --per_device_eval_batch_size 14 \
21 | --gradient_accumulation_steps 18 \
22 | --max_grad_norm 1.0 \
23 | --weight_decay 0.1 \
24 | --checkpointing_steps 1000 \
25 | --config_path model_configs/opt-1.3b.yaml \
26 | --attn_gate_type conditional_per_token \
27 | --attn_gate_linear_all_features \
28 | --output_dir output
29 |
--------------------------------------------------------------------------------
/scripts/bert_base_vanilla.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_mlm.py \
3 | --with_tracking \
4 | --report_to tensorboard \
5 | --extra_tb_stats \
6 | --seed 1000 \
7 | --dataset_setup bookcorpus_and_wiki \
8 | --preprocessing_num_workers 4 \
9 | --data_cache_dir ~/.hf_data \
10 | --model_cache_dir ~/.hf_cache \
11 | --model_type bert \
12 | --tokenizer_name bert-base-uncased \
13 | --max_seq_length 128 \
14 | --mlm_probability 0.15 \
15 | --learning_rate 0.0001 \
16 | --lr_scheduler_type linear \
17 | --max_train_steps 1000000 \
18 | --num_warmup_steps 10000 \
19 | --per_device_train_batch_size 256 \
20 | --per_device_eval_batch_size 256 \
21 | --gradient_accumulation_steps 1 \
22 | --max_grad_norm 1.0 \
23 | --weight_decay 0.01 \
24 | --config_name bert-base-uncased \
25 | --checkpointing_steps 100000 \
26 | --tb_scalar_log_interval 2000 \
27 | --tb_hist_log_interval 100000 \
28 | --attn_softmax vanilla \
29 | --output_dir output
30 |
--------------------------------------------------------------------------------
/scripts/bert_base_clipped_softmax.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_mlm.py \
3 | --with_tracking \
4 | --report_to tensorboard \
5 | --extra_tb_stats \
6 | --seed 1000 \
7 | --dataset_setup bookcorpus_and_wiki \
8 | --preprocessing_num_workers 4 \
9 | --data_cache_dir ~/.hf_data \
10 | --model_cache_dir ~/.hf_cache \
11 | --model_type bert \
12 | --tokenizer_name bert-base-uncased \
13 | --max_seq_length 128 \
14 | --mlm_probability 0.15 \
15 | --learning_rate 0.0001 \
16 | --lr_scheduler_type linear \
17 | --max_train_steps 1000000 \
18 | --num_warmup_steps 10000 \
19 | --per_device_train_batch_size 256 \
20 | --per_device_eval_batch_size 256 \
21 | --gradient_accumulation_steps 1 \
22 | --max_grad_norm 1.0 \
23 | --weight_decay 0.01 \
24 | --config_name bert-base-uncased \
25 | --checkpointing_steps 100000 \
26 | --tb_scalar_log_interval 2000 \
27 | --tb_hist_log_interval 100000 \
28 | --attn_softmax "clipped(-.025:1)" \
29 | --output_dir output
30 |
--------------------------------------------------------------------------------
/scripts/opt_125m_clipped_softmax.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \
3 | --pad_to_max_length \
4 | --wd_LN_gamma \
5 | --with_tracking \
6 | --report_to tensorboard \
7 | --extra_tb_stats \
8 | --seed 1000 \
9 | --dataset_setup bookcorpus_and_wiki \
10 | --preprocessing_num_workers 4 \
11 | --data_cache_dir ~/.hf_data \
12 | --model_cache_dir ~/.hf_cache \
13 | --model_type opt \
14 | --tokenizer_name facebook/opt-350m \
15 | --max_seq_length 2048 \
16 | --block_size 512 \
17 | --learning_rate 0.0004 \
18 | --lr_scheduler_type linear \
19 | --max_train_steps 125000 \
20 | --num_warmup_steps 2000 \
21 | --per_device_train_batch_size 48 \
22 | --per_device_eval_batch_size 48 \
23 | --gradient_accumulation_steps 4 \
24 | --max_grad_norm 1.0 \
25 | --weight_decay 0.1 \
26 | --checkpointing_steps 10000 \
27 | --tb_scalar_log_interval 2000 \
28 | --tb_hist_log_interval 10000 \
29 | --config_path model_configs/opt-12L12H.yaml \
30 | --alpha 12 \
31 | --output_dir output
32 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from pathlib import Path
3 |
4 | scripts = ["run_mlm", "run_clm", "validate_mlm", "validate_clm"]
5 | bash_scripts = Path("scripts").glob("*.sh")
6 |
7 | setup(
8 | name="outlier_free_transformers",
9 | version="1.0.0",
10 | packages=[
11 | "quantization",
12 | "quantization.quantizers",
13 | "transformers_language",
14 | "transformers_language.models",
15 | ],
16 | py_modules=scripts,
17 | scripts=[str(path) for path in bash_scripts],
18 | entry_points={"console_scripts": [f"{script} = {script}:main" for script in scripts]},
19 | url="https://github.com/Qualcomm-AI-research/outlier-free-transformers",
20 | license="BSD 3-Clause Clear License",
21 | author="Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort",
22 | author_email="{ybond, markusn, tijmen}@qti.qualcomm.com",
23 | description='Code for "Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing"',
24 | )
25 |
--------------------------------------------------------------------------------
/scripts/bert_base_gated_attention.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_mlm.py \
3 | --with_tracking \
4 | --report_to tensorboard \
5 | --extra_tb_stats \
6 | --seed 1000 \
7 | --dataset_setup bookcorpus_and_wiki \
8 | --preprocessing_num_workers 4 \
9 | --data_cache_dir ~/.hf_data \
10 | --model_cache_dir ~/.hf_cache \
11 | --model_type bert \
12 | --tokenizer_name bert-base-uncased \
13 | --max_seq_length 128 \
14 | --mlm_probability 0.15 \
15 | --learning_rate 0.0001 \
16 | --lr_scheduler_type linear \
17 | --max_train_steps 1000000 \
18 | --num_warmup_steps 10000 \
19 | --per_device_train_batch_size 256 \
20 | --per_device_eval_batch_size 256 \
21 | --gradient_accumulation_steps 1 \
22 | --max_grad_norm 1.0 \
23 | --weight_decay 0.01 \
24 | --config_name bert-base-uncased \
25 | --checkpointing_steps 100000 \
26 | --tb_scalar_log_interval 2000 \
27 | --tb_hist_log_interval 100000 \
28 | --attn_gate_type conditional_per_token \
29 | --attn_gate_mlp \
30 | --output_dir output
31 |
--------------------------------------------------------------------------------
/scripts/opt_125m_vanilla.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \
3 | --pad_to_max_length \
4 | --wd_LN_gamma \
5 | --with_tracking \
6 | --report_to tensorboard \
7 | --extra_tb_stats \
8 | --seed 1000 \
9 | --dataset_setup bookcorpus_and_wiki \
10 | --preprocessing_num_workers 4 \
11 | --data_cache_dir ~/.hf_data \
12 | --model_cache_dir ~/.hf_cache \
13 | --model_type opt \
14 | --tokenizer_name facebook/opt-350m \
15 | --max_seq_length 2048 \
16 | --block_size 512 \
17 | --learning_rate 0.0004 \
18 | --lr_scheduler_type linear \
19 | --max_train_steps 125000 \
20 | --num_warmup_steps 2000 \
21 | --per_device_train_batch_size 48 \
22 | --per_device_eval_batch_size 48 \
23 | --gradient_accumulation_steps 4 \
24 | --max_grad_norm 1.0 \
25 | --weight_decay 0.1 \
26 | --checkpointing_steps 10000 \
27 | --tb_scalar_log_interval 2000 \
28 | --tb_hist_log_interval 10000 \
29 | --config_path model_configs/opt-12L12H.yaml \
30 | --attn_softmax vanilla \
31 | --output_dir output
32 |
--------------------------------------------------------------------------------
/quantization/quantizers/quantizer_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import torch
4 | from torch.autograd import Function
5 |
6 |
7 | class RoundStraightThrough(Function):
8 | @staticmethod
9 | def forward(ctx, x):
10 | return torch.round(x)
11 |
12 | @staticmethod
13 | def backward(ctx, output_grad):
14 | return output_grad
15 |
16 |
17 | class ScaleGradient(Function):
18 | @staticmethod
19 | def forward(ctx, x, scale):
20 | ctx.scale = scale
21 | return x
22 |
23 | @staticmethod
24 | def backward(ctx, output_grad):
25 | return output_grad * ctx.scale, None
26 |
27 |
28 | round_ste_func = RoundStraightThrough.apply
29 | scale_grad_func = ScaleGradient.apply
30 |
31 |
32 | class QuantizerNotInitializedError(Exception):
33 | """Raised when a quantizer has not initialized"""
34 |
35 | def __init__(self):
36 | super(QuantizerNotInitializedError, self).__init__(
37 | "Quantizer has not been initialized yet"
38 | )
39 |
--------------------------------------------------------------------------------
/scripts/opt_125m_gated_attention.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \
3 | --pad_to_max_length \
4 | --wd_LN_gamma \
5 | --with_tracking \
6 | --report_to tensorboard \
7 | --extra_tb_stats \
8 | --seed 1000 \
9 | --dataset_setup bookcorpus_and_wiki \
10 | --preprocessing_num_workers 4 \
11 | --data_cache_dir ~/.hf_data \
12 | --model_cache_dir ~/.hf_cache \
13 | --model_type opt \
14 | --tokenizer_name facebook/opt-350m \
15 | --max_seq_length 2048 \
16 | --block_size 512 \
17 | --learning_rate 0.0004 \
18 | --lr_scheduler_type linear \
19 | --max_train_steps 125000 \
20 | --num_warmup_steps 2000 \
21 | --per_device_train_batch_size 48 \
22 | --per_device_eval_batch_size 48 \
23 | --gradient_accumulation_steps 4 \
24 | --max_grad_norm 1.0 \
25 | --weight_decay 0.1 \
26 | --checkpointing_steps 10000 \
27 | --tb_scalar_log_interval 2000 \
28 | --tb_hist_log_interval 10000 \
29 | --config_path model_configs/opt-12L12H.yaml \
30 | --attn_gate_type conditional_per_token \
31 | --attn_gate_init 0.25 \
32 | --output_dir output
33 |
--------------------------------------------------------------------------------
/transformers_language/quant_configs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | from quantization.quantizers import QMethods
4 | from quantization.range_estimators import RangeEstimators
5 | from transformers_language.utils import DotDict
6 |
7 |
8 | def get_quant_config():
9 | config = DotDict()
10 | config.act_quant = DotDict(
11 | {
12 | "cross_entropy_layer": None,
13 | "num_batches": 16,
14 | "options": {},
15 | "quant_method": RangeEstimators.running_minmax,
16 | "std_dev": None,
17 | }
18 | )
19 | config.quant = DotDict(
20 | {
21 | "act_quant": True,
22 | "n_bits": 8,
23 | "n_bits_act": 8,
24 | "num_candidates": None,
25 | "per_channel": False,
26 | "percentile": None,
27 | "quant_setup": "all",
28 | "qmethod": QMethods.symmetric_uniform,
29 | "qmethod_act": QMethods.asymmetric_uniform,
30 | "weight_quant": True,
31 | "weight_quant_method": RangeEstimators.current_minmax,
32 | }
33 | )
34 | return config
35 |
--------------------------------------------------------------------------------
/quantization/quantizers/base_quantizers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | from torch import nn
4 |
5 |
6 | class QuantizerBase(nn.Module):
7 | def __init__(self, n_bits, *args, per_channel=False, act_quant=False, **kwargs):
8 | super().__init__(*args, **kwargs)
9 | self.n_bits = n_bits
10 | self.act_quant = act_quant
11 | self.per_channel = per_channel
12 | self.state = None
13 | self.x_min_fp32 = self.x_max_fp32 = None
14 |
15 | @property
16 | def is_initialized(self):
17 | raise NotImplementedError()
18 |
19 | @property
20 | def x_max(self):
21 | raise NotImplementedError()
22 |
23 | @property
24 | def symmetric(self):
25 | raise NotImplementedError()
26 |
27 | @property
28 | def x_min(self):
29 | raise NotImplementedError()
30 |
31 | def forward(self, x_float):
32 | raise NotImplementedError()
33 |
34 | def _adjust_params_per_channel(self, x):
35 | raise NotImplementedError()
36 |
37 | def set_quant_range(self, x_min, x_max):
38 | raise NotImplementedError()
39 |
40 | def extra_repr(self):
41 | return "n_bits={}, per_channel={}, is_initalized={}".format(
42 | self.n_bits, self.per_channel, self.is_initialized
43 | )
44 |
45 | def reset(self):
46 | self._delta = None
47 |
48 | def fix_ranges(self):
49 | raise NotImplementedError()
50 |
51 | def make_range_trainable(self):
52 | raise NotImplementedError()
53 |
--------------------------------------------------------------------------------
/quantization/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | from collections import namedtuple
4 | from enum import Flag, auto
5 | from functools import partial
6 |
7 | import numpy as np
8 |
9 |
10 | def to_numpy(tensor):
11 | """
12 | Helper function that turns the given tensor into a numpy array
13 |
14 | Parameters
15 | ----------
16 | tensor : torch.Tensor
17 |
18 | Returns
19 | -------
20 | tensor : float or np.array
21 |
22 | """
23 | if isinstance(tensor, np.ndarray):
24 | return tensor
25 | if hasattr(tensor, "is_cuda"):
26 | if tensor.is_cuda:
27 | return tensor.cpu().detach().numpy()
28 | if hasattr(tensor, "detach"):
29 | return tensor.detach().numpy()
30 | if hasattr(tensor, "numpy"):
31 | return tensor.numpy()
32 |
33 | return np.array(tensor)
34 |
35 |
36 | class BaseEnumOptions(Flag):
37 | def __str__(self):
38 | return self.name
39 |
40 | @classmethod
41 | def list_names(cls):
42 | return [m.name for m in cls]
43 |
44 |
45 | class ClassEnumOptions(BaseEnumOptions):
46 | @property
47 | def cls(self):
48 | return self.value.cls
49 |
50 | def __call__(self, *args, **kwargs):
51 | return self.value.cls(*args, **kwargs)
52 |
53 |
54 | class StopForwardException(Exception):
55 | """Used to throw and catch an exception to stop traversing the graph."""
56 |
57 | pass
58 |
59 |
60 | MethodMap = partial(namedtuple("MethodMap", ["value", "cls"]), auto())
61 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2023 Qualcomm Technologies, Inc.
2 |
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without modification, are permitted
6 | (subject to the limitations in the disclaimer below) provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright notice, this list of conditions
9 | and the following disclaimer:
10 |
11 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions
12 | and the following disclaimer in the documentation and/or other materials provided with the
13 | istribution.
14 |
15 | * Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to
16 | endorse or promote products derived from this software without specific prior written permission.
17 |
18 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS
19 | SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
20 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
26 | THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 |
--------------------------------------------------------------------------------
/transformers_language/models/softmax.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | from functools import partial
4 |
5 | import torch
6 |
7 |
8 | def clipped_softmax(data, dim=1, eta=1.1, gamma=-0.1, **kw):
9 | sm_out = torch.nn.functional.softmax(data, dim=dim, **kw)
10 | stretched_out = sm_out * (eta - gamma) + gamma
11 | return torch.clip(stretched_out, 0, 1)
12 |
13 |
14 | SOFTMAX_MAPPING = {
15 | "vanilla": torch.nn.functional.softmax,
16 | # Clipped softmax
17 | "clipped(0:1.0003)": partial(clipped_softmax, gamma=0, eta=1.0003),
18 | "clipped(0:1.001)": partial(clipped_softmax, gamma=0, eta=1.001),
19 | "clipped(0:1.002)": partial(clipped_softmax, gamma=0, eta=1.002),
20 | "clipped(0:1.003)": partial(clipped_softmax, gamma=0, eta=1.003),
21 | "clipped(0:1.004)": partial(clipped_softmax, gamma=0, eta=1.004),
22 | "clipped(0:1.01)": partial(clipped_softmax, gamma=0, eta=1.01),
23 | "clipped(0:1.02)": partial(clipped_softmax, gamma=0, eta=1.02),
24 | "clipped(0:1.03)": partial(clipped_softmax, gamma=0, eta=1.03),
25 | "clipped(0:1.1)": partial(clipped_softmax, gamma=0, eta=1.1),
26 | "clipped(-.1:1)": partial(clipped_softmax, gamma=-0.1, eta=1.0),
27 | "clipped(-.00001:1)": partial(clipped_softmax, gamma=-0.00001, eta=1.0),
28 | "clipped(-.00003:1)": partial(clipped_softmax, gamma=-0.00003, eta=1.0),
29 | "clipped(-.0001:1)": partial(clipped_softmax, gamma=-0.0001, eta=1.0),
30 | "clipped(-.0003:1)": partial(clipped_softmax, gamma=-0.0003, eta=1.0),
31 | "clipped(-.0005:1)": partial(clipped_softmax, gamma=-0.0005, eta=1.0),
32 | "clipped(-.001:1)": partial(clipped_softmax, gamma=-0.001, eta=1.0),
33 | "clipped(-.002:1)": partial(clipped_softmax, gamma=-0.002, eta=1.0),
34 | "clipped(-.0025:1)": partial(clipped_softmax, gamma=-0.0025, eta=1.0),
35 | "clipped(-.003:1)": partial(clipped_softmax, gamma=-0.003, eta=1.0),
36 | "clipped(-.004:1)": partial(clipped_softmax, gamma=-0.004, eta=1.0),
37 | "clipped(-.005:1)": partial(clipped_softmax, gamma=-0.005, eta=1.0),
38 | "clipped(-.01:1)": partial(clipped_softmax, gamma=-0.01, eta=1.0),
39 | "clipped(-.015:1)": partial(clipped_softmax, gamma=-0.015, eta=1.0),
40 | "clipped(-.02:1)": partial(clipped_softmax, gamma=-0.02, eta=1.0),
41 | "clipped(-.025:1)": partial(clipped_softmax, gamma=-0.025, eta=1.0),
42 | "clipped(-.03:1)": partial(clipped_softmax, gamma=-0.03, eta=1.0),
43 | "clipped(-.04:1)": partial(clipped_softmax, gamma=-0.04, eta=1.0),
44 | "clipped(-.001:1.001)": partial(clipped_softmax, gamma=-0.001, eta=1.001),
45 | "clipped(-.002:1.002)": partial(clipped_softmax, gamma=-0.002, eta=1.002),
46 | "clipped(-.003:1.003)": partial(clipped_softmax, gamma=-0.003, eta=1.003),
47 | "clipped(-.005:1.005)": partial(clipped_softmax, gamma=-0.003, eta=1.005),
48 | "clipped(-.01:1.01)": partial(clipped_softmax, gamma=-0.01, eta=1.01),
49 | "clipped(-.03:1.03)": partial(clipped_softmax, gamma=-0.03, eta=1.03),
50 | "clipped(-.1:1.1)": partial(clipped_softmax, gamma=-0.1, eta=1.1),
51 | }
52 |
--------------------------------------------------------------------------------
/transformers_language/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import torch
4 | import torch.nn as nn
5 |
6 | from quantization.range_estimators import RangeEstimators
7 | from quantization.utils import StopForwardException
8 |
9 |
10 | def kurtosis(x, eps=1e-6):
11 | """x - (B, d)"""
12 | mu = x.mean(dim=1, keepdims=True)
13 | s = x.std(dim=1)
14 | mu4 = ((x - mu) ** 4.0).mean(dim=1)
15 | k = mu4 / (s**4.0 + eps)
16 | return k
17 |
18 |
19 | def count_params(module):
20 | return len(nn.utils.parameters_to_vector(module.parameters()))
21 |
22 |
23 | def val_qparams(config):
24 | weight_range_options = {}
25 | if config.quant.weight_quant_method == RangeEstimators.MSE:
26 | weight_range_options = dict(opt_method=config.quant.weight_opt_method)
27 | if config.quant.num_candidates is not None:
28 | weight_range_options["num_candidates"] = config.quant.num_candidates
29 |
30 | params = {
31 | "method": config.quant.qmethod.cls,
32 | "n_bits": config.quant.n_bits,
33 | "n_bits_act": config.quant.n_bits_act,
34 | "act_method": config.quant.qmethod_act.cls,
35 | "per_channel_weights": config.quant.per_channel,
36 | "percentile": config.quant.percentile,
37 | "quant_setup": config.quant.quant_setup,
38 | "weight_range_method": config.quant.weight_quant_method.cls,
39 | "weight_range_options": weight_range_options,
40 | "act_range_method": config.act_quant.quant_method.cls,
41 | "act_range_options": config.act_quant.options,
42 | }
43 | return params
44 |
45 |
46 | def pass_data_for_range_estimation(loader, model, act_quant=None, max_num_batches=20, inp_idx=0):
47 | model.eval()
48 | batches = []
49 | device = next(model.parameters()).device
50 | with torch.no_grad():
51 | for i, data in enumerate(loader):
52 | try:
53 | if isinstance(data, (tuple, list)):
54 | x = data[inp_idx].to(device=device)
55 | batches.append(x.data.cpu().numpy())
56 | model(x)
57 | print(f"proccesed step={i}")
58 | else:
59 | x = {k: v.to(device=device) for k, v in data.items()}
60 | model(**x)
61 | print(f"proccesed step={i}")
62 |
63 | if i >= max_num_batches - 1 or not act_quant:
64 | break
65 | except StopForwardException:
66 | pass
67 | return batches
68 |
69 |
70 | class DotDict(dict):
71 | """
72 | This class enables access to its attributes as both ['attr'] and .attr .
73 | Its advantage is that content of its `instance` can be accessed with `.`
74 | and still passed to functions as `**instance` (as dictionaries) for
75 | implementing variable-length arguments.
76 | """
77 |
78 | def __setattr__(self, key, value):
79 | self.__setitem__(key, value)
80 |
81 | def __delattr__(self, key):
82 | self.__delitem__(key)
83 |
84 | def __getattr__(self, key):
85 | if key in self:
86 | return self.__getitem__(key)
87 | raise AttributeError(f"DotDict instance has no key '{key}' ({self.keys()})")
88 |
--------------------------------------------------------------------------------
/quantization/quantization_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | from torch import nn
4 |
5 | from quantization.qstates import Qstates
6 | from quantization.quantizers import QMethods
7 | from quantization.quantizers.base_quantizers import QuantizerBase
8 | from quantization.quantizers.quantizer_utils import QuantizerNotInitializedError
9 | from quantization.range_estimators import RangeEstimatorBase, RangeEstimators
10 |
11 |
12 | class QuantizationManager(nn.Module):
13 | """Implementation of Quantization and Quantization Range Estimation
14 |
15 | Parameters
16 | ----------
17 | n_bits: int
18 | Number of bits for the quantization.
19 | qmethod: QMethods member (Enum)
20 | The quantization scheme to use, e.g. symmetric_uniform, asymmetric_uniform,
21 | qmn_uniform etc.
22 | init: RangeEstimators member (Enum)
23 | Initialization method for the grid from
24 | per_channel: bool
25 | If true, will use a separate quantization grid for each kernel/channel.
26 | x_min: float or PyTorch Tensor
27 | The minimum value which needs to be represented.
28 | x_max: float or PyTorch Tensor
29 | The maximum value which needs to be represented.
30 | """
31 |
32 | def __init__(
33 | self,
34 | qmethod: QuantizerBase = QMethods.symmetric_uniform.cls,
35 | init: RangeEstimatorBase = RangeEstimators.current_minmax.cls,
36 | per_channel=False,
37 | x_min=None,
38 | x_max=None,
39 | qparams=None,
40 | init_params=None,
41 | ):
42 | super().__init__()
43 | self.state = Qstates.estimate_ranges
44 | self.qmethod = qmethod
45 | self.init = init
46 | self.per_channel = per_channel
47 | self.qparams = qparams if qparams else {}
48 | self.init_params = init_params if init_params else {}
49 | self.range_estimator = None
50 |
51 | # define quantizer
52 | self.quantizer = self.qmethod(per_channel=self.per_channel, **qparams)
53 | self.quantizer.state = self.state
54 |
55 | # define range estimation method for quantizer initialisation
56 | if x_min is not None and x_max is not None:
57 | self.set_quant_range(x_min, x_max)
58 | self.fix_ranges()
59 | else:
60 | # set up the collector function to set the ranges
61 | self.range_estimator = self.init(
62 | per_channel=self.per_channel, quantizer=self.quantizer, **self.init_params
63 | )
64 |
65 | @property
66 | def n_bits(self):
67 | return self.quantizer.n_bits
68 |
69 | def estimate_ranges(self):
70 | self.state = Qstates.estimate_ranges
71 | self.quantizer.state = self.state
72 |
73 | def fix_ranges(self):
74 | if self.quantizer.is_initialized:
75 | self.state = Qstates.fix_ranges
76 | self.quantizer.state = self.state
77 | self.quantizer.fix_ranges()
78 | else:
79 | raise QuantizerNotInitializedError()
80 |
81 | def learn_ranges(self):
82 | self.quantizer.make_range_trainable()
83 | self.state = Qstates.learn_ranges
84 | self.quantizer.state = self.state
85 |
86 | def estimate_ranges_train(self):
87 | self.state = Qstates.estimate_ranges_train
88 | self.quantizer.state = self.state
89 |
90 | def reset_ranges(self):
91 | self.range_estimator.reset()
92 | self.quantizer.reset()
93 | self.estimate_ranges()
94 |
95 | def forward(self, x):
96 | if self.state == Qstates.estimate_ranges or (
97 | self.state == Qstates.estimate_ranges_train and self.training
98 | ):
99 | # Note this can be per tensor or per channel
100 | cur_xmin, cur_xmax = self.range_estimator(x)
101 | self.set_quant_range(cur_xmin, cur_xmax)
102 |
103 | return self.quantizer(x)
104 |
105 | def set_quant_range(self, x_min, x_max):
106 | self.quantizer.set_quant_range(x_min, x_max)
107 |
108 | def extra_repr(self):
109 | return "state={}".format(self.state.name)
110 |
--------------------------------------------------------------------------------
/quantization/hijacker.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import copy
4 |
5 | import torch
6 | from timm.models.layers import Swish
7 | from timm.models.layers.activations_me import SwishMe
8 | from torch import nn
9 |
10 | from quantization.base_quantized_classes import QuantizedModule
11 | from quantization.quantization_manager import QuantizationManager
12 | from quantization.range_estimators import RangeEstimators
13 | from quantization.utils import to_numpy
14 |
15 | activations_set = [
16 | nn.ReLU,
17 | nn.ReLU6,
18 | nn.Hardtanh,
19 | nn.Sigmoid,
20 | nn.Tanh,
21 | nn.GELU,
22 | nn.PReLU,
23 | Swish,
24 | SwishMe,
25 | ]
26 |
27 |
28 | class QuantizationHijacker(QuantizedModule):
29 | """Mixin class that 'hijacks' the forward pass in a module to perform quantization and
30 | dequantization on the weights and output distributions.
31 |
32 | Usage:
33 | To make a quantized nn.Linear layer:
34 | class HijackedLinear(QuantizationHijacker, nn.Linear):
35 | pass
36 |
37 | It is vital that QSchemeForwardHijacker is the first parent class, and that the second parent
38 | class derives from nn.Module, otherwise it will not be reached by a super(., .) call.
39 |
40 | NB: this implementation (for now) assumes that there will always be some training involved,
41 | e.g. to estimate the activation ranges.
42 | """
43 |
44 | def __init__(self, *args, activation: nn.Module = None, **kwargs):
45 | super().__init__(*args, **kwargs)
46 | if activation:
47 | assert isinstance(activation, tuple(activations_set)), str(activation)
48 |
49 | self.activation_function = copy.deepcopy(activation) if activation else None
50 |
51 | self.activation_quantizer = QuantizationManager(
52 | qmethod=self.act_method,
53 | init=self.act_range_method,
54 | qparams=self.act_qparams,
55 | init_params=self.act_range_options,
56 | )
57 |
58 | if self.weight_range_method == RangeEstimators.current_minmax:
59 | weight_init_params = dict(percentile=self.percentile)
60 | else:
61 | weight_init_params = self.weight_range_options
62 |
63 | self.weight_quantizer = QuantizationManager(
64 | qmethod=self.method,
65 | init=self.weight_range_method,
66 | per_channel=self.per_channel_weights,
67 | qparams=self.weight_qparams,
68 | init_params=weight_init_params,
69 | )
70 |
71 | self.prune_manager = None
72 | if hasattr(self, "prune_method") and self.prune_method is not None:
73 | self.prune_manager = self.prune_method(self, **self.prune_kwargs)
74 |
75 | self.activation_save_target = None
76 | self.activation_save_name = None
77 |
78 | def forward(self, x, offsets=None):
79 | weight, bias = self.get_params()
80 | res = self.run_forward(x, weight, bias, offsets=offsets)
81 | res = self.quantize_activations(res)
82 | return res
83 |
84 | def get_params(self):
85 | if not self.training and self.cached_params:
86 | return self.cached_params
87 |
88 | weight, bias = self.get_weight_bias()
89 |
90 | if self.prune_manager is not None:
91 | weight, bias = self.prune_manager(weight, bias)
92 |
93 | if self._quant_w:
94 | weight = self.quantize_weights(weight)
95 |
96 | if self._caching and not self.training and self.cached_params is None:
97 | self.cached_params = (
98 | torch.Tensor(to_numpy(weight)).to(weight.device),
99 | torch.Tensor(to_numpy(bias)).to(bias.device) if bias is not None else None,
100 | )
101 | return weight, bias
102 |
103 | def quantize_weights(self, weights):
104 | return self.weight_quantizer(weights)
105 |
106 | def get_weight_bias(self):
107 | bias = None
108 | if hasattr(self, "bias"):
109 | bias = self.bias
110 | return self.weight, bias
111 |
112 | def run_forward(self, x, weight, bias, offsets=None):
113 | # Performs the actual linear operation of the layer
114 | raise NotImplementedError()
115 |
116 | def quantize_activations(self, activations):
117 | """Quantize a single activation tensor or all activations from a layer. I'm assuming that
118 | we should quantize all outputs for a layer with the same quantization scheme.
119 | """
120 | if self.activation_function is not None:
121 | activations = self.activation_function(activations)
122 |
123 | if self.activation_save_target is not None:
124 | self.activation_save_target[self.activation_save_name] = activations.data.cpu().numpy()
125 |
126 | if self._quant_a:
127 | activations = self.activation_quantizer(activations)
128 |
129 | if self.activation_save_target is not None:
130 | self.activation_save_target[
131 | self.activation_save_name + "_Q"
132 | ] = activations.data.cpu().numpy()
133 |
134 | return activations
135 |
--------------------------------------------------------------------------------
/quantization/base_quantized_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import warnings
4 | from typing import Dict, Union
5 |
6 | import torch
7 | from torch import Tensor, nn
8 |
9 | from quantization.base_quantized_classes import (
10 | QuantizedModule,
11 | _set_layer_estimate_ranges,
12 | _set_layer_estimate_ranges_train,
13 | _set_layer_fix_ranges,
14 | _set_layer_learn_ranges,
15 | )
16 | from quantization.quantizers.base_quantizers import QuantizerBase
17 |
18 |
19 | class QuantizedModel(nn.Module):
20 | """
21 | Parent class for a quantized model. This allows you to have convenience functions to put the
22 | whole model into quantization or full precision or to freeze BN. Otherwise it does not add any
23 | further functionality, so it is not a necessity that a quantized model uses this class.
24 | """
25 |
26 | def __init__(self, input_size=(1, 3, 224, 224)):
27 | """
28 | Parameters
29 | ----------
30 | input_size: Tuple with the input dimension for the model (including batch dimension)
31 | """
32 | super().__init__()
33 | self.input_size = input_size
34 |
35 | def load_state_dict(
36 | self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], strict: bool = True
37 | ):
38 | """
39 | This function overwrites the load_state_dict of nn.Module to ensure that quantization
40 | parameters are loaded correctly for quantized model.
41 |
42 | """
43 | quant_state_dict = {
44 | k: v for k, v in state_dict.items() if k.endswith("_quant_a") or k.endswith("_quant_w")
45 | }
46 |
47 | if quant_state_dict:
48 | # Case 1: the quantization states are stored in the state_dict
49 | super().load_state_dict(quant_state_dict, strict=False)
50 |
51 | else:
52 | # Case 2 (older models): the quantization states are NOT stored in the state_dict but
53 | # only the scale factor _delta.
54 | warnings.warn(
55 | "Old state_dict without quantization state included. Checking for " "_delta instead"
56 | )
57 | # Add quantization flags to the state_dict
58 | for name, module in self.named_modules():
59 | if isinstance(module, QuantizedModule):
60 | state_dict[".".join((name, "_quant_a"))] = torch.BoolTensor([False])
61 | state_dict[".".join((name, "_quant_w"))] = torch.BoolTensor([False])
62 | if (
63 | ".".join((name, "activation_quantizer", "quantizer", "_delta"))
64 | in state_dict.keys()
65 | ):
66 | module.quantized_acts()
67 | state_dict[".".join((name, "_quant_a"))] = torch.BoolTensor([True])
68 | if (
69 | ".".join((name, "weight_quantizer", "quantizer", "_delta"))
70 | in state_dict.keys()
71 | ):
72 | module.quantized_weights()
73 | state_dict[".".join((name, "_quant_w"))] = torch.BoolTensor([True])
74 |
75 | # Pass dummy data through quantized model to ensure all quantization parameters are
76 | # initialized with the correct dimensions (None tensors will lead to issues in state dict loading)
77 | device = next(self.parameters()).device
78 | dummy_input = torch.rand(*self.input_size, device=device)
79 | with torch.no_grad():
80 | self.forward(dummy_input)
81 |
82 | # Load state dict
83 | super().load_state_dict(state_dict, strict)
84 |
85 | def disable_caching(self):
86 | def _fn(layer):
87 | if isinstance(layer, QuantizedModule):
88 | layer.disable_caching()
89 |
90 | self.apply(_fn)
91 |
92 | def quantized_weights(self):
93 | def _fn(layer):
94 | if isinstance(layer, QuantizedModule):
95 | layer.quantized_weights()
96 |
97 | self.apply(_fn)
98 |
99 | def full_precision_weights(self):
100 | def _fn(layer):
101 | if isinstance(layer, QuantizedModule):
102 | layer.full_precision_weights()
103 |
104 | self.apply(_fn)
105 |
106 | def quantized_acts(self):
107 | def _fn(layer):
108 | if isinstance(layer, QuantizedModule):
109 | layer.quantized_acts()
110 |
111 | self.apply(_fn)
112 |
113 | def full_precision_acts(self):
114 | def _fn(layer):
115 | if isinstance(layer, QuantizedModule):
116 | layer.full_precision_acts()
117 |
118 | self.apply(_fn)
119 |
120 | def quantized(self):
121 | def _fn(layer):
122 | if isinstance(layer, QuantizedModule):
123 | layer.quantized()
124 |
125 | self.apply(_fn)
126 |
127 | def full_precision(self):
128 | def _fn(layer):
129 | if isinstance(layer, QuantizedModule):
130 | layer.full_precision()
131 |
132 | self.apply(_fn)
133 |
134 | # Methods for switching quantizer quantization states
135 | def learn_ranges(self):
136 | self.apply(_set_layer_learn_ranges)
137 |
138 | def fix_ranges(self):
139 | self.apply(_set_layer_fix_ranges)
140 |
141 | def estimate_ranges(self):
142 | self.apply(_set_layer_estimate_ranges)
143 |
144 | def estimate_ranges_train(self):
145 | self.apply(_set_layer_estimate_ranges_train)
146 |
147 | def set_quant_state(self, weight_quant, act_quant):
148 | if act_quant:
149 | self.quantized_acts()
150 | else:
151 | self.full_precision_acts()
152 |
153 | if weight_quant:
154 | self.quantized_weights()
155 | else:
156 | self.full_precision_weights()
157 |
158 | def grad_scaling(self, grad_scaling=True):
159 | def _fn(module):
160 | if isinstance(module, QuantizerBase):
161 | module.grad_scaling = grad_scaling
162 |
163 | self.apply(_fn)
164 |
--------------------------------------------------------------------------------
/quantization/base_quantized_classes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import torch
4 | from torch import nn
5 |
6 | from quantization.quantization_manager import QuantizationManager
7 | from quantization.quantizers.base_quantizers import QuantizerBase
8 | from quantization.quantizers.uniform_quantizers import AsymmetricUniformQuantizer
9 | from quantization.range_estimators import (
10 | CurrentMinMaxEstimator,
11 | RangeEstimatorBase,
12 | RunningMinMaxEstimator,
13 | )
14 |
15 |
16 | def _set_layer_learn_ranges(layer):
17 | if isinstance(layer, QuantizationManager):
18 | if layer.quantizer.is_initialized:
19 | layer.learn_ranges()
20 |
21 |
22 | def _set_layer_fix_ranges(layer):
23 | if isinstance(layer, QuantizationManager):
24 | if layer.quantizer.is_initialized:
25 | layer.fix_ranges()
26 |
27 |
28 | def _set_layer_estimate_ranges(layer):
29 | if isinstance(layer, QuantizationManager):
30 | layer.estimate_ranges()
31 |
32 |
33 | def _set_layer_estimate_ranges_train(layer):
34 | if isinstance(layer, QuantizationManager):
35 | if layer.quantizer.is_initialized:
36 | layer.estimate_ranges_train()
37 |
38 |
39 | class QuantizedModule(nn.Module):
40 | """
41 | Parent class for a quantized module. It adds the basic functionality of switching the module
42 | between quantized and full precision mode. It also defines the cached parameters and handles
43 | the reset of the cache properly.
44 | """
45 |
46 | def __init__(
47 | self,
48 | *args,
49 | method: QuantizerBase = AsymmetricUniformQuantizer,
50 | act_method=None,
51 | weight_range_method: RangeEstimatorBase = CurrentMinMaxEstimator,
52 | act_range_method: RangeEstimatorBase = RunningMinMaxEstimator,
53 | n_bits=8,
54 | n_bits_act=None,
55 | per_channel_weights=False,
56 | percentile=None,
57 | weight_range_options=None,
58 | act_range_options=None,
59 | scale_domain="linear",
60 | bayesian_bits_kwargs=None,
61 | prune_method=None,
62 | prune_kwargs=None,
63 | **kwargs
64 | ):
65 | kwargs.pop("act_quant_dict", None)
66 | kwargs.pop("quant_dict", None)
67 |
68 | super().__init__(*args, **kwargs)
69 |
70 | self.method = method
71 | self.act_method = act_method or method
72 | self.n_bits = n_bits
73 | self.n_bits_act = n_bits_act or n_bits
74 | self.per_channel_weights = per_channel_weights
75 | self.percentile = percentile
76 | self.weight_range_method = weight_range_method
77 | self.weight_range_options = weight_range_options if weight_range_options else {}
78 | self.act_range_method = act_range_method
79 | self.act_range_options = act_range_options if act_range_options else {}
80 | self.scale_domain = scale_domain
81 |
82 | self.bayesian_bits_kwargs = bayesian_bits_kwargs or {}
83 | self.prune_method = prune_method
84 | self.prune_kwargs = prune_kwargs
85 |
86 | self.cached_params = None
87 | self._caching = True
88 |
89 | self.quant_params = None
90 | self.register_buffer("_quant_w", torch.BoolTensor([False]))
91 | self.register_buffer("_quant_a", torch.BoolTensor([False]))
92 |
93 | self.act_qparams = dict(
94 | n_bits=self.n_bits_act,
95 | scale_domain=self.scale_domain,
96 | act_quant=True,
97 | **self.bayesian_bits_kwargs
98 | )
99 | self.weight_qparams = dict(
100 | n_bits=self.n_bits,
101 | scale_domain=self.scale_domain,
102 | act_quant=False,
103 | **self.bayesian_bits_kwargs
104 | )
105 |
106 | @property
107 | def caching(self):
108 | return self._caching
109 |
110 | @caching.setter
111 | def caching(self, value: bool):
112 | self._caching = value
113 | if not value:
114 | self.cached_params = None
115 |
116 | def quantized_weights(self):
117 | self.cached_params = None
118 | self._quant_w = torch.BoolTensor([True])
119 |
120 | def full_precision_weights(self):
121 | self.cached_params = None
122 | self._quant_w = torch.BoolTensor([False])
123 |
124 | def quantized_acts(self):
125 | self._quant_a = torch.BoolTensor([True])
126 |
127 | def full_precision_acts(self):
128 | self._quant_a = torch.BoolTensor([False])
129 |
130 | def quantized(self):
131 | self.quantized_weights()
132 | self.quantized_acts()
133 |
134 | def full_precision(self):
135 | self.full_precision_weights()
136 | self.full_precision_acts()
137 |
138 | def get_quantizer_status(self):
139 | return dict(quant_a=self._quant_a.item(), quant_w=self._quant_w.item())
140 |
141 | def set_quantizer_status(self, quantizer_status):
142 | if quantizer_status["quant_a"]:
143 | self.quantized_acts()
144 | else:
145 | self.full_precision_acts()
146 |
147 | if quantizer_status["quant_w"]:
148 | self.quantized_weights()
149 | else:
150 | self.full_precision_weights()
151 |
152 | def learn_ranges(self):
153 | self.apply(_set_layer_learn_ranges)
154 |
155 | def fix_ranges(self):
156 | self.apply(_set_layer_fix_ranges)
157 |
158 | def estimate_ranges(self):
159 | self.apply(_set_layer_estimate_ranges)
160 |
161 | def estimate_ranges_train(self):
162 | self.apply(_set_layer_estimate_ranges_train)
163 |
164 | def train(self, mode=True):
165 | super().train(mode)
166 | if mode:
167 | self.cached_params = None
168 | return self
169 |
170 | def _apply(self, *args, **kwargs):
171 | self.cached_params = None
172 | return super(QuantizedModule, self)._apply(*args, **kwargs)
173 |
174 | def extra_repr(self):
175 | quant_state = "weight_quant={}, act_quant={}".format(
176 | self._quant_w.item(), self._quant_a.item()
177 | )
178 | parent_repr = super().extra_repr()
179 | return "{},\n{}".format(parent_repr, quant_state) if parent_repr else quant_state
180 |
181 |
182 | class QuantizedActivation(QuantizedModule):
183 | def __init__(self, *args, **kwargs):
184 | super().__init__(*args, **kwargs)
185 | self.activation_quantizer = QuantizationManager(
186 | qmethod=self.act_method,
187 | qparams=self.act_qparams,
188 | init=self.act_range_method,
189 | init_params=self.act_range_options,
190 | )
191 |
192 | def quantize_activations(self, x):
193 | if self._quant_a:
194 | return self.activation_quantizer(x)
195 | else:
196 | return x
197 |
198 | def forward(self, x):
199 | return self.quantize_activations(x)
200 |
201 |
202 | class FP32Acts(nn.Module):
203 | def forward(self, x):
204 | return x
205 |
206 | def reset_ranges(self):
207 | pass
208 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Outlier-free Transformers
2 | This repository contains the implementation and LLM experiments for the paper presented in
3 |
4 | **Yelysei Bondarenko1, Markus Nagel1, Tijmen Blankevoort1,
5 | "Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing", NeurIPS 2023.** [[ArXiv]](https://arxiv.org/abs/2306.12929)
6 |
7 | 1 Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc.)
8 |
9 |
10 | ## Reference
11 | If you find our work useful, please cite
12 | ```
13 | @article{bondarenko2023quantizable,
14 | title={Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
15 | author={Bondarenko, Yelysei and Nagel, Markus and Blankevoort, Tijmen},
16 | journal={arXiv preprint arXiv:2306.12929},
17 | year={2023}
18 | }
19 | ```
20 |
21 | ## Abstract
22 |
23 |
24 |
25 |
26 | Many studies have shown that modern transformer models tend to learn strong outliers in their activations,
27 | making them difficult to quantize. We show that strong outliers are related to very specific behavior of attention
28 | heads that try to learn a _"no-op"_ or just a _partial update_ of the residual. To achieve the exact zeros needed in the
29 | attention matrix for a no-update, the input to the softmax is pushed to be larger and larger during training, causing
30 | outliers in other parts of the network.
31 |
32 | Based on these observations, we propose two simple (independent) modifications to the attention mechanism -
33 | **clipped softmax** and **gated attention**. We empirically show that models pre-trained using our methods learn
34 | significantly smaller outliers while maintaining and sometimes even improving the floating-point task performance.
35 | This enables us to quantize transformers to full INT8 quantization without any additional effort.
36 |
37 |
38 | ## Repository structure
39 | ```bash
40 | .
41 | ├── accelerate_configs # HuggingFace accelerate configs
42 | ├── docker # dockerfile and requirements files
43 | ├── img
44 | ├── model_configs # YAML configs for different model sizes
45 | ├── quantization # quantization tools and functionality
46 | │ └── quantizers
47 | ├── scripts # preset scripts for pre-training
48 | └── transformers_language # source code for quantized LLMs and our methods
49 | └── models
50 | └── run_clm.py # train a Causal Language Model (e.g., OPT)
51 | └── run_mlm.py # train a Masked Language Model (e.g., BERT)
52 | └── validate_clm.py # validate a Causal Language Model (e.g., OPT)
53 | └── validate_mlm.py # validate a Masked Language Model (e.g., BERT)
54 | ```
55 |
56 | ## How to install
57 | ### using docker (recommended)
58 | You can build and run the docker container as follows
59 | ```bash
60 | docker build -f docker/Dockerfile --tag outlier_free_transformers:latest .
61 | docker run -ti outlier_free_transformers:latest
62 | ```
63 |
64 |
65 | ### without docker
66 | Set locale variables and add the project root directory to your pythonpath:
67 | ```bash
68 | export LC_ALL=C.UTF-8
69 | export LANG=C.UTF-8
70 | export PYTHONPATH=${PYTHONPATH}:$(realpath "$PWD")
71 | ```
72 |
73 | Make sure to have Python ≥3.6 (tested with Python 3.8.10) and
74 | ensure the latest version of `pip` (**tested** with 23.2.1):
75 | ```bash
76 | pip install --upgrade --no-deps pip
77 | ```
78 |
79 | Next, install PyTorch 1.11 with the appropriate CUDA version (tested with CUDA 11.3, CuDNN 8.2.0):
80 | ```bash
81 | pip install torch==1.11.0 torchvision==0.12.0
82 | ```
83 |
84 | Install the remaining dependencies using pip:
85 | ```bash
86 | pip install -r docker/requirements.txt
87 | ```
88 |
89 | Finally, make all scripts executable
90 | ```bash
91 | chmod +x scripts/*.sh
92 | ```
93 |
94 | ## Pre-training commands
95 | All the training scripts (batch size, etc.) are set up to fit on a single A100 80GB GPU.
96 |
97 | | Model | Softmax | Script |
98 | |:----------|:----------------|:---------------------------------------------------------------------|
99 | | BERT-base | vanilla | [scripts/bert_base_vanilla.sh](scripts/bert_base_vanilla.sh) |
100 | | BERT-base | clipped softmax | [scripts/bert_base_clipped_softmax.sh](scripts/bert_base_clipped_softmax.sh) |
101 | | BERT-base | gated attention | [scripts/bert_base_gated_attention.sh](scripts/bert_base_gated_attention.sh) |
102 | | OPT-125m | vanilla | [scripts/opt_125m_vanilla.sh](scripts/opt_125m_vanilla.sh) |
103 | | OPT-125m | clipped softmax | [scripts/opt_125m_clipped_softmax.sh](scripts/opt_125m_clipped_softmax.sh) |
104 | | OPT-125m | gated attention | [scripts/opt_125m_gated_attention.sh](scripts/opt_125m_gated_attention.sh) |
105 | | OPT-350m | vanilla | [scripts/opt_350m_vanilla.sh](scripts/opt_350m_vanilla.sh) |
106 | | OPT-350m | gated attention | [scripts/opt_350m_gated_attention.sh](scripts/opt_350m_gated_attention.sh) |
107 | | OPT-1.3B | vanilla | [scripts/opt_1.3b_vanilla.sh](scripts/opt_1.3b_vanilla.sh) |
108 | | OPT-1.3B | gated attention | [scripts/opt_1.3b_gated_attention.sh](scripts/opt_1.3b_gated_attention.sh) |
109 |
110 | ## Validation commands
111 | After the model is trained, you can run evaluation (both floating point, and quantized) using
112 | the following commands.
113 | Make sure to pass the same softmax method arguments that were used for pre-training (e.g., `--attn_softmax vanilla`, `--attn_softmax "clipped(-.025:1)"`, `--alpha 12`, `--attn_gate_type conditional_per_token --attn_gate_mlp`, `--attn_gate_type conditional_per_token --attn_gate_init 0.25` etc.)
114 |
115 | ### FP validation for BERT models
116 | Run command:
117 | ```bash
118 | accelerate launch validate_mlm.py \
119 | --seed 1000 \
120 | --dataset_setup bookcorpus_and_wiki \
121 | --preprocessing_num_workers 8 \
122 | --model_type bert \
123 | --max_seq_length 128 \
124 | --mlm_probability 0.15 \
125 | --per_device_eval_batch_size 32 \
126 | --data_cache_dir ~/.hf_data \
127 | --model_cache_dir ~/.hf_cache \
128 | --model_name_or_path /path/to/saved/checkpoint \
129 | --output_dir output_metrics
130 | ```
131 | Expected (example) output:
132 | ```
133 | INFO - validate_mlm - perplexity: 4.5438
134 | INFO - validate_mlm - max FFN output inf norm: 20.8
135 | INFO - validate_mlm - max FFN input + output inf norm: 47.0
136 | INFO - validate_mlm - max LN(FFN i + o) inf norm: 25.9
137 | INFO - validate_mlm - Avg Kurtosis: 72.19
138 | INFO - validate_mlm - Max Kurtosis: 197.6
139 | ```
140 |
141 |
142 | ### INT8 validation for BERT models
143 | Run command:
144 | ```bash
145 | accelerate launch validate_mlm.py \
146 | --quantize \
147 | --est_num_batches 16 \
148 | --seed 2000 \
149 | --dataset_setup bookcorpus_and_wiki \
150 | --preprocessing_num_workers 8 \
151 | --model_type bert \
152 | --max_seq_length 128 \
153 | --mlm_probability 0.15 \
154 | --per_device_eval_batch_size 32 \
155 | --data_cache_dir ~/.hf_data \
156 | --model_cache_dir ~/.hf_cache \
157 | --model_name_or_path /path/to/saved/checkpoint \
158 | --output_dir output_metrics
159 | ```
160 | Expected (example) output:
161 | ```
162 | INFO - validate_mlm - perplexity: 4.6550
163 | INFO - validate_mlm - max FFN output inf norm: 20.6
164 | INFO - validate_mlm - max FFN input + output inf norm: 47.0
165 | INFO - validate_mlm - max LN(FFN i + o) inf norm: 25.7
166 | INFO - validate_mlm - Avg Kurtosis: 70.32
167 | INFO - validate_mlm - Max Kurtosis: 188.4
168 | ```
169 |
170 | ### FP validation for OPT models
171 | Run command:
172 | ```bash
173 | accelerate launch validate_clm.py \
174 | --seed 1000 \
175 | --dataset_setup bookcorpus_and_wiki \
176 | --preprocessing_num_workers 4 \
177 | --model_type opt \
178 | --block_size 512 \
179 | --per_device_eval_batch_size 4 \
180 | --data_cache_dir ~/.hf_data \
181 | --model_cache_dir ~/.hf_cache \
182 | --model_name_or_path /path/to/saved/checkpoint \
183 | --output_dir output_metrics
184 | ```
185 | Expected (example) output:
186 | ```bash
187 | INFO - validate_clm - perplexity: 15.5449
188 | INFO - validate_clm - Max inf norm: 122.4
189 | INFO - validate_clm - Max FFN inf norm: 0.5
190 | INFO - validate_clm - Max layer inf norm: 9.1
191 | INFO - validate_clm - Avg Kurtosis: 18.26
192 | INFO - validate_clm - Max Kurtosis: 151.5
193 | INFO - validate_clm - Max Kurtosis layers: 151.5
194 | INFO - validate_clm -
195 | # (...)
196 | # detailed per-layer stats
197 | ```
198 |
199 | ### INT8 validation for OPT models
200 | Run command:
201 | ```bash
202 | accelerate launch validate_clm.py \
203 | --quantize \
204 | --quant_setup fp32_head \
205 | --ranges_acts running_minmax \
206 | --qmethod_acts asymmetric_uniform \
207 | --percentile 99.999 \
208 | --est_num_batches 4 \
209 | --seed 2000 \
210 | --dataset_setup bookcorpus_and_wiki \
211 | --preprocessing_num_workers 4 \
212 | --model_type opt \
213 | --block_size 512 \
214 | --per_device_eval_batch_size 1 \
215 | --data_cache_dir ~/.hf_data \
216 | --model_cache_dir ~/.hf_cache \
217 | --model_name_or_path /path/to/saved/checkpoint \
218 | --output_dir output_metrics
219 | ```
220 | Expected (example) output:
221 | ```
222 | perplexity: 16.0132
223 | ```
224 |
--------------------------------------------------------------------------------
/quantization/autoquant_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import copy
4 | import warnings
5 |
6 | from torch import nn
7 | from torch.nn import functional as F
8 | from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AvgPoolNd
9 |
10 | from quantization.base_quantized_classes import (
11 | FP32Acts,
12 | QuantizedActivation,
13 | QuantizedModule,
14 | )
15 | from quantization.hijacker import QuantizationHijacker, activations_set
16 | from quantization.quantization_manager import QuantizationManager
17 |
18 |
19 | class QuantLinear(QuantizationHijacker, nn.Linear):
20 | def run_forward(self, x, weight, bias, offsets=None):
21 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias)
22 |
23 |
24 | class QuantizedActivationWrapper(QuantizedActivation):
25 | """
26 | Wraps over a layer and quantized the activation.
27 | It also allow for tying the input and output quantizer which is helpful
28 | for layers such Average Pooling
29 | """
30 |
31 | def __init__(
32 | self,
33 | layer,
34 | *args,
35 | tie_activation_quantizers=False,
36 | input_quantizer: QuantizationManager = None,
37 | **kwargs,
38 | ):
39 | super().__init__(*args, **kwargs)
40 | self.tie_activation_quantizers = tie_activation_quantizers
41 | if input_quantizer:
42 | assert isinstance(input_quantizer, QuantizationManager)
43 | self.activation_quantizer = input_quantizer
44 | self.layer = layer
45 |
46 | def quantize_activations_no_range_update(self, x):
47 | if self._quant_a:
48 | return self.activation_quantizer.quantizer(x)
49 | else:
50 | return x
51 |
52 | def forward(self, x):
53 | x = self.layer(x)
54 | if self.tie_activation_quantizers:
55 | # The input activation quantizer is used to quantize the activation
56 | # but without updating the quantization range
57 | return self.quantize_activations_no_range_update(x)
58 | else:
59 | return self.quantize_activations(x)
60 |
61 | def extra_repr(self):
62 | return f"tie_activation_quantizers={self.tie_activation_quantizers}"
63 |
64 |
65 | class QuantLayerNorm(QuantizationHijacker, nn.LayerNorm):
66 | def run_forward(self, x, weight, bias, offsets=None):
67 | return F.layer_norm(
68 | input=x.contiguous(),
69 | normalized_shape=self.normalized_shape,
70 | weight=weight.contiguous(),
71 | bias=bias.contiguous(),
72 | eps=self.eps,
73 | )
74 |
75 |
76 | class QuantEmbedding(QuantizationHijacker, nn.Embedding):
77 | def __init__(self, *args, activation=None, **kwargs):
78 | super().__init__(*args, activation=activation, **kwargs)
79 | # NB: We should not (re-)quantize activations of this module, as it is a
80 | # lookup table (=weights), which is already quantized
81 | self.activation_quantizer = FP32Acts()
82 |
83 | def run_forward(self, x, weight, bias, offsets=None):
84 | return F.embedding(
85 | input=x.contiguous(),
86 | weight=weight.contiguous(),
87 | padding_idx=self.padding_idx,
88 | max_norm=self.max_norm,
89 | norm_type=self.norm_type,
90 | scale_grad_by_freq=self.scale_grad_by_freq,
91 | sparse=self.sparse,
92 | )
93 |
94 |
95 | # Modules Map
96 | module_map = {nn.Linear: QuantLinear, nn.LayerNorm: QuantLayerNorm, nn.Embedding: QuantEmbedding}
97 |
98 |
99 | non_param_modules = (_AdaptiveAvgPoolNd, _AvgPoolNd)
100 |
101 |
102 | def next_bn(module, i):
103 | return len(module) > i + 1 and isinstance(module[i + 1], (nn.BatchNorm2d, nn.BatchNorm1d))
104 |
105 |
106 | def get_act(module, i):
107 | # Case 1: conv + act
108 | if len(module) - i > 1 and isinstance(module[i + 1], tuple(activations_set)):
109 | return module[i + 1], i + 1
110 |
111 | # Case 2: conv + bn + act
112 | if (
113 | len(module) - i > 2
114 | and next_bn(module, i)
115 | and isinstance(module[i + 2], tuple(activations_set))
116 | ):
117 | return module[i + 2], i + 2
118 |
119 | # Case 3: conv + bn + X -> return false
120 | # Case 4: conv + X -> return false
121 | return None, None
122 |
123 |
124 | def get_linear_args(module):
125 | args = dict(
126 | in_features=module.in_features,
127 | out_features=module.out_features,
128 | bias=module.bias is not None,
129 | )
130 | return args
131 |
132 |
133 | def get_layernorm_args(module):
134 | args = dict(normalized_shape=module.normalized_shape, eps=module.eps)
135 | return args
136 |
137 |
138 | def get_embedding_args(module):
139 | args = dict(
140 | num_embeddings=module.num_embeddings,
141 | embedding_dim=module.embedding_dim,
142 | padding_idx=module.padding_idx,
143 | max_norm=module.max_norm,
144 | norm_type=module.norm_type,
145 | scale_grad_by_freq=module.scale_grad_by_freq,
146 | sparse=module.sparse,
147 | )
148 | return args
149 |
150 |
151 | def get_module_args(mod, act):
152 | if isinstance(mod, nn.Linear):
153 | kwargs = get_linear_args(mod)
154 | elif isinstance(mod, nn.LayerNorm):
155 | kwargs = get_layernorm_args(mod)
156 | elif isinstance(mod, nn.Embedding):
157 | kwargs = get_embedding_args(mod)
158 | else:
159 | raise ValueError
160 |
161 | kwargs["activation"] = act
162 |
163 | return kwargs
164 |
165 |
166 | def quant_module(module, i, **quant_params):
167 | act, _ = get_act(module, i)
168 | modtype = module_map[type(module[i])]
169 |
170 | kwargs = get_module_args(module[i], act)
171 | new_module = modtype(**kwargs, **quant_params)
172 | new_module.weight.data = module[i].weight.data.clone()
173 |
174 | if module[i].bias is not None:
175 | new_module.bias.data = module[i].bias.data.clone()
176 |
177 | return new_module, i + int(bool(act)) + 1
178 |
179 |
180 | def quantize_sequential(model, specials=None, tie_activation_quantizers=False, **quant_params):
181 | specials = specials or dict()
182 |
183 | i = 0
184 | quant_modules = []
185 | while i < len(model):
186 | if isinstance(model[i], QuantizedModule):
187 | quant_modules.append(model[i])
188 |
189 | elif type(model[i]) in module_map:
190 | new_module, new_i = quant_module(model, i, **quant_params)
191 | quant_modules.append(new_module)
192 | i = new_i
193 | continue
194 |
195 | elif type(model[i]) in specials:
196 | quant_modules.append(specials[type(model[i])](model[i], **quant_params))
197 |
198 | elif isinstance(model[i], non_param_modules):
199 | # Check for last quantizer
200 | input_quantizer = None
201 | if quant_modules and isinstance(quant_modules[-1], QuantizedModule):
202 | last_layer = quant_modules[-1]
203 | input_quantizer = quant_modules[-1].activation_quantizer
204 | elif (
205 | quant_modules
206 | and isinstance(quant_modules[-1], nn.Sequential)
207 | and isinstance(quant_modules[-1][-1], QuantizedModule)
208 | ):
209 | last_layer = quant_modules[-1][-1]
210 | input_quantizer = quant_modules[-1][-1].activation_quantizer
211 |
212 | if input_quantizer and tie_activation_quantizers:
213 | # If input quantizer is found the tie input/output act quantizers
214 | print(
215 | f"Tying input quantizer {i-1}^th layer of type {type(last_layer)} to the "
216 | f"quantized {type(model[i])} following it"
217 | )
218 | quant_modules.append(
219 | QuantizedActivationWrapper(
220 | model[i],
221 | tie_activation_quantizers=tie_activation_quantizers,
222 | input_quantizer=input_quantizer,
223 | **quant_params,
224 | )
225 | )
226 | else:
227 | # Input quantizer not found
228 | quant_modules.append(QuantizedActivationWrapper(model[i], **quant_params))
229 | if tie_activation_quantizers:
230 | warnings.warn("Input quantizer not found, so we do not tie quantizers")
231 | else:
232 | quant_modules.append(quantize_model(model[i], specials=specials, **quant_params))
233 | i += 1
234 | return nn.Sequential(*quant_modules)
235 |
236 |
237 | def quantize_model(model, specials=None, tie_activation_quantizers=False, **quant_params):
238 | specials = specials or dict()
239 |
240 | if isinstance(model, nn.Sequential):
241 | quant_model = quantize_sequential(
242 | model, specials, tie_activation_quantizers, **quant_params
243 | )
244 |
245 | elif type(model) in specials:
246 | quant_model = specials[type(model)](model, **quant_params)
247 |
248 | elif isinstance(model, non_param_modules):
249 | quant_model = QuantizedActivationWrapper(model, **quant_params)
250 |
251 | elif type(model) in module_map:
252 | # If we do isinstance() then we might run into issues with modules that inherit from
253 | # one of these classes, for whatever reason
254 | modtype = module_map[type(model)]
255 | kwargs = get_module_args(model, None)
256 | quant_model = modtype(**kwargs, **quant_params)
257 |
258 | quant_model.weight.data = model.weight.data
259 | if getattr(model, "bias", None) is not None:
260 | quant_model.bias.data = model.bias.data
261 |
262 | else:
263 | # Unknown type, try to quantize all child modules
264 | quant_model = copy.deepcopy(model)
265 | for name, module in quant_model._modules.items():
266 | new_model = quantize_model(module, specials=specials, **quant_params)
267 | if new_model is not None:
268 | setattr(quant_model, name, new_model)
269 |
270 | return quant_model
271 |
--------------------------------------------------------------------------------
/quantization/quantizers/uniform_quantizers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import torch
4 |
5 | from quantization.quantizers.base_quantizers import QuantizerBase
6 | from quantization.quantizers.quantizer_utils import (
7 | QuantizerNotInitializedError,
8 | round_ste_func,
9 | scale_grad_func,
10 | )
11 |
12 |
13 | class AsymmetricUniformQuantizer(QuantizerBase):
14 | """
15 | PyTorch Module that implements Asymmetric Uniform Quantization using STE.
16 | Quantizes its argument in the forward pass, passes the gradient 'straight
17 | through' on the backward pass, ignoring the quantization that occurred.
18 |
19 | Parameters
20 | ----------
21 | n_bits: int
22 | Number of bits for quantization.
23 | scale_domain: str ('log', 'linear) with default='linear'
24 | Domain of scale factor
25 | per_channel: bool
26 | If True: allows for per-channel quantization
27 | """
28 |
29 | def __init__(self, n_bits, scale_domain="linear", grad_scaling=False, eps=1e-8, **kwargs):
30 | super().__init__(n_bits=n_bits, **kwargs)
31 |
32 | assert scale_domain in ("linear", "log")
33 | self.register_buffer("_delta", None)
34 | self.register_buffer("_zero_float", None)
35 | self.scale_domain = scale_domain
36 | self.grad_scaling = grad_scaling
37 | self.eps = eps
38 |
39 | # A few useful properties
40 | @property
41 | def delta(self):
42 | if self._delta is not None:
43 | return self._delta
44 | else:
45 | raise QuantizerNotInitializedError()
46 |
47 | @property
48 | def zero_float(self):
49 | if self._zero_float is not None:
50 | return self._zero_float
51 | else:
52 | raise QuantizerNotInitializedError()
53 |
54 | @property
55 | def is_initialized(self):
56 | return self._delta is not None
57 |
58 | @property
59 | def symmetric(self):
60 | return False
61 |
62 | @property
63 | def int_min(self):
64 | # integer grid minimum
65 | return 0.0
66 |
67 | @property
68 | def int_max(self):
69 | # integer grid maximum
70 | return 2.0**self.n_bits - 1
71 |
72 | @property
73 | def scale(self):
74 | if self.scale_domain == "linear":
75 | return torch.clamp(self.delta, min=self.eps)
76 | elif self.scale_domain == "log":
77 | return torch.exp(self.delta)
78 |
79 | @property
80 | def zero_point(self):
81 | zero_point = round_ste_func(self.zero_float)
82 | zero_point = torch.clamp(zero_point, self.int_min, self.int_max)
83 | return zero_point
84 |
85 | @property
86 | def x_max(self):
87 | return self.scale * (self.int_max - self.zero_point)
88 |
89 | @property
90 | def x_min(self):
91 | return self.scale * (self.int_min - self.zero_point)
92 |
93 | def to_integer_forward(self, x_float):
94 | """
95 | Qunatized input to its integer represantion
96 | Parameters
97 | ----------
98 | x_float: PyTorch Float Tensor
99 | Full-precision Tensor
100 |
101 | Returns
102 | -------
103 | x_int: PyTorch Float Tensor of integers
104 | """
105 | if self.grad_scaling:
106 | grad_scale = self.calculate_grad_scale(x_float)
107 | scale = scale_grad_func(self.scale, grad_scale)
108 | zero_point = (
109 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale)
110 | )
111 | else:
112 | scale = self.scale
113 | zero_point = self.zero_point
114 |
115 | x_int = round_ste_func(x_float / scale) + zero_point
116 | x_int = torch.clamp(x_int, self.int_min, self.int_max)
117 |
118 | return x_int
119 |
120 | def forward(self, x_float):
121 | """
122 | Quantizes (quantized to integer and the scales back to original domain)
123 | Parameters
124 | ----------
125 | x_float: PyTorch Float Tensor
126 | Full-precision Tensor
127 |
128 | Returns
129 | -------
130 | x_quant: PyTorch Float Tensor
131 | Quantized-Dequantized Tensor
132 | """
133 | if self.per_channel:
134 | self._adjust_params_per_channel(x_float)
135 |
136 | if self.grad_scaling:
137 | grad_scale = self.calculate_grad_scale(x_float)
138 | scale = scale_grad_func(self.scale, grad_scale)
139 | zero_point = (
140 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale)
141 | )
142 | else:
143 | scale = self.scale
144 | zero_point = self.zero_point
145 |
146 | x_int = self.to_integer_forward(x_float)
147 | x_quant = scale * (x_int - zero_point)
148 |
149 | return x_quant
150 |
151 | def calculate_grad_scale(self, quant_tensor):
152 | num_pos_levels = self.int_max # Qp in LSQ paper
153 | num_elements = quant_tensor.numel() # nfeatures or nweights in LSQ paper
154 | if self.per_channel:
155 | # In the per tensor case we do not sum the gradients over the output channel dimension
156 | num_elements /= quant_tensor.shape[0]
157 |
158 | return (num_pos_levels * num_elements) ** -0.5 # 1 / sqrt (Qn * nfeatures)
159 |
160 | def _adjust_params_per_channel(self, x):
161 | """
162 | Adjusts the quantization parameter tensors (delta, zero_float)
163 | to the input tensor shape if they don't match
164 | Parameters
165 | ----------
166 | x: input tensor
167 | """
168 | if x.ndim != self.delta.ndim:
169 | new_shape = [-1] + [1] * (len(x.shape) - 1)
170 | self._delta = self.delta.view(new_shape)
171 | if self._zero_float is not None:
172 | self._zero_float = self._zero_float.view(new_shape)
173 |
174 | def _tensorize_min_max(self, x_min, x_max):
175 | """
176 | Converts provided min max range into tensors
177 | Parameters
178 | ----------
179 | x_min: float or PyTorch 1D tensor
180 | x_max: float or PyTorch 1D tensor
181 |
182 | Returns
183 | -------
184 | x_min: PyTorch Tensor 0 or 1-D
185 | x_max: PyTorch Tensor 0 or 1-D
186 | """
187 | # Ensure a torch tensor
188 | if not torch.is_tensor(x_min):
189 | x_min = torch.tensor(x_min).float()
190 | x_max = torch.tensor(x_max).float()
191 |
192 | if x_min.dim() > 0 and len(x_min) > 1 and not self.per_channel:
193 | print(x_min)
194 | print(self.per_channel)
195 | raise ValueError(
196 | "x_min and x_max must be a float or 1-D Tensor"
197 | " for per-tensor quantization (per_channel=False)"
198 | )
199 | # Ensure we always use zero and avoid division by zero
200 | x_min = torch.min(x_min, torch.zeros_like(x_min))
201 | x_max = torch.max(x_max, torch.ones_like(x_max) * self.eps)
202 |
203 | return x_min, x_max
204 |
205 | def set_quant_range(self, x_min, x_max):
206 | """
207 | Instantiates the quantization parameters based on the provided
208 | min and max range
209 |
210 | Parameters
211 | ----------
212 | x_min: tensor or float
213 | Quantization range minimum limit
214 | x_max: tensor of float
215 | Quantization range minimum limit
216 | """
217 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max
218 | x_min, x_max = self._tensorize_min_max(x_min, x_max)
219 | self._delta = (x_max - x_min) / self.int_max
220 | self._zero_float = (-x_min / self.delta).detach()
221 |
222 | if self.scale_domain == "log":
223 | self._delta = torch.log(self.delta)
224 |
225 | self._delta = self._delta.detach()
226 |
227 | def make_range_trainable(self):
228 | # Converts trainable parameters to nn.Parameters
229 | if self.delta not in self.parameters():
230 | self._delta = torch.nn.Parameter(self._delta)
231 | self._zero_float = torch.nn.Parameter(self._zero_float)
232 |
233 | def fix_ranges(self):
234 | # Removes trainable quantization params from nn.Parameters
235 | if self.delta in self.parameters():
236 | _delta = self._delta.data
237 | _zero_float = self._zero_float.data
238 | del self._delta # delete the parameter
239 | del self._zero_float
240 | self.register_buffer("_delta", _delta)
241 | self.register_buffer("_zero_float", _zero_float)
242 |
243 |
244 | class SymmetricUniformQuantizer(AsymmetricUniformQuantizer):
245 | """
246 | PyTorch Module that implements Symmetric Uniform Quantization using STE.
247 | Quantizes its argument in the forward pass, passes the gradient 'straight
248 | through' on the backward pass, ignoring the quantization that occurred.
249 |
250 | Parameters
251 | ----------
252 | n_bits: int
253 | Number of bits for quantization.
254 | scale_domain: str ('log', 'linear) with default='linear'
255 | Domain of scale factor
256 | per_channel: bool
257 | If True: allows for per-channel quantization
258 | """
259 |
260 | def __init__(self, *args, **kwargs):
261 | super().__init__(*args, **kwargs)
262 | self.register_buffer("_signed", None)
263 |
264 | @property
265 | def signed(self):
266 | if self._signed is not None:
267 | return self._signed.item()
268 | else:
269 | raise QuantizerNotInitializedError()
270 |
271 | @property
272 | def symmetric(self):
273 | return True
274 |
275 | @property
276 | def int_min(self):
277 | return -(2.0 ** (self.n_bits - 1)) if self.signed else 0
278 |
279 | @property
280 | def int_max(self):
281 | pos_n_bits = self.n_bits - self.signed
282 | return 2.0**pos_n_bits - 1
283 |
284 | @property
285 | def zero_point(self):
286 | return 0.0
287 |
288 | def set_quant_range(self, x_min, x_max):
289 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max
290 | x_min, x_max = self._tensorize_min_max(x_min, x_max)
291 | self._signed = x_min.min() < 0
292 |
293 | x_absmax = torch.max(x_min.abs(), x_max)
294 | self._delta = x_absmax / self.int_max
295 |
296 | if self.scale_domain == "log":
297 | self._delta = torch.log(self._delta)
298 |
299 | self._delta = self._delta.detach()
300 |
301 | def make_range_trainable(self):
302 | # Converts trainable parameters to nn.Parameters
303 | if self.delta not in self.parameters():
304 | self._delta = torch.nn.Parameter(self._delta)
305 |
306 | def fix_ranges(self):
307 | # Removes trainable quantization params from nn.Parameters
308 | if self.delta in self.parameters():
309 | _delta = self._delta.data
310 | del self._delta # delete the parameter
311 | self.register_buffer("_delta", _delta)
312 |
--------------------------------------------------------------------------------
/transformers_language/args.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import argparse
4 |
5 | from transformers import MODEL_MAPPING, SchedulerType
6 |
7 | from transformers_language.dataset_setups import DatasetSetups
8 | from transformers_language.models.bert_attention import AttentionGateType
9 | from transformers_language.models.softmax import SOFTMAX_MAPPING
10 |
11 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
12 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(
17 | description="Finetune/pre-train a transformers model on a " "MLM/CLM task"
18 | )
19 |
20 | # *** Options from example script ***
21 |
22 | #
23 | ## Base
24 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
25 |
26 | #
27 | ## Data
28 | parser.add_argument(
29 | "--preprocessing_num_workers",
30 | type=int,
31 | default=None,
32 | help="The number of processes to use for the preprocessing.",
33 | )
34 | parser.add_argument(
35 | "--overwrite_cache",
36 | action="store_true",
37 | help="Overwrite the cached training and evaluation sets",
38 | )
39 |
40 | #
41 | ## Model & tokenizer
42 | parser.add_argument(
43 | "--model_type",
44 | type=str,
45 | default=None,
46 | help="Model type to use if training from scratch.",
47 | choices=MODEL_TYPES,
48 | )
49 | parser.add_argument(
50 | "--model_name_or_path",
51 | type=str,
52 | help="Path to pretrained model or model identifier from huggingface.co/models.",
53 | required=False,
54 | )
55 | parser.add_argument(
56 | "--config_name",
57 | type=str,
58 | default=None,
59 | help="Pretrained config name or path if not the same as model_name",
60 | )
61 | parser.add_argument(
62 | "--tokenizer_name",
63 | type=str,
64 | default=None,
65 | help="Pretrained tokenizer name or path if not the same as model_name",
66 | )
67 | parser.add_argument(
68 | "--use_slow_tokenizer",
69 | action="store_true",
70 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
71 | )
72 | parser.add_argument(
73 | "--pad_to_max_length",
74 | action="store_true",
75 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
76 | )
77 | parser.add_argument(
78 | "--max_seq_length",
79 | type=int,
80 | default=None,
81 | help=(
82 | "The maximum total input sequence length after tokenization. Sequences longer than "
83 | "this will be truncated."
84 | ),
85 | )
86 | parser.add_argument(
87 | "--block_size",
88 | type=int,
89 | default=None,
90 | help=(
91 | "Optional input sequence length after tokenization. The training dataset will be truncated in block of"
92 | " this size for training. Default to the model max input length for single sentence inputs (take into"
93 | " account special tokens)."
94 | ),
95 | )
96 |
97 | #
98 | ## Task
99 | parser.add_argument(
100 | "--mlm_probability",
101 | type=float,
102 | default=0.15,
103 | help="Ratio of tokens to mask for masked language modeling loss",
104 | )
105 |
106 | #
107 | ## Training
108 | parser.add_argument(
109 | "--per_device_train_batch_size",
110 | type=int,
111 | default=8,
112 | help="Batch size (per device) for the training dataloader.",
113 | )
114 | parser.add_argument(
115 | "--per_device_eval_batch_size",
116 | type=int,
117 | default=8,
118 | help="Batch size (per device) for the evaluation dataloader.",
119 | )
120 | parser.add_argument(
121 | "--learning_rate",
122 | type=float,
123 | default=5e-5,
124 | help="Initial learning rate (after the potential warmup period) to use.",
125 | )
126 | parser.add_argument(
127 | "--lr_scheduler_type",
128 | type=SchedulerType,
129 | default="linear",
130 | help="The scheduler type to use.",
131 | choices=[
132 | "linear",
133 | "cosine",
134 | "cosine_with_restarts",
135 | "polynomial",
136 | "constant",
137 | "constant_with_warmup",
138 | ],
139 | )
140 | parser.add_argument(
141 | "--num_train_epochs",
142 | type=int,
143 | default=3,
144 | help="Total number of training epochs to perform.",
145 | )
146 | parser.add_argument(
147 | "--max_train_steps",
148 | type=int,
149 | default=None,
150 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
151 | )
152 | parser.add_argument(
153 | "--num_warmup_steps",
154 | type=int,
155 | default=0,
156 | help="Number of steps for the warmup in the lr scheduler.",
157 | )
158 | parser.add_argument(
159 | "--gradient_accumulation_steps",
160 | type=int,
161 | default=1,
162 | help="Number of updates steps to accumulate before performing a backward/update pass.",
163 | )
164 |
165 | #
166 | ## Regularization
167 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
168 |
169 | #
170 | ## Saving/loading & logging
171 | parser.add_argument(
172 | "--output_dir", type=str, default=None, help="Where to store the final model."
173 | )
174 | parser.add_argument(
175 | "--checkpointing_steps",
176 | type=str,
177 | default=None,
178 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' "
179 | "for each epoch.",
180 | )
181 | parser.add_argument(
182 | "--resume_from_checkpoint",
183 | type=str,
184 | default=None,
185 | help="If the training should continue from a checkpoint folder.",
186 | )
187 | parser.add_argument(
188 | "--with_tracking",
189 | action="store_true",
190 | help="Whether to enable experiment trackers for logging.",
191 | )
192 | parser.add_argument(
193 | "--extra_tb_stats",
194 | action="store_true",
195 | help="Whether to log extra scalars and histograms to TensorBoard.",
196 | )
197 | parser.add_argument(
198 | "--report_to",
199 | type=str,
200 | default="all",
201 | help=(
202 | "The integration to report the results and logs to. Supported platforms are "
203 | '`"tensorboard"`, `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to '
204 | "report to all integrations. Only applicable when `--with_tracking` is passed."
205 | ),
206 | )
207 | parser.add_argument(
208 | "--low_cpu_mem_usage",
209 | action="store_true",
210 | help=(
211 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
212 | "If passed, LLM loading time and RAM consumption will be benefited."
213 | ),
214 | )
215 |
216 | # *** New options ***
217 |
218 | #
219 | ## Data
220 | parser.add_argument(
221 | "--dataset_setup",
222 | choices=DatasetSetups.list_names(),
223 | default=DatasetSetups.wikitext_103.name,
224 | help=f"The setup/preset of the datasets to use.",
225 | )
226 | parser.add_argument(
227 | "--data_cache_dir",
228 | type=str,
229 | default="/local/mnt/workspace/.hf_data",
230 | help="Where to store data.",
231 | )
232 | parser.add_argument(
233 | "--train_percentage",
234 | type=int,
235 | default=None,
236 | help="Percentage of training set to use.",
237 | )
238 | parser.add_argument(
239 | "--validation_percentage",
240 | type=int,
241 | default=None,
242 | help="Percentage of validation set to use.",
243 | )
244 |
245 | #
246 | ## Model & tokenizer
247 | parser.add_argument(
248 | "--config_path",
249 | type=str,
250 | default=None,
251 | help="Path to a yaml file with model config modifications.",
252 | )
253 | parser.add_argument(
254 | "--model_cache_dir",
255 | type=str,
256 | default="/local/mnt/workspace/.hf_cache",
257 | help="Where to store models & tokenizers.",
258 | )
259 |
260 | #
261 | ## Training
262 | parser.add_argument(
263 | "--final_lr_fraction",
264 | type=float,
265 | default=0.0,
266 | help="Final LR as a fraction of the maximum LR (only for CLM).",
267 | )
268 |
269 | #
270 | ## Logging
271 | parser.add_argument(
272 | "--tqdm_update_interval",
273 | type=int,
274 | default=100,
275 | help="How often to update the progress bar. "
276 | "Note that using small value might generate large log files.",
277 | )
278 |
279 | #
280 | ## Regularization
281 | parser.add_argument(
282 | "--max_grad_norm",
283 | type=float,
284 | default=None,
285 | help="Max gradient norm. If set to 0, no clipping will be applied.",
286 | )
287 | parser.add_argument(
288 | "--grad_norm_type",
289 | type=float,
290 | default=2.0,
291 | help="Norm type to use for gradient clipping.",
292 | )
293 | parser.add_argument(
294 | "--attn_dropout",
295 | type=float,
296 | default=None,
297 | help="Dropout rate to set for attention probs.",
298 | )
299 | parser.add_argument(
300 | "--hidden_dropout",
301 | type=float,
302 | default=None,
303 | help="Dropout rate to set for hidden states.",
304 | )
305 |
306 | #
307 | ## Logging
308 | parser.add_argument(
309 | "--tb_scalar_log_interval",
310 | type=int,
311 | default=1000,
312 | help="How often to log scalar stats of weights and activations to TensorBoard.",
313 | )
314 | parser.add_argument(
315 | "--tb_hist_log_interval",
316 | type=int,
317 | default=10000,
318 | help="How often to log histograms of weights and activations to TensorBoard.",
319 | )
320 |
321 | #
322 | ## Extra options
323 | parser.add_argument("--wd_LN_gamma", action="store_true")
324 |
325 | parser.add_argument(
326 | "--skip_attn",
327 | action="store_true",
328 | help="Skip attention (don't update the residual).",
329 | )
330 |
331 | parser.add_argument(
332 | "--attn_softmax",
333 | type=str,
334 | default="vanilla",
335 | help="Softmax variation to use in attention module.",
336 | choices=SOFTMAX_MAPPING.keys(),
337 | )
338 | parser.add_argument(
339 | "--alpha",
340 | type=float,
341 | default=None,
342 | help="If specified, use clipped softmax gamma = -alpha / seq_length.",
343 | )
344 | parser.add_argument(
345 | "--attn_gate_type",
346 | type=str,
347 | default=AttentionGateType.none.name,
348 | help="The type of gating to use for the self-attention.",
349 | choices=AttentionGateType.list_names(),
350 | )
351 | parser.add_argument(
352 | "--attn_gate_init",
353 | type=float,
354 | default=0.5,
355 | help="init bias s.t. the gate prob is approx this value",
356 | )
357 | parser.add_argument(
358 | "--attn_gate_mlp",
359 | action="store_true",
360 | help="Use MLP instead of single linear layer to predict the gate.",
361 | )
362 | parser.add_argument(
363 | "--attn_gate_mlp2",
364 | action="store_true",
365 | help="Use bigger MLP instead of single linear layer to predict the gate.",
366 | )
367 | parser.add_argument(
368 | "--attn_gate_linear_all_features",
369 | action="store_true",
370 | help="Use Linear (d_model -> n_heads) instead of n_heads Linear's (d_head -> 1).",
371 | )
372 |
373 | #
374 | ## Quantization
375 | parser.add_argument("--quantize", action="store_true")
376 | parser.add_argument("--est_num_batches", type=int, default=1)
377 | parser.add_argument("--n_bits", type=int, default=8)
378 | parser.add_argument("--n_bits_act", type=int, default=8)
379 | parser.add_argument("--no_weight_quant", action="store_true")
380 | parser.add_argument("--no_act_quant", action="store_true")
381 | parser.add_argument("--qmethod_acts", type=str, default="asymmetric_uniform")
382 | parser.add_argument("--ranges_weights", type=str, default="minmax")
383 | parser.add_argument("--ranges_acts", type=str, default="running_minmax")
384 | parser.add_argument(
385 | "--percentile", type=float, default=None, help="Percentile (in %) for range estimation."
386 | )
387 | parser.add_argument("--quant_setup", type=str, default="all")
388 |
389 | # Fine-tuning
390 | parser.add_argument("--fine_tuning", action="store_true")
391 |
392 | # Parse options
393 | args = parser.parse_args()
394 |
395 | return args
396 |
--------------------------------------------------------------------------------
/transformers_language/models/opt_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | from functools import partial
4 | from typing import Optional, Tuple
5 |
6 | import torch
7 | import torch.utils.checkpoint
8 | from torch import nn
9 |
10 | from transformers_language.models.bert_attention import AttentionGateType, logit
11 | from transformers_language.models.softmax import clipped_softmax
12 |
13 |
14 | class OPTAttentionWithExtras(nn.Module):
15 | """Multi-headed attention from 'Attention Is All You Need' paper"""
16 |
17 | def __init__(
18 | self,
19 | embed_dim: int,
20 | num_heads: int,
21 | dropout: float = 0.0,
22 | is_decoder: bool = False,
23 | bias: bool = True,
24 | ## new
25 | softmax_fn=torch.nn.functional.softmax,
26 | alpha=None,
27 | max_seq_length=None,
28 | ssm_eps=None,
29 | tau=None,
30 | skip_attn=False,
31 | attn_gate_type=AttentionGateType.none,
32 | attn_gate_init=None,
33 | attn_gate_mlp=False,
34 | attn_gate_mlp2=False,
35 | attn_gate_linear_all_features=False,
36 | fine_tuning=False,
37 | ):
38 | super().__init__()
39 | self.embed_dim = embed_dim
40 | self.num_heads = num_heads
41 | self.dropout = dropout
42 | self.head_dim = embed_dim // num_heads
43 |
44 | if (self.head_dim * num_heads) != self.embed_dim:
45 | raise ValueError(
46 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
47 | f" and `num_heads`: {num_heads})."
48 | )
49 | self.scaling = self.head_dim**-0.5
50 | self.is_decoder = is_decoder
51 |
52 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
53 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
54 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
55 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
56 |
57 | # YB: capture the input and output of the softmax
58 | self.attn_scores = nn.Identity() # before attention mask
59 | self.attn_probs_before_dropout = nn.Identity()
60 | self.attn_probs_after_dropout = nn.Identity()
61 |
62 | self.alpha = alpha
63 | self.max_seq_length = max_seq_length
64 | self.ssm_eps = ssm_eps
65 | self.tau = tau
66 |
67 | # define softmax function
68 | if self.alpha is not None:
69 | assert self.max_seq_length is not None
70 | gamma = -self.alpha / self.max_seq_length
71 | self.softmax_fn = partial(clipped_softmax, gamma=gamma, eta=1.0)
72 | else:
73 | self.softmax_fn = softmax_fn
74 |
75 | self.skip_attn = skip_attn
76 |
77 | # attention gating
78 | self.last_gate_avg_prob = None
79 | self.last_gate_all_probs = None
80 |
81 | self.attn_gate_type = attn_gate_type
82 | self.attn_gate_init = attn_gate_init
83 | self.attn_gate_mlp = attn_gate_mlp
84 | self.attn_gate_mlp2 = attn_gate_mlp2
85 | self.attn_gate_linear_all_features = attn_gate_linear_all_features
86 |
87 | self.alpha = None
88 | self.ssm_eps = ssm_eps
89 | self.gate_fn = torch.sigmoid
90 | self.pooling_fn = partial(torch.mean, dim=1, keepdims=True)
91 |
92 | self.fine_tuning = fine_tuning
93 |
94 | # gate scaling factor
95 | self.gate_scaling_factor = 1.0
96 | if self.fine_tuning and self.attn_gate_init is not None:
97 | self.gate_scaling_factor = 1.0 / self.attn_gate_init
98 |
99 | # define gate
100 | if self.attn_gate_type == AttentionGateType.unconditional_per_head:
101 | init_alpha = torch.zeros(size=(self.num_heads,))
102 | self.alpha = nn.Parameter(init_alpha, requires_grad=True)
103 |
104 | elif self.attn_gate_type in (
105 | AttentionGateType.conditional_per_head,
106 | AttentionGateType.conditional_per_token,
107 | ):
108 | if self.attn_gate_linear_all_features:
109 | self.alpha = nn.Linear(self.embed_dim, self.num_heads, bias=True)
110 |
111 | else: # separate predictors for each head
112 | module_list = []
113 | for _ in range(self.num_heads):
114 | if self.attn_gate_mlp:
115 | fc = nn.Sequential(
116 | nn.Linear(self.head_dim, self.head_dim // 4, bias=True),
117 | nn.ReLU(),
118 | nn.Linear(self.head_dim // 4, 1, bias=True),
119 | )
120 | elif self.attn_gate_mlp2:
121 | fc = nn.Sequential(
122 | nn.Linear(self.head_dim, self.head_dim, bias=True),
123 | nn.ReLU(),
124 | nn.Linear(self.head_dim, 1, bias=True),
125 | )
126 | else:
127 | fc = nn.Linear(self.head_dim, 1, bias=True)
128 |
129 | if self.attn_gate_init is not None:
130 | init_bias = logit(self.attn_gate_init)
131 | torch.nn.init.constant_(fc.bias, init_bias)
132 |
133 | if self.fine_tuning:
134 | # init to a very small values
135 | torch.nn.init.normal_(fc.weight, mean=0.0, std=0.001)
136 |
137 | module_list.append(fc)
138 | self.alpha = nn.ModuleList(module_list)
139 |
140 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
141 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
142 |
143 | def forward(
144 | self,
145 | hidden_states: torch.Tensor,
146 | key_value_states: Optional[torch.Tensor] = None,
147 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
148 | attention_mask: Optional[torch.Tensor] = None,
149 | layer_head_mask: Optional[torch.Tensor] = None,
150 | output_attentions: bool = False,
151 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
152 | """Input shape: Batch x Time x Channel"""
153 |
154 | # if key_value_states are provided this layer is used as a cross-attention layer
155 | # for the decoder
156 | is_cross_attention = key_value_states is not None
157 |
158 | bsz, tgt_len, _ = hidden_states.size()
159 |
160 | # get query proj
161 | query_states = self.q_proj(hidden_states) * self.scaling
162 | # get key, value proj
163 | if is_cross_attention and past_key_value is not None:
164 | # reuse k,v, cross_attentions
165 | key_states = past_key_value[0]
166 | value_states = past_key_value[1]
167 | elif is_cross_attention:
168 | # cross_attentions
169 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
170 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
171 | elif past_key_value is not None:
172 | # reuse k, v, self_attention
173 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
174 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
175 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
176 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
177 | else:
178 | # self_attention
179 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
180 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
181 |
182 | if self.is_decoder:
183 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
184 | # Further calls to cross_attention layer can then reuse all cross-attention
185 | # key/value_states (first "if" case)
186 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
187 | # all previous decoder key/value_states. Further calls to uni-directional self-attention
188 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
189 | # if encoder bi-directional self-attention `past_key_value` is always `None`
190 | past_key_value = (key_states, value_states)
191 |
192 | proj_shape = (bsz * self.num_heads, -1, self.head_dim)
193 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
194 | key_states = key_states.view(*proj_shape)
195 | value_states = value_states.view(*proj_shape)
196 |
197 | src_len = key_states.size(1)
198 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
199 |
200 | # YB: for logging softmax input
201 | attn_weights = self.attn_scores(attn_weights)
202 |
203 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
204 | raise ValueError(
205 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
206 | f" {attn_weights.size()}"
207 | )
208 |
209 | if attention_mask is not None:
210 | if attention_mask.size() != (bsz, 1, tgt_len, src_len):
211 | raise ValueError(
212 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
213 | )
214 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
215 | attn_weights = torch.max(
216 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
217 | )
218 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
219 |
220 | # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
221 | if attn_weights.dtype == torch.float16:
222 | attn_weights = self.softmax_fn(attn_weights, dim=-1, dtype=torch.float32).to(
223 | torch.float16
224 | )
225 | else:
226 | attn_weights = self.softmax_fn(attn_weights, dim=-1)
227 |
228 | if layer_head_mask is not None:
229 | if layer_head_mask.size() != (self.num_heads,):
230 | raise ValueError(
231 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
232 | f" {layer_head_mask.size()}"
233 | )
234 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
235 | bsz, self.num_heads, tgt_len, src_len
236 | )
237 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
238 |
239 | if output_attentions:
240 | # this operation is a bit awkward, but it's required to
241 | # make sure that attn_weights keeps its gradient.
242 | # In order to do so, attn_weights have to be reshaped
243 | # twice and have to be reused in the following
244 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
245 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
246 | else:
247 | attn_weights_reshaped = None
248 |
249 | # YB: for logging softmax output
250 | attn_weights = self.attn_probs_before_dropout(attn_weights)
251 |
252 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
253 |
254 | # YB: for logging softmax output
255 | attn_probs = self.attn_probs_after_dropout(attn_probs)
256 |
257 | attn_output = torch.bmm(attn_probs, value_states)
258 |
259 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
260 | raise ValueError(
261 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
262 | f" {attn_output.size()}"
263 | )
264 |
265 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
266 | # attn_output - (B, H, T, d_head)
267 |
268 | #
269 | # *** Gating ***
270 | if self.attn_gate_type == AttentionGateType.unconditional_per_head:
271 | gate = self.gate_fn(self.alpha) # (H,)
272 | attn_output *= gate.view(-1, 1, 1) # (B, H, T, d_head)
273 |
274 | self.last_gate_avg_prob = gate.view(-1)
275 |
276 | elif self.attn_gate_type in (
277 | AttentionGateType.conditional_per_head,
278 | AttentionGateType.conditional_per_token,
279 | ):
280 | x = hidden_states # (B, T, d_model)
281 |
282 | if self.attn_gate_linear_all_features: # assume per_token
283 | alpha = self.alpha(x) # (B, T, H)
284 | gate = self.gate_fn(alpha)
285 | gate = gate.permute(0, 2, 1).contiguous() # (B, H, T)
286 | gate = gate.unsqueeze(3) # (B, H, T, 1)
287 |
288 | else:
289 | # x = self.transpose_for_scores(x) # (B, H, T, d_head)
290 | x = self._shape(x, -1, bsz) # (B, H, T, d_head)
291 |
292 | alpha = []
293 | for head_idx in range(self.num_heads):
294 | x_head = x[:, head_idx, ...] # (B, T, d_head)
295 | fc_head = self.alpha[head_idx]
296 | alpha_head = fc_head(x_head) # (B, T, 1)
297 | if self.attn_gate_type == AttentionGateType.conditional_per_head:
298 | alpha_head = self.pooling_fn(alpha_head) # (B, 1, 1)
299 | alpha.append(alpha_head)
300 | alpha = torch.stack(alpha, dim=1) # (B, H, *, 1)
301 | gate = self.gate_fn(alpha)
302 |
303 | attn_output *= gate * self.gate_scaling_factor
304 |
305 | self.last_gate_all_probs = gate # all gates to see the distributions
306 | avg_gate = gate.mean(dim=0)
307 | self.last_gate_avg_prob = avg_gate.view(self.num_heads, -1).mean(dim=1)
308 |
309 | #
310 | ##
311 |
312 | attn_output = attn_output.transpose(1, 2)
313 |
314 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
315 | # partitioned aross GPUs when using tensor-parallelism.
316 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
317 |
318 | attn_output = self.out_proj(attn_output)
319 |
320 | return attn_output, attn_weights_reshaped, past_key_value
321 |
--------------------------------------------------------------------------------
/transformers_language/models/bert_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import math
4 | from functools import partial
5 | from typing import Optional, Tuple
6 |
7 | import numpy as np
8 | import torch
9 | import torch.utils.checkpoint
10 | from torch import nn
11 |
12 | from quantization.utils import BaseEnumOptions
13 | from transformers_language.models.softmax import clipped_softmax
14 |
15 |
16 | def logit(p, eps=1e-16):
17 | p = np.clip(p, eps, 1 - eps)
18 | return -np.log(1 / p - 1)
19 |
20 |
21 | class AttentionGateType(BaseEnumOptions):
22 | none = 0
23 | unconditional_per_head = 1
24 | conditional_per_head = 2
25 | conditional_per_token = 3
26 |
27 |
28 | class BertSelfAttentionWithExtras(nn.Module):
29 | def __init__(
30 | self,
31 | config,
32 | position_embedding_type=None,
33 | softmax_fn=torch.nn.functional.softmax,
34 | alpha=None,
35 | ssm_eps=None,
36 | tau=None,
37 | max_seq_length=None,
38 | skip_attn=False,
39 | attn_gate_type=AttentionGateType.none,
40 | attn_gate_init=None,
41 | attn_gate_mlp=False,
42 | attn_gate_mlp2=False,
43 | attn_gate_linear_all_features=False,
44 | fine_tuning=False,
45 | ):
46 | super().__init__()
47 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
48 | config, "embedding_size"
49 | ):
50 | raise ValueError(
51 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
52 | f"heads ({config.num_attention_heads})"
53 | )
54 |
55 | self.num_attention_heads = config.num_attention_heads
56 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
57 | self.all_head_size = self.num_attention_heads * self.attention_head_size
58 |
59 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
60 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
61 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
62 |
63 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
64 | self.position_embedding_type = position_embedding_type or getattr(
65 | config, "position_embedding_type", "absolute"
66 | )
67 | if (
68 | self.position_embedding_type == "relative_key"
69 | or self.position_embedding_type == "relative_key_query"
70 | ):
71 | self.max_position_embeddings = config.max_position_embeddings
72 | self.distance_embedding = nn.Embedding(
73 | 2 * config.max_position_embeddings - 1, self.attention_head_size
74 | )
75 |
76 | self.is_decoder = config.is_decoder
77 |
78 | # YB: capture the input and output of the softmax
79 | self.attn_scores = nn.Identity() # before attention mask
80 | self.attn_probs_before_dropout = nn.Identity()
81 | self.attn_probs_after_dropout = nn.Identity()
82 |
83 | self.alpha = alpha
84 | self.ssm_eps = ssm_eps
85 | self.tau = tau
86 | self.max_seq_length = max_seq_length
87 |
88 | # define softmax function
89 | if self.alpha is not None:
90 | assert self.max_seq_length is not None
91 | gamma = -self.alpha / self.max_seq_length
92 | self.softmax_fn = partial(clipped_softmax, gamma=gamma, eta=1.0)
93 | else:
94 | self.softmax_fn = softmax_fn
95 |
96 | self.skip_attn = skip_attn
97 |
98 | # attention gating
99 | self.last_gate_avg_prob = None
100 | self.last_gate_all_probs = None
101 |
102 | self.attn_gate_type = attn_gate_type
103 | self.attn_gate_init = attn_gate_init
104 | self.attn_gate_mlp = attn_gate_mlp
105 | self.attn_gate_mlp2 = attn_gate_mlp2
106 | self.attn_gate_linear_all_features = attn_gate_linear_all_features
107 |
108 | self.alpha = None
109 | self.gate_fn = torch.sigmoid
110 | self.pooling_fn = partial(torch.mean, dim=1, keepdims=True)
111 |
112 | self.fine_tuning = fine_tuning
113 |
114 | # gate scaling factor
115 | self.gate_scaling_factor = 1.0
116 | if self.fine_tuning and self.attn_gate_init is not None:
117 | self.gate_scaling_factor = 1.0 / self.attn_gate_init
118 |
119 | # define gate
120 | if self.attn_gate_type == AttentionGateType.unconditional_per_head:
121 | init_alpha = torch.zeros(size=(self.num_attention_heads,))
122 | self.alpha = nn.Parameter(init_alpha, requires_grad=True)
123 |
124 | elif self.attn_gate_type in (
125 | AttentionGateType.conditional_per_head,
126 | AttentionGateType.conditional_per_token,
127 | ):
128 | if self.attn_gate_linear_all_features:
129 | self.alpha = nn.Linear(self.all_head_size, self.num_attention_heads, bias=True)
130 |
131 | else: # separate predictors for each head
132 | module_list = []
133 | for _ in range(self.num_attention_heads):
134 | if self.attn_gate_mlp:
135 | fc = nn.Sequential(
136 | nn.Linear(
137 | self.attention_head_size, self.attention_head_size // 4, bias=True
138 | ),
139 | nn.ReLU(),
140 | nn.Linear(self.attention_head_size // 4, 1, bias=True),
141 | )
142 | elif self.attn_gate_mlp2:
143 | fc = nn.Sequential(
144 | nn.Linear(
145 | self.attention_head_size, self.attention_head_size, bias=True
146 | ),
147 | nn.ReLU(),
148 | nn.Linear(self.attention_head_size, 1, bias=True),
149 | )
150 | else:
151 | fc = nn.Linear(self.attention_head_size, 1, bias=True)
152 |
153 | if self.attn_gate_init is not None:
154 | init_bias = logit(self.attn_gate_init)
155 | torch.nn.init.constant_(fc.bias, init_bias)
156 |
157 | if self.fine_tuning:
158 | # init to a very small values
159 | torch.nn.init.normal_(fc.weight, mean=0.0, std=0.01)
160 |
161 | module_list.append(fc)
162 | self.alpha = nn.ModuleList(module_list)
163 |
164 | def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
165 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
166 | x = x.view(new_x_shape)
167 | return x.permute(0, 2, 1, 3)
168 |
169 | def forward(
170 | self,
171 | hidden_states: torch.Tensor,
172 | attention_mask: Optional[torch.FloatTensor] = None,
173 | head_mask: Optional[torch.FloatTensor] = None,
174 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
175 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
176 | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
177 | output_attentions: Optional[bool] = False,
178 | ) -> Tuple[torch.Tensor]:
179 | if self.skip_attn:
180 | out = torch.zeros_like(hidden_states)
181 | return (out,)
182 |
183 | mixed_query_layer = self.query(hidden_states)
184 |
185 | # If this is instantiated as a cross-attention module, the keys
186 | # and values come from an encoder; the attention mask needs to be
187 | # such that the encoder's padding tokens are not attended to.
188 | is_cross_attention = encoder_hidden_states is not None
189 |
190 | if is_cross_attention and past_key_value is not None:
191 | # reuse k,v, cross_attentions
192 | key_layer = past_key_value[0]
193 | value_layer = past_key_value[1]
194 | attention_mask = encoder_attention_mask
195 | elif is_cross_attention:
196 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
197 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
198 | attention_mask = encoder_attention_mask
199 | elif past_key_value is not None:
200 | key_layer = self.transpose_for_scores(self.key(hidden_states))
201 | value_layer = self.transpose_for_scores(self.value(hidden_states))
202 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
203 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
204 | else:
205 | key_layer = self.transpose_for_scores(self.key(hidden_states))
206 | value_layer = self.transpose_for_scores(self.value(hidden_states))
207 |
208 | query_layer = self.transpose_for_scores(mixed_query_layer)
209 |
210 | use_cache = past_key_value is not None
211 | if self.is_decoder:
212 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
213 | # Further calls to cross_attention layer can then reuse all cross-attention
214 | # key/value_states (first "if" case)
215 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
216 | # all previous decoder key/value_states. Further calls to uni-directional self-attention
217 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
218 | # if encoder bi-directional self-attention `past_key_value` is always `None`
219 | past_key_value = (key_layer, value_layer)
220 |
221 | # Take the dot product between "query" and "key" to get the raw attention scores.
222 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
223 |
224 | if (
225 | self.position_embedding_type == "relative_key"
226 | or self.position_embedding_type == "relative_key_query"
227 | ):
228 | query_length, key_length = query_layer.shape[2], key_layer.shape[2]
229 | if use_cache:
230 | position_ids_l = torch.tensor(
231 | key_length - 1, dtype=torch.long, device=hidden_states.device
232 | ).view(-1, 1)
233 | else:
234 | position_ids_l = torch.arange(
235 | query_length, dtype=torch.long, device=hidden_states.device
236 | ).view(-1, 1)
237 | position_ids_r = torch.arange(
238 | key_length, dtype=torch.long, device=hidden_states.device
239 | ).view(1, -1)
240 | distance = position_ids_l - position_ids_r
241 |
242 | positional_embedding = self.distance_embedding(
243 | distance + self.max_position_embeddings - 1
244 | )
245 | positional_embedding = positional_embedding.to(
246 | dtype=query_layer.dtype
247 | ) # fp16 compatibility
248 |
249 | if self.position_embedding_type == "relative_key":
250 | relative_position_scores = torch.einsum(
251 | "bhld,lrd->bhlr", query_layer, positional_embedding
252 | )
253 | attention_scores = attention_scores + relative_position_scores
254 | elif self.position_embedding_type == "relative_key_query":
255 | relative_position_scores_query = torch.einsum(
256 | "bhld,lrd->bhlr", query_layer, positional_embedding
257 | )
258 | relative_position_scores_key = torch.einsum(
259 | "bhrd,lrd->bhlr", key_layer, positional_embedding
260 | )
261 | attention_scores = (
262 | attention_scores + relative_position_scores_query + relative_position_scores_key
263 | )
264 |
265 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
266 |
267 | # YB: for logging softmax input
268 | attention_scores = self.attn_scores(attention_scores)
269 |
270 | if attention_mask is not None:
271 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
272 | attention_scores = attention_scores + attention_mask
273 |
274 | # Normalize the attention scores to probabilities.
275 | # MN: uses our own SM function as specified in the config
276 | attention_probs = self.softmax_fn(attention_scores, dim=-1)
277 |
278 | # YB: for logging softmax output
279 | attention_probs = self.attn_probs_before_dropout(attention_probs)
280 |
281 | # This is actually dropping out entire tokens to attend to, which might
282 | # seem a bit unusual, but is taken from the original Transformer paper.
283 | attention_probs = self.dropout(attention_probs)
284 |
285 | # YB: for logging softmax output
286 | attention_probs = self.attn_probs_after_dropout(attention_probs)
287 |
288 | # Mask heads if we want to
289 | if head_mask is not None:
290 | attention_probs = attention_probs * head_mask
291 |
292 | context_layer = torch.matmul(attention_probs, value_layer)
293 |
294 | # *** Gating ***
295 | if self.attn_gate_type == AttentionGateType.unconditional_per_head:
296 | gate = self.gate_fn(self.alpha) # (H,)
297 | context_layer *= gate.view(-1, 1, 1) # (B, H, T, d_head)
298 |
299 | self.last_gate_avg_prob = gate.view(-1)
300 |
301 | elif self.attn_gate_type in (
302 | AttentionGateType.conditional_per_head,
303 | AttentionGateType.conditional_per_token,
304 | ):
305 | x = hidden_states # (B, T, d_model)
306 |
307 | if self.attn_gate_linear_all_features: # assume per_token
308 | alpha = self.alpha(x) # (B, T, H)
309 | gate = self.gate_fn(alpha)
310 | gate = gate.permute(0, 2, 1).contiguous() # (B, H, T)
311 | gate = gate.unsqueeze(3) # (B, H, T, 1)
312 |
313 | else:
314 | x = self.transpose_for_scores(x) # (B, H, T, d_head)
315 |
316 | alpha = []
317 | for head_idx in range(self.num_attention_heads):
318 | x_head = x[:, head_idx, ...] # (B, T, d_head)
319 | fc_head = self.alpha[head_idx]
320 | alpha_head = fc_head(x_head) # (B, T, 1)
321 | if self.attn_gate_type == AttentionGateType.conditional_per_head:
322 | alpha_head = self.pooling_fn(alpha_head) # (B, 1, 1)
323 | alpha.append(alpha_head)
324 | alpha = torch.stack(alpha, dim=1) # (B, H, *, 1)
325 | gate = self.gate_fn(alpha)
326 |
327 | context_layer *= gate * self.gate_scaling_factor
328 |
329 | self.last_gate_all_probs = gate # all gates to see the distributions
330 | avg_gate = gate.mean(dim=0)
331 | self.last_gate_avg_prob = avg_gate.view(self.num_attention_heads, -1).mean(dim=1)
332 |
333 | #
334 |
335 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
336 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
337 | context_layer = context_layer.view(new_context_layer_shape)
338 |
339 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
340 |
341 | if self.is_decoder:
342 | outputs = outputs + (past_key_value,)
343 | return outputs
344 |
--------------------------------------------------------------------------------
/quantization/range_estimators.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
2 | # All Rights Reserved.
3 | import copy
4 | from enum import auto
5 |
6 | import numpy as np
7 | import torch
8 | from scipy.optimize import minimize_scalar
9 | from torch import nn
10 |
11 | from quantization.utils import BaseEnumOptions, ClassEnumOptions, MethodMap, to_numpy
12 |
13 |
14 | class RangeEstimatorBase(nn.Module):
15 | def __init__(self, *args, per_channel=False, quantizer=None, **kwargs):
16 | super().__init__(*args, **kwargs)
17 | self.register_buffer("current_xmin", None)
18 | self.register_buffer("current_xmax", None)
19 | self.per_channel = per_channel
20 | self.quantizer = quantizer
21 |
22 | def forward(self, x):
23 | """
24 | Accepts an input tensor, updates the current estimates of x_min and x_max
25 | and retruns them.
26 | Parameters
27 | ----------
28 | x: Input tensor
29 |
30 | Returns
31 | -------
32 | self.current_xmin: tensor
33 |
34 | self.current_xmax: tensor
35 |
36 | """
37 | raise NotImplementedError()
38 |
39 | def reset(self):
40 | """
41 | Reset the range estimator.
42 | """
43 | self.current_xmin = None
44 | self.current_xmax = None
45 |
46 | def __repr__(self):
47 | # We overwrite this from nn.Module as we do not want to have submodules such as
48 | # self.quantizer in the reproduce. Otherwise it behaves as expected for an nn.Module.
49 | lines = self.extra_repr().split("\n")
50 | extra_str = lines[0] if len(lines) == 1 else "\n " + "\n ".join(lines) + "\n"
51 |
52 | return self._get_name() + "(" + extra_str + ")"
53 |
54 |
55 | class CurrentMinMaxEstimator(RangeEstimatorBase):
56 | def __init__(self, *args, percentile=None, **kwargs):
57 | self.percentile = percentile
58 | super().__init__(*args, **kwargs)
59 |
60 | def forward(self, x):
61 | if self.per_channel:
62 | x = x.view(x.shape[0], -1)
63 | if self.percentile:
64 | axis = -1 if self.per_channel else None
65 | data_np = to_numpy(x)
66 | x_min, x_max = np.percentile(
67 | data_np, (self.percentile, 100 - self.percentile), axis=axis
68 | )
69 | self.current_xmin = torch.tensor(x_min).to(x.device)
70 | self.current_xmax = torch.tensor(x_max).to(x.device)
71 | else:
72 | self.current_xmin = x.min(-1)[0].detach() if self.per_channel else x.min().detach()
73 | self.current_xmax = x.max(-1)[0].detach() if self.per_channel else x.max().detach()
74 |
75 | return self.current_xmin, self.current_xmax
76 |
77 |
78 | class RunningMinMaxEstimator(RangeEstimatorBase):
79 | def __init__(self, *args, momentum=0.9, percentile=None, **kwargs):
80 | self.momentum = momentum
81 | self.percentile = percentile
82 | super().__init__(*args, **kwargs)
83 |
84 | def forward(self, x):
85 | if self.per_channel:
86 | # Along 1st dim
87 | x_flattened = x.view(x.shape[0], -1)
88 | x_min = x_flattened.min(-1)[0].detach()
89 | x_max = x_flattened.max(-1)[0].detach()
90 | elif self.percentile:
91 | data_np = to_numpy(x)
92 | x_min, x_max = np.percentile(data_np, (100 - self.percentile, self.percentile))
93 |
94 | x_min = torch.tensor(x_min).to(x.device)
95 | x_max = torch.tensor(x_max).to(x.device)
96 | else:
97 | x_min = torch.min(x).detach()
98 | x_max = torch.max(x).detach()
99 |
100 | if self.current_xmin is None:
101 | self.current_xmin = x_min
102 | self.current_xmax = x_max
103 | else:
104 | self.current_xmin = (1 - self.momentum) * x_min + self.momentum * self.current_xmin
105 | self.current_xmax = (1 - self.momentum) * x_max + self.momentum * self.current_xmax
106 |
107 | return self.current_xmin, self.current_xmax
108 |
109 |
110 | class OptMethod(BaseEnumOptions):
111 | grid = auto()
112 | golden_section = auto()
113 |
114 |
115 | class MSE_Estimator(RangeEstimatorBase):
116 | def __init__(
117 | self, *args, num_candidates=100, opt_method=OptMethod.grid, range_margin=0.5, **kwargs
118 | ):
119 | super().__init__(*args, **kwargs)
120 | assert opt_method in OptMethod
121 | self.opt_method = opt_method
122 | self.num_candidates = num_candidates
123 | self.loss_array = None
124 | self.max_pos_thr = None
125 | self.max_neg_thr = None
126 | self.max_search_range = None
127 | self.one_sided_dist = None
128 | self.range_margin = range_margin
129 | if self.quantizer is None:
130 | raise NotImplementedError(
131 | "A Quantizer must be given as an argument to the MSE Range" "Estimator"
132 | )
133 | self.max_int_skew = (2**self.quantizer.n_bits) // 4 # For asymmetric quantization
134 |
135 | def loss_fx(self, data, neg_thr, pos_thr, per_channel_loss=False):
136 | y = self.quantize(data, x_min=neg_thr, x_max=pos_thr)
137 | temp_sum = torch.sum(((data - y) ** 2).view(len(data), -1), dim=1)
138 | # if we want to return the MSE loss of each channel separately, speeds up the per-channel
139 | # grid search
140 | if per_channel_loss:
141 | return to_numpy(temp_sum)
142 | else:
143 | return to_numpy(torch.sum(temp_sum))
144 |
145 | @property
146 | def step_size(self):
147 | if self.one_sided_dist is None:
148 | raise NoDataPassedError()
149 |
150 | return self.max_search_range / self.num_candidates
151 |
152 | @property
153 | def optimization_method(self):
154 | if self.one_sided_dist is None:
155 | raise NoDataPassedError()
156 |
157 | if self.opt_method == OptMethod.grid:
158 | # Grid search method
159 | if self.one_sided_dist or self.quantizer.symmetric:
160 | # 1-D grid search
161 | return self._perform_1D_search
162 | else:
163 | # 2-D grid_search
164 | return self._perform_2D_search
165 | elif self.opt_method == OptMethod.golden_section:
166 | # Golden section method
167 | if self.one_sided_dist or self.quantizer.symmetric:
168 | return self._golden_section_symmetric
169 | else:
170 | return self._golden_section_asymmetric
171 | else:
172 | raise NotImplementedError("Optimization Method not Implemented")
173 |
174 | def quantize(self, x_float, x_min=None, x_max=None):
175 | temp_q = copy.deepcopy(self.quantizer)
176 | # In the current implementation no optimization procedure requires temp quantizer for
177 | # loss_fx to be per-channel
178 | temp_q.per_channel = False
179 | if x_min or x_max:
180 | temp_q.set_quant_range(x_min, x_max)
181 | return temp_q(x_float)
182 |
183 | def golden_sym_loss(self, range, data):
184 | """
185 | Loss function passed to the golden section optimizer from scipy in case of symmetric
186 | quantization
187 | """
188 | neg_thr = 0 if self.one_sided_dist else -range
189 | pos_thr = range
190 | return self.loss_fx(data, neg_thr, pos_thr)
191 |
192 | def golden_asym_shift_loss(self, shift, range, data):
193 | """
194 | Inner Loss function (shift) passed to the golden section optimizer from scipy
195 | in case of asymmetric quantization
196 | """
197 | pos_thr = range + shift
198 | neg_thr = -range + shift
199 | return self.loss_fx(data, neg_thr, pos_thr)
200 |
201 | def golden_asym_range_loss(self, range, data):
202 | """
203 | Outer Loss function (range) passed to the golden section optimizer from scipy in case of
204 | asymmetric quantization
205 | """
206 | temp_delta = 2 * range / (2**self.quantizer.n_bits - 1)
207 | max_shift = temp_delta * self.max_int_skew
208 | result = minimize_scalar(
209 | self.golden_asym_shift_loss,
210 | args=(range, data),
211 | bounds=(-max_shift, max_shift),
212 | method="Bounded",
213 | )
214 | return result.fun
215 |
216 | def _define_search_range(self, data):
217 | self.channel_groups = len(data) if self.per_channel else 1
218 | self.current_xmax = torch.zeros(self.channel_groups, device=data.device)
219 | self.current_xmin = torch.zeros(self.channel_groups, device=data.device)
220 |
221 | if self.one_sided_dist or self.quantizer.symmetric:
222 | # 1D search space
223 | self.loss_array = np.zeros(
224 | (self.channel_groups, self.num_candidates + 1)
225 | ) # 1D search space
226 | self.loss_array[:, 0] = np.inf # exclude interval_start=interval_finish
227 | # Defining the search range for clipping thresholds
228 | self.max_pos_thr = max(abs(float(data.min())), float(data.max())) + self.range_margin
229 | self.max_neg_thr = -self.max_pos_thr
230 | self.max_search_range = self.max_pos_thr
231 | else:
232 | # 2D search space (3rd and 4th index correspond to asymmetry where fourth
233 | # index represents whether the skew is positive (0) or negative (1))
234 | self.loss_array = np.zeros(
235 | [self.channel_groups, self.num_candidates + 1, self.max_int_skew, 2]
236 | ) # 2D search space
237 | self.loss_array[:, 0, :, :] = np.inf # exclude interval_start=interval_finish
238 | # Define the search range for clipping thresholds in asymmetric case
239 | self.max_pos_thr = float(data.max()) + self.range_margin
240 | self.max_neg_thr = float(data.min()) - self.range_margin
241 | self.max_search_range = max(abs(self.max_pos_thr), abs(self.max_neg_thr))
242 |
243 | def _perform_1D_search(self, data):
244 | """
245 | Grid search through all candidate quantizers in 1D to find the best
246 | The loss is accumulated over all batches without any momentum
247 | :param data: input tensor
248 | """
249 | for cand_index in range(1, self.num_candidates + 1):
250 | neg_thr = 0 if self.one_sided_dist else -self.step_size * cand_index
251 | pos_thr = self.step_size * cand_index
252 |
253 | self.loss_array[:, cand_index] += self.loss_fx(
254 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel
255 | )
256 | # find the best clipping thresholds
257 | min_cand = self.loss_array.argmin(axis=1)
258 | xmin = (
259 | np.zeros(self.channel_groups) if self.one_sided_dist else -self.step_size * min_cand
260 | ).astype(np.single)
261 | xmax = (self.step_size * min_cand).astype(np.single)
262 | self.current_xmax = torch.tensor(xmax).to(device=data.device)
263 | self.current_xmin = torch.tensor(xmin).to(device=data.device)
264 |
265 | def _perform_2D_search(self, data):
266 | """
267 | Grid search through all candidate quantizers in 1D to find the best
268 | The loss is accumulated over all batches without any momentum
269 | Parameters
270 | ----------
271 | data: PyTorch Tensor
272 | Returns
273 | -------
274 |
275 | """
276 | for cand_index in range(1, self.num_candidates + 1):
277 | # defining the symmetric quantization range
278 | temp_start = -self.step_size * cand_index
279 | temp_finish = self.step_size * cand_index
280 | temp_delta = float(temp_finish - temp_start) / (2**self.quantizer.n_bits - 1)
281 | for shift in range(self.max_int_skew):
282 | for reverse in range(2):
283 | # introducing asymmetry in the quantization range
284 | skew = ((-1) ** reverse) * shift * temp_delta
285 | neg_thr = max(temp_start + skew, self.max_neg_thr)
286 | pos_thr = min(temp_finish + skew, self.max_pos_thr)
287 |
288 | self.loss_array[:, cand_index, shift, reverse] += self.loss_fx(
289 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel
290 | )
291 |
292 | for channel_index in range(self.channel_groups):
293 | min_cand, min_shift, min_reverse = np.unravel_index(
294 | np.argmin(self.loss_array[channel_index], axis=None),
295 | self.loss_array[channel_index].shape,
296 | )
297 | min_interval_start = -self.step_size * min_cand
298 | min_interval_finish = self.step_size * min_cand
299 | min_delta = float(min_interval_finish - min_interval_start) / (
300 | 2**self.quantizer.n_bits - 1
301 | )
302 | min_skew = ((-1) ** min_reverse) * min_shift * min_delta
303 | xmin = max(min_interval_start + min_skew, self.max_neg_thr)
304 | xmax = min(min_interval_finish + min_skew, self.max_pos_thr)
305 |
306 | self.current_xmin[channel_index] = torch.tensor(xmin).to(device=data.device)
307 | self.current_xmax[channel_index] = torch.tensor(xmax).to(device=data.device)
308 |
309 | def _golden_section_symmetric(self, data):
310 | for channel_index in range(self.channel_groups):
311 | if channel_index == 0 and not self.per_channel:
312 | data_segment = data
313 | else:
314 | data_segment = data[channel_index]
315 |
316 | self.result = minimize_scalar(
317 | self.golden_sym_loss,
318 | args=data_segment,
319 | bounds=(0.01 * self.max_search_range, self.max_search_range),
320 | method="Bounded",
321 | )
322 | self.current_xmax[channel_index] = torch.tensor(self.result.x).to(device=data.device)
323 | self.current_xmin[channel_index] = (
324 | torch.tensor(0.0).to(device=data.device)
325 | if self.one_sided_dist
326 | else -self.current_xmax[channel_index]
327 | )
328 |
329 | def _golden_section_asymmetric(self, data):
330 | for channel_index in range(self.channel_groups):
331 | if channel_index == 0 and not self.per_channel:
332 | data_segment = data
333 | else:
334 | data_segment = data[channel_index]
335 |
336 | self.result = minimize_scalar(
337 | self.golden_asym_range_loss,
338 | args=data_segment,
339 | bounds=(0.01 * self.max_search_range, self.max_search_range),
340 | method="Bounded",
341 | )
342 | self.final_range = self.result.x
343 | temp_delta = 2 * self.final_range / (2**self.quantizer.n_bits - 1)
344 | max_shift = temp_delta * self.max_int_skew
345 | self.subresult = minimize_scalar(
346 | self.golden_asym_shift_loss,
347 | args=(self.final_range, data_segment),
348 | bounds=(-max_shift, max_shift),
349 | method="Bounded",
350 | )
351 | self.final_shift = self.subresult.x
352 | self.current_xmax[channel_index] = torch.tensor(self.final_range + self.final_shift).to(
353 | device=data.device
354 | )
355 | self.current_xmin[channel_index] = torch.tensor(
356 | -self.final_range + self.final_shift
357 | ).to(device=data.device)
358 |
359 | def forward(self, data):
360 | if self.loss_array is None:
361 | # Initialize search range on first batch, and accumulate losses with subsequent calls
362 |
363 | # Decide whether input distribution is one-sided
364 | if self.one_sided_dist is None:
365 | self.one_sided_dist = bool((data.min() >= 0).item())
366 |
367 | # Define search
368 | self._define_search_range(data)
369 |
370 | # Perform Search/Optimization for Quantization Ranges
371 | self.optimization_method(data)
372 |
373 | return self.current_xmin, self.current_xmax
374 |
375 | def reset(self):
376 | super().reset()
377 | self.loss_array = None
378 |
379 | def extra_repr(self):
380 | repr = "opt_method={}".format(self.opt_method.name)
381 | if self.opt_method == OptMethod.grid:
382 | repr += " ,num_candidates={}".format(self.num_candidates)
383 | return repr
384 |
385 |
386 | class NoDataPassedError(Exception):
387 | """Raised data has been passed inot the Range Estimator"""
388 |
389 | def __init__(self):
390 | super().__init__("Data must be pass through the range estimator to be initialized")
391 |
392 |
393 | class RangeEstimators(ClassEnumOptions):
394 | current_minmax = MethodMap(CurrentMinMaxEstimator)
395 | running_minmax = MethodMap(RunningMinMaxEstimator)
396 | MSE = MethodMap(MSE_Estimator)
397 |
--------------------------------------------------------------------------------
/validate_mlm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
4 | # All Rights Reserved.
5 | import json
6 | import logging
7 | import math
8 | import os
9 | import random
10 | from collections import OrderedDict
11 | from itertools import chain
12 | from pathlib import Path
13 |
14 | import datasets
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 | import transformers
19 | from accelerate import Accelerator
20 | from accelerate.utils import set_seed
21 | from datasets import DatasetDict, load_dataset, load_from_disk
22 | from timm.utils import AverageMeter
23 | from torch.utils.data import DataLoader
24 | from tqdm.auto import tqdm
25 | from transformers import (
26 | CONFIG_MAPPING,
27 | MODEL_MAPPING,
28 | AutoConfig,
29 | AutoModelForMaskedLM,
30 | AutoTokenizer,
31 | DataCollatorForLanguageModeling,
32 | )
33 |
34 | from quantization.range_estimators import OptMethod, RangeEstimators
35 | from transformers_language.args import parse_args
36 | from transformers_language.dataset_setups import DatasetSetups
37 | from transformers_language.models.bert_attention import (
38 | AttentionGateType,
39 | BertSelfAttentionWithExtras,
40 | )
41 | from transformers_language.models.quantized_bert import QuantizedBertForMaskedLM
42 | from transformers_language.models.softmax import SOFTMAX_MAPPING
43 | from transformers_language.quant_configs import get_quant_config
44 | from transformers_language.utils import (
45 | count_params,
46 | kurtosis,
47 | pass_data_for_range_estimation,
48 | val_qparams,
49 | )
50 |
51 | logger = logging.getLogger("validate_mlm")
52 | logging.basicConfig(
53 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
54 | datefmt="%m/%d/%Y %H:%M:%S",
55 | level=logging.INFO,
56 | )
57 |
58 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
59 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
60 |
61 | EXTRA_METRICS = True
62 |
63 |
64 | def attach_act_hooks(model):
65 | act_dict = OrderedDict()
66 |
67 | def _make_hook(name):
68 | def _hook(mod, inp, out):
69 | if isinstance(inp, tuple) and len(inp) > 0:
70 | inp = inp[0]
71 | act_dict[name] = (inp, out)
72 |
73 | return _hook
74 |
75 | for name, module in model.named_modules():
76 | module.register_forward_hook(_make_hook(name))
77 | return act_dict
78 |
79 |
80 | def main():
81 | args = parse_args()
82 | logger.info(args)
83 |
84 | # convert dataset setup to an enum
85 | dataset_setup = DatasetSetups[args.dataset_setup]
86 |
87 | # Initialize the accelerator. We will let the accelerator handle device placement for us in
88 | # this example.
89 | # If we're using tracking, we also need to initialize it here and it will by default pick up
90 | # all supported trackers in the environment
91 | accelerator_log_kwargs = {}
92 |
93 | if args.with_tracking:
94 | accelerator_log_kwargs["log_with"] = args.report_to
95 | accelerator_log_kwargs["logging_dir"] = args.output_dir
96 |
97 | accelerator = Accelerator(
98 | gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs
99 | )
100 |
101 | logger.info(accelerator.state)
102 | if accelerator.is_local_main_process:
103 | datasets.utils.logging.set_verbosity_warning()
104 | transformers.utils.logging.set_verbosity_info()
105 | else:
106 | datasets.utils.logging.set_verbosity_error()
107 | transformers.utils.logging.set_verbosity_error()
108 |
109 | # If passed along, set the training seed now.
110 | if args.seed is not None:
111 | set_seed(args.seed)
112 |
113 | # Prepare HuggingFace config
114 | # In distributed training, the .from_pretrained methods guarantee that only one local process
115 | # can concurrently download model & vocab.
116 | config_kwargs = {
117 | "cache_dir": args.model_cache_dir,
118 | }
119 | if args.config_name:
120 | config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
121 | elif args.model_name_or_path:
122 | config = AutoConfig.from_pretrained(args.model_name_or_path, **config_kwargs)
123 | else:
124 | config = CONFIG_MAPPING[args.model_type]()
125 | logger.warning("You are instantiating a new config instance from scratch.")
126 |
127 | # Display config after changes
128 | logger.info("HuggingFace config after user changes:")
129 | logger.info(str(config))
130 |
131 | # Load tokenizer
132 | tokenizer_kwargs = {
133 | # 'cache_dir': args.model_cache_dir,
134 | }
135 | if args.model_name_or_path:
136 | tokenizer = AutoTokenizer.from_pretrained(
137 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, **tokenizer_kwargs
138 | )
139 | else:
140 | raise ValueError(
141 | "You are instantiating a new tokenizer from scratch. This is not supported by this "
142 | "script. You can do it from another script, save it, and load it from here, "
143 | "using --tokenizer_name."
144 | )
145 |
146 | # Load and prepare model
147 | if args.model_name_or_path:
148 | model = AutoModelForMaskedLM.from_pretrained(
149 | args.model_name_or_path,
150 | from_tf=bool(".ckpt" in args.model_name_or_path),
151 | config=config,
152 | cache_dir=args.model_cache_dir,
153 | )
154 | else:
155 | logger.info("Training new model from scratch")
156 | model = AutoModelForMaskedLM.from_config(config)
157 |
158 | # replace GELUActivation with nn.GELU
159 | for layer_idx in range(len(model.bert.encoder.layer)):
160 | model.bert.encoder.layer[layer_idx].intermediate.intermediate_act_fn = nn.GELU()
161 | # (skip head since we do not quantize it anyway)
162 |
163 | # replace Self-attention module with ours
164 | # NOTE: currently assumes BERT
165 | logger.info("replace self-attention module with ours (+copy loaded weights for Q,K,V)")
166 | for layer_idx in range(len(model.bert.encoder.layer)):
167 | old_self = model.bert.encoder.layer[layer_idx].attention.self
168 | new_self = BertSelfAttentionWithExtras(
169 | config,
170 | softmax_fn=SOFTMAX_MAPPING[args.attn_softmax],
171 | alpha=args.alpha,
172 | max_seq_length=args.max_seq_length,
173 | skip_attn=args.skip_attn,
174 | attn_gate_type=AttentionGateType[args.attn_gate_type],
175 | attn_gate_init=args.attn_gate_init,
176 | attn_gate_mlp=args.attn_gate_mlp,
177 | attn_gate_mlp2=args.attn_gate_mlp2,
178 | attn_gate_linear_all_features=args.attn_gate_linear_all_features,
179 | )
180 |
181 | # copy loaded weights
182 | new_self.load_state_dict(old_self.state_dict(), strict=False)
183 | model.bert.encoder.layer[layer_idx].attention.self = new_self
184 |
185 | # Gating -> load the model again to load missing alpha
186 | if args.attn_gate_type != "none":
187 | state_dict = torch.load(str(Path(args.model_name_or_path) / "pytorch_model.bin"))
188 | new_state_dict = {}
189 | for name, val in state_dict.items():
190 | if "alpha" in name:
191 | new_state_dict[name] = val
192 | model.load_state_dict(new_state_dict, strict=False)
193 |
194 | # Display num params
195 | n_embeddings = count_params(model.bert.embeddings)
196 | n_encoder = count_params(model.bert.encoder)
197 | n_head = count_params(model.cls)
198 | logger.info(
199 | f"\nNumber of parameters:\n"
200 | f"\t* Embeddings:\t{n_embeddings}\n"
201 | f"\t* Encoder:\t{n_encoder}\n"
202 | f"\t* Head:\t{n_head}\n"
203 | f"\t= Total (pre-training):\t{n_embeddings + n_encoder + n_head}\n"
204 | f"\t= Total (encoder):\t{n_embeddings + n_encoder}\n"
205 | )
206 |
207 | # Get the datasets.
208 | # In distributed training, the load_dataset function guarantee that only one local process can
209 | # concurrently download the dataset.
210 |
211 | pre_tokenized_path_map = OrderedDict(
212 | [
213 | # (data_setup, max_seq_length, validation_percentage) -> dirname
214 | ((DatasetSetups.bookcorpus_and_wiki, 128, None), "tokenized_wiki_val_128"),
215 | ((DatasetSetups.bookcorpus_and_wiki, 128, 5), "tokenized_wiki_val_128_5%"),
216 | ((DatasetSetups.bookcorpus_and_wiki, 128, 1), "tokenized_wiki_val_128_1%"),
217 | ((DatasetSetups.wikitext_103, 128, None), "tokenized_wikitext_103_val_128"),
218 | ((DatasetSetups.wikitext_103, 128, 5), "tokenized_wikitext_103_val_128_5%"),
219 | ]
220 | )
221 | for k, v in pre_tokenized_path_map.items():
222 | pre_tokenized_path_map[k] = Path(args.data_cache_dir) / v
223 |
224 | tokenized_configuration = (dataset_setup, args.max_seq_length, args.validation_percentage)
225 | pre_tokenized_path = pre_tokenized_path_map.get(tokenized_configuration, None)
226 |
227 | if pre_tokenized_path is not None and pre_tokenized_path.exists():
228 | pre_tokenized_path = str(pre_tokenized_path)
229 |
230 | accelerator.print(f"Loading pre-tokenized dataset from {pre_tokenized_path}")
231 | tokenized_datasets = load_from_disk(pre_tokenized_path)
232 |
233 | else: # do tokenization
234 | train_split = (
235 | "train" if args.train_percentage is None else f"train[:{args.train_percentage}%]"
236 | )
237 | val_split = (
238 | "validation"
239 | if args.validation_percentage is None
240 | else f"validation[:{args.validation_percentage}%]"
241 | )
242 |
243 | if dataset_setup == DatasetSetups.wikitext_2:
244 | raw_datasets = DatasetDict()
245 | raw_datasets["train"] = load_dataset(
246 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=train_split
247 | )
248 | raw_datasets["validation"] = load_dataset(
249 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=val_split
250 | )
251 |
252 | elif dataset_setup == DatasetSetups.wikitext_103:
253 | raw_datasets = DatasetDict()
254 | raw_datasets["train"] = load_dataset(
255 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=train_split
256 | )
257 | raw_datasets["validation"] = load_dataset(
258 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=val_split
259 | )
260 |
261 | elif dataset_setup == DatasetSetups.bookcorpus_and_wiki:
262 | bookcorpus = load_dataset(
263 | "bookcorpus", cache_dir=args.data_cache_dir, split=train_split
264 | )
265 |
266 | wiki_train = load_dataset(
267 | "wiki40b", "en", cache_dir=args.data_cache_dir, split=train_split
268 | )
269 | wiki_val = load_dataset("wiki40b", "en", cache_dir=args.data_cache_dir, split=val_split)
270 |
271 | # only keep the 'text' column
272 | wiki_train = wiki_train.remove_columns(
273 | [c for c in wiki_train.column_names if c != "text"]
274 | )
275 | wiki_val = wiki_val.remove_columns(
276 | [col for col in wiki_val.column_names if col != "text"]
277 | )
278 | assert bookcorpus.features.type == wiki_train.features.type
279 |
280 | raw_datasets = DatasetDict()
281 | raw_datasets["train_book"] = bookcorpus
282 | raw_datasets["train_wiki"] = wiki_train
283 | raw_datasets["validation"] = wiki_val
284 |
285 | else:
286 | raise ValueError(f"Unknown dataset, {dataset_setup}")
287 |
288 | # Preprocessing the datasets.
289 |
290 | # Check sequence length
291 | if args.max_seq_length is None:
292 | max_seq_length = tokenizer.model_max_length
293 | if max_seq_length > 1024:
294 | logger.warning(
295 | f"The tokenizer picked seems to have a very large `model_max_length` "
296 | f"({tokenizer.model_max_length}). Picking 1024 instead. You can change that "
297 | f"default value by passing --max_seq_length xxx."
298 | )
299 | max_seq_length = 1024
300 | else:
301 | if args.max_seq_length > tokenizer.model_max_length:
302 | logger.warning(
303 | f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum "
304 | f"length for the model ({tokenizer.model_max_length}). Using "
305 | f"max_seq_length={tokenizer.model_max_length}."
306 | )
307 | max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
308 |
309 | # Tokenize all the texts.
310 | # YB: removed line-by-line option as we'll likely never use it
311 | column_names = raw_datasets["validation"].column_names
312 | text_column_name = "text" if "text" in column_names else column_names[0]
313 |
314 | # ... we tokenize every text, then concatenate them together before splitting them in smaller
315 | # parts. We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling
316 | # (see below) is more efficient when it receives the `special_tokens_mask`.
317 | def tokenize_function(examples):
318 | return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
319 |
320 | # YB: make the default bs for text pre-processing explicit
321 | tokenizer_map_batch_size = 1000
322 | with accelerator.main_process_first():
323 | tokenized_datasets = raw_datasets.map(
324 | tokenize_function,
325 | batched=True,
326 | batch_size=tokenizer_map_batch_size,
327 | writer_batch_size=tokenizer_map_batch_size,
328 | num_proc=args.preprocessing_num_workers,
329 | remove_columns=column_names,
330 | load_from_cache_file=not args.overwrite_cache,
331 | desc="Running tokenizer on every text in dataset",
332 | )
333 |
334 | # Main data processing function that will concatenate all texts from our dataset and generate
335 | # chunks of max_seq_length.
336 | def group_texts(examples):
337 | # Concatenate all texts.
338 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
339 | total_length = len(concatenated_examples[list(examples.keys())[0]])
340 | # We drop the small remainder, we could add padding if the model supported it instead of
341 | # this drop, you can customize this part to your needs.
342 | if total_length >= max_seq_length:
343 | total_length = (total_length // max_seq_length) * max_seq_length
344 | # Split by chunks of max_len.
345 | result = {
346 | k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
347 | for k, t in concatenated_examples.items()
348 | }
349 | return result
350 |
351 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts
352 | # throws away a remainder for each of those groups of 1,000 texts. You can adjust that
353 | # batch_size here but a higher value might be slower to preprocess.
354 | #
355 | # To speed up this part, we use multiprocessing. See the documentation of the map method for
356 | # more information:
357 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
358 |
359 | with accelerator.main_process_first():
360 | tokenized_datasets = tokenized_datasets.map(
361 | group_texts,
362 | batched=True,
363 | batch_size=tokenizer_map_batch_size,
364 | num_proc=args.preprocessing_num_workers,
365 | load_from_cache_file=not args.overwrite_cache,
366 | desc=f"Grouping texts in chunks of {max_seq_length}",
367 | )
368 |
369 | #
370 |
371 | eval_dataset = tokenized_datasets["validation"]
372 |
373 | # Conditional for small test subsets
374 | if len(eval_dataset) > 3:
375 | # Log a few random samples from the training set:
376 | for index in random.sample(range(len(eval_dataset)), 3):
377 | logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.")
378 |
379 | # Data collator
380 | # This one will take care of randomly masking the tokens.
381 | data_collator = DataCollatorForLanguageModeling(
382 | tokenizer=tokenizer, mlm_probability=args.mlm_probability
383 | )
384 |
385 | # DataLoaders creation:
386 | eval_dataloader = DataLoader(
387 | eval_dataset,
388 | collate_fn=data_collator,
389 | batch_size=args.per_device_eval_batch_size,
390 | num_workers=args.preprocessing_num_workers,
391 | )
392 |
393 | # Prepare everything with our `accelerator`.
394 | model, eval_dataloader = accelerator.prepare(model, eval_dataloader)
395 |
396 | logger.info("FP model:")
397 | logger.info(model)
398 |
399 | # Quantize:
400 | if args.quantize:
401 | click_config = get_quant_config()
402 |
403 | # override number of batches
404 | click_config.act_quant.num_batches = args.est_num_batches
405 | click_config.quant.n_bits = args.n_bits
406 | click_config.quant.n_bits_act = args.n_bits_act
407 | if args.no_weight_quant:
408 | click_config.quant.weight_quant = False
409 | if args.no_act_quant:
410 | click_config.quant.act_quant = False
411 |
412 | # Weight Ranges
413 | if args.ranges_weights == "minmax":
414 | pass
415 | elif args.ranges_weights in ("mse", "MSE"):
416 | click_config.quant.weight_quant_method = RangeEstimators.MSE
417 | click_config.quant.weight_opt_method = OptMethod.grid
418 | else:
419 | raise ValueError(f"Unknown weight range estimation: {args.ranges_weights}")
420 |
421 | # Acts ranges
422 | if args.percentile is not None:
423 | click_config.act_quant.options["percentile"] = args.percentile
424 |
425 | if args.ranges_acts == "running_minmax":
426 | click_config.act_quant.quant_method = RangeEstimators.running_minmax
427 |
428 | elif args.ranges_acts == "MSE":
429 | click_config.act_quant.quant_method = RangeEstimators.MSE
430 | if args.qmethod_acts == "symmetric_uniform":
431 | click_config.act_quant.options = dict(opt_method=OptMethod.grid)
432 | elif args.qmethod_acts == "asymmetric_uniform":
433 | click_config.act_quant.options = dict(opt_method=OptMethod.golden_section)
434 |
435 | elif args.ranges_acts.startswith("L"):
436 | click_config.act_quant.quant_method = RangeEstimators.Lp
437 | p_norm = float(args.ranges_acts.replace("L", ""))
438 | options = dict(p_norm=p_norm)
439 | if args.qmethod_acts == "symmetric_uniform":
440 | options["opt_method"] = OptMethod.grid
441 | elif args.qmethod_acts == "asymmetric_uniform":
442 | options["opt_method"] = OptMethod.golden_section
443 | click_config.act_quant.options = options
444 |
445 | else:
446 | raise NotImplementedError(f"Unknown act range estimation setting, '{args.ranges_acts}'")
447 |
448 | qparams = val_qparams(click_config)
449 | qparams["quant_dict"] = {}
450 |
451 | model = QuantizedBertForMaskedLM(model, **qparams)
452 | model.set_quant_state(
453 | weight_quant=click_config.quant.weight_quant, act_quant=click_config.quant.act_quant
454 | )
455 |
456 | logger.info("Quantized model:")
457 | logger.info(model)
458 |
459 | # Range estimation
460 | logger.info("** Estimate quantization ranges on training data **")
461 | pass_data_for_range_estimation(
462 | loader=eval_dataloader,
463 | model=model,
464 | act_quant=click_config.quant.act_quant,
465 | max_num_batches=click_config.act_quant.num_batches,
466 | )
467 | model.fix_ranges()
468 | model.set_quant_state(
469 | weight_quant=click_config.quant.weight_quant, act_quant=click_config.quant.act_quant
470 | )
471 |
472 | # attach hooks for activation stats (if needed)
473 | act_dict = {}
474 | if EXTRA_METRICS:
475 | act_dict = attach_act_hooks(model)
476 |
477 | num_layers = len(model.bert.encoder.layer)
478 | act_inf_norms = OrderedDict()
479 | act_kurtoses = OrderedDict()
480 |
481 | # *** Evaluation ***
482 | model.eval()
483 | losses = []
484 | for batch_idx, batch in enumerate(tqdm(eval_dataloader)):
485 | with torch.no_grad():
486 | outputs = model(**batch)
487 |
488 | loss = outputs.loss
489 | loss_ = accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))
490 | losses.append(loss_)
491 |
492 | # compute inf norms
493 | if EXTRA_METRICS:
494 | for j in range(num_layers):
495 | for name in (
496 | f"bert.encoder.layer.{j}.output.dense", # FFN output
497 | f"bert.encoder.layer.{j}.output.LayerNorm", # LN(FFN output + input)
498 | ):
499 | x_inp, x_out = act_dict[name]
500 |
501 | x = x_out
502 |
503 | # inf-norm
504 | x = x.view(x.size(0), -1)
505 | inf_norms = x.norm(dim=1, p=np.inf)
506 | if not name in act_inf_norms:
507 | act_inf_norms[name] = AverageMeter()
508 | for v in inf_norms:
509 | act_inf_norms[name].update(v.item())
510 |
511 | # kurtosis
512 | if batch_idx <= 256:
513 | kurt = kurtosis(x)
514 | if not name in act_kurtoses:
515 | act_kurtoses[name] = AverageMeter()
516 | for v in kurt:
517 | act_kurtoses[name].update(v.item())
518 |
519 | # compute inf norm also for input
520 | if "LayerNorm" in name:
521 | x = x_inp
522 | x = x.view(x.size(0), -1)
523 | inf_norms = x.norm(dim=1, p=np.inf)
524 | name += ".input"
525 | if not name in act_inf_norms:
526 | act_inf_norms[name] = AverageMeter()
527 | for v in inf_norms:
528 | act_inf_norms[name].update(v.item())
529 |
530 | losses = torch.cat(losses)
531 | try:
532 | eval_loss = torch.mean(losses)
533 | perplexity = math.exp(eval_loss)
534 | except OverflowError:
535 | perplexity = float("inf")
536 | logger.info(f"perplexity: {perplexity:.4f}")
537 |
538 | # metrics
539 | metrics = OrderedDict([("perplexity", perplexity)])
540 |
541 | if EXTRA_METRICS:
542 | for name, v in act_inf_norms.items():
543 | metrics[name] = v.avg
544 |
545 | max_ffn_out_inf_norm = max(v.avg for k, v in act_inf_norms.items() if "dense" in k)
546 | max_LN_out_inf_norm = max(
547 | v.avg for k, v in act_inf_norms.items() if k.endswith("LayerNorm")
548 | )
549 | max_LN_inp_inf_norm = max(v.avg for k, v in act_inf_norms.items() if "input" in k)
550 | avg_kurtosis = sum(v.avg for v in act_kurtoses.values()) / len(act_kurtoses.values())
551 | max_kurtosis = max(v.avg for v in act_kurtoses.values())
552 |
553 | metrics["max_ffn_out_inf_norm"] = max_ffn_out_inf_norm
554 | metrics["max_LN_out_inf_norm"] = max_LN_out_inf_norm
555 | metrics["max_LN_inp_inf_norm"] = max_LN_inp_inf_norm
556 | metrics["avg_kurtosis"] = avg_kurtosis
557 | metrics["max_kurtosis"] = max_kurtosis
558 |
559 | logger.info(f"max FFN output inf norm: {max_ffn_out_inf_norm:.1f}")
560 | logger.info(f"max FFN input + output inf norm: {max_LN_inp_inf_norm:.1f}")
561 | logger.info(f"max LN(FFN i + o) inf norm: {max_LN_out_inf_norm:.1f}")
562 | logger.info(f"Avg Kurtosis: {avg_kurtosis:.2f}")
563 | logger.info(f"Max Kurtosis: {max_kurtosis:.1f}")
564 |
565 | if args.output_dir is not None:
566 | os.makedirs(args.output_dir, exist_ok=True)
567 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
568 | json.dump(metrics, f)
569 |
570 |
571 | if __name__ == "__main__":
572 | main()
573 |
--------------------------------------------------------------------------------
/validate_clm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
4 | # All Rights Reserved.
5 | import json
6 | import logging
7 | import math
8 | import os
9 | import random
10 | from collections import OrderedDict
11 | from itertools import chain
12 | from pathlib import Path
13 | from pprint import pformat
14 |
15 | import datasets
16 | import numpy as np
17 | import torch
18 | import transformers
19 | from accelerate import Accelerator
20 | from accelerate.utils import set_seed
21 | from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
22 | from timm.utils import AverageMeter
23 | from torch.utils.data import DataLoader
24 | from tqdm.auto import tqdm
25 | from transformers import (
26 | CONFIG_MAPPING,
27 | MODEL_MAPPING,
28 | AutoConfig,
29 | AutoModelForCausalLM,
30 | AutoTokenizer,
31 | default_data_collator,
32 | )
33 |
34 | from quantization.quantizers import QMethods
35 | from quantization.range_estimators import OptMethod, RangeEstimators
36 | from transformers_language.args import parse_args
37 | from transformers_language.dataset_setups import DatasetSetups
38 | from transformers_language.models.opt_attention import (
39 | AttentionGateType,
40 | OPTAttentionWithExtras,
41 | )
42 | from transformers_language.models.quantized_opt import QuantizedOPTForCausalLM
43 | from transformers_language.models.softmax import SOFTMAX_MAPPING
44 | from transformers_language.quant_configs import get_quant_config
45 | from transformers_language.utils import (
46 | count_params,
47 | kurtosis,
48 | pass_data_for_range_estimation,
49 | val_qparams,
50 | )
51 |
52 | logger = logging.getLogger("validate_clm")
53 | logging.basicConfig(
54 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
55 | datefmt="%m/%d/%Y %H:%M:%S",
56 | level=logging.INFO,
57 | )
58 |
59 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
60 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
61 |
62 |
63 | def main():
64 | args = parse_args()
65 | logger.info(args)
66 |
67 | # convert dataset setup to an enum
68 | dataset_setup = DatasetSetups[args.dataset_setup]
69 |
70 | # Initialize the accelerator. We will let the accelerator handle device placement for us in
71 | # this example.
72 | # If we're using tracking, we also need to initialize it here and it will by default pick up
73 | # all supported trackers in the environment
74 | accelerator_log_kwargs = {}
75 |
76 | if args.with_tracking:
77 | accelerator_log_kwargs["log_with"] = args.report_to
78 | accelerator_log_kwargs["logging_dir"] = args.output_dir
79 |
80 | accelerator = Accelerator(
81 | gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs
82 | )
83 |
84 | logger.info(accelerator.state)
85 | if accelerator.is_local_main_process:
86 | datasets.utils.logging.set_verbosity_warning()
87 | transformers.utils.logging.set_verbosity_info()
88 | else:
89 | datasets.utils.logging.set_verbosity_error()
90 | transformers.utils.logging.set_verbosity_error()
91 |
92 | # If passed along, set the training seed now.
93 | if args.seed is not None:
94 | set_seed(args.seed)
95 |
96 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
97 | # https://huggingface.co/docs/datasets/loading_datasets.html.
98 |
99 | # Load pretrained model and tokenizer
100 | #
101 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
102 | # download model & vocab.
103 | config_kwargs = {
104 | "cache_dir": args.model_cache_dir,
105 | }
106 | if args.config_name:
107 | config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
108 | elif args.model_name_or_path:
109 | config = AutoConfig.from_pretrained(args.model_name_or_path, **config_kwargs)
110 | else:
111 | config = CONFIG_MAPPING[args.model_type]()
112 | logger.warning("You are instantiating a new config instance from scratch.")
113 |
114 | # Display config after changes
115 | logger.info("HuggingFace config after user changes:")
116 | logger.info(str(config))
117 |
118 | # Load tokenizer
119 | tokenizer_kwargs = {
120 | # 'cache_dir': args.model_cache_dir,
121 | }
122 | if args.model_name_or_path:
123 | tokenizer = AutoTokenizer.from_pretrained(
124 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, **tokenizer_kwargs
125 | )
126 | else:
127 | raise ValueError(
128 | "You are instantiating a new tokenizer from scratch. This is not supported by this script."
129 | "You can do it from another script, save it, and load it from here, using --tokenizer_name."
130 | )
131 |
132 | # Load and prepare model
133 | if args.model_name_or_path:
134 | model = AutoModelForCausalLM.from_pretrained(
135 | args.model_name_or_path,
136 | from_tf=bool(".ckpt" in args.model_name_or_path),
137 | config=config,
138 | low_cpu_mem_usage=args.low_cpu_mem_usage,
139 | cache_dir=args.model_cache_dir,
140 | )
141 | else:
142 | logger.info("Training new model from scratch")
143 | model = AutoModelForCausalLM.from_config(config)
144 |
145 | # >> replace self-attention module with ours
146 | # NOTE: currently assumes OPT
147 | for layer_idx in range(len(model.model.decoder.layers)):
148 | old_attn = model.model.decoder.layers[layer_idx].self_attn
149 | new_attn = OPTAttentionWithExtras(
150 | embed_dim=old_attn.embed_dim,
151 | num_heads=old_attn.num_heads,
152 | dropout=old_attn.dropout,
153 | is_decoder=old_attn.is_decoder,
154 | bias=True,
155 | # new
156 | softmax_fn=SOFTMAX_MAPPING[args.attn_softmax],
157 | alpha=args.alpha,
158 | max_seq_length=args.block_size,
159 | skip_attn=args.skip_attn,
160 | attn_gate_type=AttentionGateType[args.attn_gate_type],
161 | attn_gate_init=args.attn_gate_init,
162 | attn_gate_mlp=args.attn_gate_mlp,
163 | attn_gate_mlp2=args.attn_gate_mlp2,
164 | attn_gate_linear_all_features=args.attn_gate_linear_all_features,
165 | )
166 | # copy loaded weights
167 | new_attn.load_state_dict(old_attn.state_dict(), strict=False)
168 | model.model.decoder.layers[layer_idx].self_attn = new_attn
169 |
170 | # Gating -> load the model again to load missing alpha
171 | if args.attn_gate_type != "none":
172 | state_dict = torch.load(str(Path(args.model_name_or_path) / "pytorch_model.bin"))
173 | new_state_dict = {}
174 | for name, val in state_dict.items():
175 | if "alpha" in name:
176 | new_state_dict[name] = val
177 | model.load_state_dict(new_state_dict, strict=False)
178 |
179 | # Print model
180 | logger.info("Model:")
181 | logger.info(model)
182 |
183 | # Display num params
184 | n_embeddings = count_params(model.model.decoder.embed_tokens) + count_params(
185 | model.model.decoder.embed_positions
186 | )
187 | n_decoder = count_params(model.model.decoder) - n_embeddings
188 | n_head = count_params(model.lm_head)
189 | logger.info(
190 | f"\nNumber of parameters:\n"
191 | f"\t* Embeddings:\t{n_embeddings}\n"
192 | f"\t* Decoder:\t{n_decoder}\n"
193 | f"\t* Head:\t{n_head}\n"
194 | f"\t= Total (pre-training):\t{n_embeddings + n_decoder + n_head}\n"
195 | f"\t= Total (decoder only):\t{n_embeddings + n_decoder}\n"
196 | )
197 |
198 | # -----------------------------------------------------------------
199 |
200 | # Get the datasets.
201 | # In distributed training, the load_dataset function guarantee that only one local process can
202 | # concurrently download the dataset.
203 |
204 | # (data_setup, block_size, train_percentage, validation_percentage) -> train_dirname
205 | pre_tokenized_path_map = OrderedDict(
206 | [
207 | ((DatasetSetups.wikitext_103, 512, None, None), ("tokenized_wikitext_103_OPT_512")),
208 | (
209 | (DatasetSetups.wikitext_103, 512, None, 10),
210 | ("tokenized_wikitext_103_OPT_512_val_10%"),
211 | ),
212 | (
213 | (DatasetSetups.wikitext_103, 512, 10, None),
214 | ("tokenized_wikitext_103_OPT_512_train_10%"),
215 | ),
216 | (
217 | (DatasetSetups.bookcorpus_and_wiki, 512, 1, 5),
218 | ("tokenized_book_wiki_OPT_512_train_1%_val_5%"),
219 | ),
220 | (
221 | (DatasetSetups.bookcorpus_and_wiki, 512, 1, 1),
222 | ("tokenized_book_wiki_OPT_512_train_1%_val_1%"),
223 | ),
224 | ]
225 | )
226 | for k, v in pre_tokenized_path_map.items():
227 | pre_tokenized_path_map[k] = Path(args.data_cache_dir) / v
228 |
229 | tokenized_configuration = (
230 | dataset_setup,
231 | args.block_size,
232 | args.train_percentage,
233 | args.validation_percentage,
234 | )
235 | pre_tokenized_path = pre_tokenized_path_map.get(tokenized_configuration, None)
236 |
237 | if pre_tokenized_path is not None and pre_tokenized_path.exists():
238 | pre_tokenized_path = str(pre_tokenized_path)
239 |
240 | accelerator.print(f"Loading pre-tokenized dataset from {pre_tokenized_path}")
241 | tokenized_datasets = load_from_disk(pre_tokenized_path)
242 |
243 | else: # do tokenization
244 | train_split = (
245 | "train" if args.train_percentage is None else f"train[:{args.train_percentage}%]"
246 | )
247 | val_split = (
248 | "validation"
249 | if args.validation_percentage is None
250 | else f"validation[:{args.validation_percentage}%]"
251 | )
252 |
253 | if dataset_setup == DatasetSetups.wikitext_2:
254 | raw_datasets = DatasetDict()
255 | raw_datasets["train"] = load_dataset(
256 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=train_split
257 | )
258 | raw_datasets["validation"] = load_dataset(
259 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=val_split
260 | )
261 |
262 | elif dataset_setup == DatasetSetups.wikitext_103:
263 | raw_datasets = DatasetDict()
264 | raw_datasets["train"] = load_dataset(
265 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=train_split
266 | )
267 | raw_datasets["validation"] = load_dataset(
268 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=val_split
269 | )
270 |
271 | elif dataset_setup == DatasetSetups.bookcorpus_and_wiki:
272 | bookcorpus = load_dataset(
273 | "bookcorpus", cache_dir=args.data_cache_dir, split=train_split
274 | )
275 |
276 | wiki_train = load_dataset(
277 | "wiki40b", "en", cache_dir=args.data_cache_dir, split=train_split
278 | )
279 | wiki_val = load_dataset("wiki40b", "en", cache_dir=args.data_cache_dir, split=val_split)
280 |
281 | # only keep the 'text' column
282 | wiki_train = wiki_train.remove_columns(
283 | [c for c in wiki_train.column_names if c != "text"]
284 | )
285 | wiki_val = wiki_val.remove_columns(
286 | [col for col in wiki_val.column_names if col != "text"]
287 | )
288 | assert bookcorpus.features.type == wiki_train.features.type
289 |
290 | raw_datasets = DatasetDict()
291 | raw_datasets["train_book"] = bookcorpus
292 | raw_datasets["train_wiki"] = wiki_train
293 | raw_datasets["validation"] = wiki_val
294 |
295 | else:
296 | raise ValueError(f"Unknown dataset, {dataset_setup}")
297 |
298 | # Preprocessing the datasets.
299 | # Check sequence length
300 | if args.block_size is None:
301 | block_size = tokenizer.model_max_length
302 | if block_size > 1024:
303 | logger.warning(
304 | "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
305 | " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
306 | " override this default with `--block_size xxx`."
307 | )
308 | block_size = 1024
309 | else:
310 | if args.block_size > tokenizer.model_max_length:
311 | logger.warning(
312 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
313 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
314 | )
315 | block_size = min(args.block_size, tokenizer.model_max_length)
316 |
317 | # Tokenize all the texts.
318 | column_names = raw_datasets["validation"].column_names
319 | text_column_name = "text" if "text" in column_names else column_names[0]
320 |
321 | def tokenize_function(examples):
322 | return tokenizer(examples[text_column_name])
323 |
324 | # YB: make the default bs for text pre-processing explicit
325 | tokenizer_map_batch_size = 1000
326 | with accelerator.main_process_first():
327 | tokenized_datasets = raw_datasets.map(
328 | tokenize_function,
329 | batched=True,
330 | batch_size=tokenizer_map_batch_size,
331 | writer_batch_size=tokenizer_map_batch_size,
332 | num_proc=args.preprocessing_num_workers,
333 | remove_columns=column_names,
334 | load_from_cache_file=not args.overwrite_cache,
335 | desc="Running tokenizer on dataset",
336 | )
337 |
338 | # Main data processing function that will concatenate all texts from our dataset and generate
339 | # chunks of max_seq_length.
340 | def group_texts(examples):
341 | # Concatenate all texts.
342 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
343 | total_length = len(concatenated_examples[list(examples.keys())[0]])
344 | # We drop the small remainder, we could add padding if the model supported it instead of
345 | # this drop, you can customize this part to your needs.
346 | if total_length >= block_size:
347 | total_length = (total_length // block_size) * block_size
348 | else:
349 | total_length = 0
350 | # Split by chunks of max_len.
351 | result = {
352 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
353 | for k, t in concatenated_examples.items()
354 | }
355 | result["labels"] = result["input_ids"].copy()
356 | return result
357 |
358 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
359 | # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
360 | # to preprocess.
361 | #
362 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
363 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
364 |
365 | with accelerator.main_process_first():
366 | tokenized_datasets = tokenized_datasets.map(
367 | group_texts,
368 | batched=True,
369 | batch_size=tokenizer_map_batch_size,
370 | num_proc=args.preprocessing_num_workers,
371 | load_from_cache_file=not args.overwrite_cache,
372 | desc=f"Grouping texts in chunks of {block_size}",
373 | )
374 |
375 | #
376 |
377 | if dataset_setup == DatasetSetups.bookcorpus_and_wiki:
378 | train_dataset = concatenate_datasets(
379 | [tokenized_datasets["train_book"], tokenized_datasets["train_wiki"]]
380 | )
381 | eval_dataset = tokenized_datasets["validation"]
382 | else:
383 | train_dataset = tokenized_datasets["train"]
384 | eval_dataset = tokenized_datasets["validation"]
385 |
386 | # Log a few random samples from the training set:
387 | if len(train_dataset) > 3:
388 | for index in random.sample(range(len(train_dataset)), 3):
389 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
390 |
391 | # DataLoaders creation:
392 | train_dataloader = DataLoader(
393 | train_dataset,
394 | shuffle=True,
395 | collate_fn=default_data_collator,
396 | batch_size=args.per_device_train_batch_size,
397 | num_workers=args.preprocessing_num_workers,
398 | )
399 | eval_dataloader = DataLoader(
400 | eval_dataset,
401 | collate_fn=default_data_collator,
402 | batch_size=args.per_device_eval_batch_size,
403 | num_workers=args.preprocessing_num_workers,
404 | )
405 |
406 | # Prepare everything with our `accelerator`.
407 | model, train_dataloader, eval_dataloader = accelerator.prepare(
408 | model, train_dataloader, eval_dataloader
409 | )
410 |
411 | logger.info("FP model:")
412 | logger.info(model)
413 |
414 | #
415 | ## Quantize:
416 | #
417 | if args.quantize:
418 | click_config = get_quant_config()
419 |
420 | # override number of batches
421 | click_config.act_quant.num_batches = args.est_num_batches
422 | click_config.quant.n_bits = args.n_bits
423 | click_config.quant.n_bits_act = args.n_bits_act
424 | click_config.quant.quant_setup = args.quant_setup
425 | if args.no_weight_quant:
426 | click_config.quant.weight_quant = False
427 | if args.no_act_quant:
428 | click_config.quant.act_quant = False
429 |
430 | # use MSE for weights (ignore `args.ranges_weights`)
431 | # click_config.quant.weight_quant_method = RangeEstimators.current_minmax
432 | click_config.quant.weight_quant_method = RangeEstimators.MSE
433 | click_config.quant.weight_opt_method = OptMethod.grid
434 |
435 | # qmethod acts
436 | if args.qmethod_acts == "symmetric_uniform":
437 | click_config.quant.qmethod_act = QMethods.symmetric_uniform
438 | elif args.qmethod_acts == "asymmetric_uniform":
439 | click_config.quant.qmethod_act = QMethods.asymmetric_uniform
440 | else:
441 | raise NotImplementedError(f"Unknown qmethod_act setting, '{args.qmethod_acts}'")
442 |
443 | # Acts ranges
444 | if args.percentile is not None:
445 | click_config.act_quant.options["percentile"] = args.percentile
446 |
447 | if args.ranges_acts == "running_minmax":
448 | click_config.act_quant.quant_method = RangeEstimators.running_minmax
449 |
450 | elif args.ranges_acts == "MSE":
451 | click_config.act_quant.quant_method = RangeEstimators.MSE
452 | if args.qmethod_acts == "symmetric_uniform":
453 | click_config.act_quant.options = dict(opt_method=OptMethod.grid)
454 | elif args.qmethod_acts == "asymmetric_uniform":
455 | click_config.act_quant.options = dict(opt_method=OptMethod.golden_section)
456 |
457 | elif args.ranges_acts.startswith("L"):
458 | click_config.act_quant.quant_method = RangeEstimators.Lp
459 | p_norm = float(args.ranges_acts.replace("L", ""))
460 | options = dict(p_norm=p_norm)
461 | if args.qmethod_acts == "symmetric_uniform":
462 | options["opt_method"] = OptMethod.grid
463 | elif args.qmethod_acts == "asymmetric_uniform":
464 | options["opt_method"] = OptMethod.golden_section
465 | click_config.act_quant.options = options
466 |
467 | else:
468 | raise NotImplementedError(f"Unknown range estimation setting, '{args.ranges_acts}'")
469 |
470 | qparams = val_qparams(click_config)
471 | qparams["quant_dict"] = {}
472 |
473 | model = QuantizedOPTForCausalLM(model, **qparams)
474 | model.set_quant_state(
475 | weight_quant=click_config.quant.weight_quant, act_quant=click_config.quant.act_quant
476 | )
477 |
478 | logger.info("Quantized model:")
479 | logger.info(model)
480 |
481 | # Range estimation
482 | logger.info("** Estimate quantization ranges on training data **")
483 | pass_data_for_range_estimation(
484 | loader=train_dataloader,
485 | model=model,
486 | act_quant=click_config.quant.act_quant,
487 | max_num_batches=click_config.act_quant.num_batches,
488 | )
489 | model.fix_ranges()
490 |
491 | model.set_quant_state(
492 | weight_quant=click_config.quant.weight_quant, act_quant=click_config.quant.act_quant
493 | )
494 |
495 | # attach hooks for activation stats
496 | def attach_act_hooks(model):
497 | act_dict = OrderedDict()
498 |
499 | def _make_hook(name):
500 | def _hook(mod, inp, out):
501 | if isinstance(inp, tuple) and len(inp) > 0:
502 | inp = inp[0]
503 | if isinstance(out, tuple) and len(out) > 0:
504 | out = out[0]
505 | act_dict[name] = (inp, out)
506 |
507 | return _hook
508 |
509 | for name, module in model.named_modules():
510 | module.register_forward_hook(_make_hook(name))
511 | return act_dict
512 |
513 | if args.quantize:
514 | act_dict = {}
515 | else:
516 | act_dict = attach_act_hooks(model)
517 | num_layers = len(model.model.decoder.layers)
518 |
519 | ACT_KEYS = [
520 | "model.decoder.final_layer_norm",
521 | *[f"model.decoder.layers.{j}" for j in range(num_layers)],
522 | *[f"model.decoder.layers.{j}.fc2" for j in range(num_layers)],
523 | *[f"model.decoder.layers.{j}.final_layer_norm" for j in range(num_layers)],
524 | *[f"model.decoder.layers.{j}.self_attn.out_proj" for j in range(num_layers)],
525 | *[f"model.decoder.layers.{j}.self_attn_layer_norm" for j in range(num_layers)],
526 | ]
527 |
528 | act_inf_norms = OrderedDict()
529 | act_kurtoses = OrderedDict()
530 |
531 | # -----------------------------------------------------------------
532 |
533 | # *** Evaluation ***
534 | model.eval()
535 | losses = []
536 | for batch_idx, batch in enumerate(tqdm(eval_dataloader)):
537 | with torch.no_grad():
538 | outputs = model(**batch)
539 |
540 | loss = outputs.loss
541 | loss_ = accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))
542 | losses.append(loss_)
543 |
544 | # compute inf norms
545 | if not args.quantize:
546 | for name in ACT_KEYS:
547 | if name in act_dict:
548 | x_inp, x_out = act_dict[name]
549 | x = x_out
550 | x = x.view(x.size(0), -1)
551 |
552 | # compute inf norm
553 | inf_norms = x.norm(dim=1, p=np.inf)
554 | if not name in act_inf_norms:
555 | act_inf_norms[name] = AverageMeter()
556 | for v in inf_norms:
557 | act_inf_norms[name].update(v.item())
558 |
559 | # compute kurtosis
560 | if batch_idx <= 100:
561 | kurt = kurtosis(x)
562 | if not name in act_kurtoses:
563 | act_kurtoses[name] = AverageMeter()
564 | for v in kurt:
565 | act_kurtoses[name].update(v.item())
566 |
567 | losses = torch.cat(losses)
568 | try:
569 | eval_loss = torch.mean(losses)
570 | perplexity = math.exp(eval_loss)
571 | except OverflowError:
572 | perplexity = float("inf")
573 | logger.info(f"perplexity: {perplexity:.4f}")
574 |
575 | # metrics
576 | metrics = OrderedDict([("perplexity", perplexity)])
577 |
578 | if not args.quantize:
579 | for name, v in act_inf_norms.items():
580 | metrics[name] = v.avg
581 |
582 | max_inf_norm = max(v.avg for v in act_inf_norms.values())
583 | max_ffn_inf_norm = max(v.avg for k, v in act_inf_norms.items() if ".fc" in k)
584 | max_layer_inf_norm = max(
585 | act_inf_norms[f"model.decoder.layers.{j}"].avg for j in range(num_layers)
586 | )
587 |
588 | avg_kurtosis = sum(v.avg for v in act_kurtoses.values()) / len(act_kurtoses.values())
589 | max_kurtosis = max(v.avg for v in act_kurtoses.values())
590 | max_kurtosis_layers = max(
591 | act_kurtoses[f"model.decoder.layers.{j}"].avg for j in range(num_layers)
592 | )
593 |
594 | metrics["max_inf_norm"] = max_inf_norm
595 | metrics["max_ffn_inf_norm"] = max_ffn_inf_norm
596 | metrics["max_layer_inf_norm"] = max_layer_inf_norm
597 |
598 | metrics["avg_kurtosis"] = avg_kurtosis
599 | metrics["max_kurtosis"] = max_kurtosis
600 | metrics["max_kurtosis_layers"] = max_kurtosis_layers
601 |
602 | logger.info(f"Max inf norm: {max_inf_norm:.1f}")
603 | logger.info(f"Max FFN inf norm: {max_ffn_inf_norm:.1f}")
604 | logger.info(f"Max layer inf norm: {max_layer_inf_norm:.1f}")
605 |
606 | logger.info(f"Avg Kurtosis: {avg_kurtosis:.2f}")
607 | logger.info(f"Max Kurtosis: {max_kurtosis:.1f}")
608 | logger.info(f"Max Kurtosis layers: {max_kurtosis_layers:.1f}")
609 |
610 | logger.info(f"\nAll metrics:\n{pformat(metrics)}")
611 |
612 | if args.output_dir is not None:
613 | os.makedirs(args.output_dir, exist_ok=True)
614 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
615 | json.dump(metrics, f)
616 |
617 |
618 | if __name__ == "__main__":
619 | main()
620 |
--------------------------------------------------------------------------------
/run_mlm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright (c) 2023 Qualcomm Technologies, Inc.
4 | # All Rights Reserved.
5 | import json
6 | import logging
7 | import math
8 | import os
9 | import random
10 | import warnings
11 | from collections import OrderedDict
12 | from itertools import chain
13 | from pathlib import Path
14 |
15 | import datasets
16 | import torch
17 | import transformers
18 | import yaml
19 | from accelerate import Accelerator
20 | from accelerate.logging import get_logger
21 | from accelerate.utils import set_seed
22 | from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
23 | from torch.utils.data import DataLoader
24 | from tqdm.auto import tqdm
25 | from transformers import (
26 | CONFIG_MAPPING,
27 | AutoConfig,
28 | AutoModelForMaskedLM,
29 | AutoTokenizer,
30 | DataCollatorForLanguageModeling,
31 | get_scheduler,
32 | )
33 |
34 | from transformers_language.args import parse_args
35 | from transformers_language.dataset_setups import DatasetSetups
36 | from transformers_language.models.bert_attention import (
37 | AttentionGateType,
38 | BertSelfAttentionWithExtras,
39 | )
40 | from transformers_language.models.softmax import SOFTMAX_MAPPING
41 | from transformers_language.utils import count_params
42 |
43 | logger = get_logger("run_mlm")
44 |
45 |
46 | def attach_tb_act_hooks(model):
47 | act_dict = OrderedDict()
48 |
49 | def _make_hook(name):
50 | def _hook(mod, inp, out):
51 | act_dict[name] = out[0]
52 |
53 | return _hook
54 |
55 | for name, module in model.named_modules():
56 | module.register_forward_hook(_make_hook(name))
57 | return act_dict
58 |
59 |
60 | def main():
61 | args = parse_args()
62 |
63 | # convert dataset setup to an enum
64 | dataset_setup = DatasetSetups[args.dataset_setup]
65 |
66 | # Initialize the accelerator. We will let the accelerator handle device placement for us in
67 | # this example.
68 | # If we're using tracking, we also need to initialize it here and it will by default pick up
69 | # all supported trackers in the environment
70 | accelerator_log_kwargs = {}
71 |
72 | if args.with_tracking:
73 | accelerator_log_kwargs["log_with"] = args.report_to
74 | accelerator_log_kwargs["project_dir"] = args.output_dir
75 |
76 | accelerator = Accelerator(
77 | gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs
78 | )
79 | accelerator.project_configuration.total_limit = 1
80 | accelerator.project_configuration.automatic_checkpoint_naming = True
81 |
82 | # log passed args
83 | logger.info(args)
84 |
85 | # Make one log on every process with the configuration for debugging.
86 | logging.basicConfig(
87 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
88 | datefmt="%m/%d/%Y %H:%M:%S",
89 | level=logging.INFO,
90 | )
91 | logger.info(accelerator.state, main_process_only=False)
92 | if accelerator.is_local_main_process:
93 | datasets.utils.logging.set_verbosity_warning()
94 | transformers.utils.logging.set_verbosity_info()
95 | else:
96 | datasets.utils.logging.set_verbosity_error()
97 | transformers.utils.logging.set_verbosity_error()
98 |
99 | # If passed along, set the training seed now.
100 | if args.seed is not None:
101 | set_seed(args.seed)
102 |
103 | # Prepare HuggingFace config
104 | # In distributed training, the .from_pretrained methods guarantee that only one local process
105 | # can concurrently download model & vocab.
106 | config_kwargs = {
107 | "cache_dir": args.model_cache_dir,
108 | }
109 | if args.config_name:
110 | config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
111 | elif args.model_name_or_path:
112 | config = AutoConfig.from_pretrained(args.model_name_or_path, **config_kwargs)
113 | else:
114 | config = CONFIG_MAPPING[args.model_type]()
115 | logger.warning("You are instantiating a new config instance from scratch.")
116 |
117 | # Load model config changes from file, if provided
118 | if args.config_path is not None:
119 | logger.info(f"Loading model config changes from {args.config_path}")
120 | with open(args.config_path) as f:
121 | config_changes = yaml.safe_load(f)
122 |
123 | for key, value in config_changes.items():
124 | setattr(config, key, value)
125 |
126 | # Set dropout rates, if specified
127 | if args.attn_dropout is not None:
128 | logger.info(f"Setting attention dropout rate to {args.attn_dropout}")
129 | config.attention_probs_dropout_prob = args.attn_dropout
130 |
131 | if args.hidden_dropout is not None:
132 | logger.info(f"Setting hidden dropout rate to {args.hidden_dropout}")
133 | config.hidden_dropout_prob = args.hidden_dropout
134 |
135 | # Display config after changes
136 | logger.info("HuggingFace config after user changes:")
137 | logger.info(str(config))
138 |
139 | # Load tokenizer
140 | tokenizer_kwargs = {
141 | "cache_dir": args.model_cache_dir,
142 | }
143 | if args.tokenizer_name:
144 | tokenizer = AutoTokenizer.from_pretrained(
145 | args.tokenizer_name, use_fast=not args.use_slow_tokenizer, **tokenizer_kwargs
146 | )
147 | elif args.model_name_or_path:
148 | tokenizer = AutoTokenizer.from_pretrained(
149 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, **tokenizer_kwargs
150 | )
151 | else:
152 | raise ValueError(
153 | "You are instantiating a new tokenizer from scratch. This is not supported by this "
154 | "script. You can do it from another script, save it, and load it from here, "
155 | "using --tokenizer_name."
156 | )
157 |
158 | # Load and prepare model
159 | if args.model_name_or_path:
160 | model = AutoModelForMaskedLM.from_pretrained(
161 | args.model_name_or_path,
162 | from_tf=bool(".ckpt" in args.model_name_or_path),
163 | config=config,
164 | cache_dir=args.model_cache_dir,
165 | )
166 | else:
167 | logger.info("Training new model from scratch")
168 | model = AutoModelForMaskedLM.from_config(config)
169 |
170 | # >> replace Self-attention module with ours
171 | # NOTE: currently assumes BERT
172 | for layer_idx in range(len(model.bert.encoder.layer)):
173 | old_self = model.bert.encoder.layer[layer_idx].attention.self
174 | new_self = BertSelfAttentionWithExtras(
175 | config,
176 | softmax_fn=SOFTMAX_MAPPING[args.attn_softmax],
177 | alpha=args.alpha,
178 | max_seq_length=args.max_seq_length,
179 | skip_attn=args.skip_attn,
180 | attn_gate_type=AttentionGateType[args.attn_gate_type],
181 | attn_gate_init=args.attn_gate_init,
182 | attn_gate_mlp=args.attn_gate_mlp,
183 | attn_gate_mlp2=args.attn_gate_mlp2,
184 | attn_gate_linear_all_features=args.attn_gate_linear_all_features,
185 | fine_tuning=args.fine_tuning,
186 | )
187 |
188 | # copy loaded weights
189 | if args.model_name_or_path is not None:
190 | new_self.load_state_dict(old_self.state_dict(), strict=False)
191 | model.bert.encoder.layer[layer_idx].attention.self = new_self
192 |
193 | # Gating -> load the model again to load missing alpha
194 | if args.model_name_or_path is not None and args.attn_gate_type != "none":
195 | state_dict = torch.load(str(Path(args.model_name_or_path) / "pytorch_model.bin"))
196 | new_state_dict = {}
197 | for name, val in state_dict.items():
198 | if "alpha" in name:
199 | new_state_dict[name] = val
200 | model.load_state_dict(new_state_dict, strict=False)
201 |
202 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a
203 | # model from scratch on a small vocab and want a smaller embedding size, remove this test.
204 | embedding_size = model.get_input_embeddings().weight.shape[0] # = vocab size
205 | if len(tokenizer) > embedding_size:
206 | model.resize_token_embeddings(len(tokenizer))
207 |
208 | # Display num params
209 | n_embeddings = count_params(model.bert.embeddings)
210 | n_encoder = count_params(model.bert.encoder)
211 | n_head = count_params(model.cls)
212 | logger.info(
213 | f"\nNumber of parameters:\n"
214 | f"\t* Embeddings:\t{n_embeddings}\n"
215 | f"\t* Encoder:\t{n_encoder}\n"
216 | f"\t* Head:\t{n_head}\n"
217 | f"\t= Total (pre-training):\t{n_embeddings + n_encoder + n_head}\n"
218 | f"\t= Total (encoder):\t{n_embeddings + n_encoder}\n"
219 | )
220 |
221 | # Get the datasets.
222 | # In distributed training, the load_dataset function guarantee that only one local process can
223 | # concurrently download the dataset.
224 | tokenized_book_wiki_path = (
225 | Path(args.data_cache_dir) / f"tokenized_book_wiki_{args.max_seq_length}"
226 | )
227 | if dataset_setup == DatasetSetups.bookcorpus_and_wiki and tokenized_book_wiki_path.exists():
228 | accelerator.print(f"Loading tokenized dataset from {str(tokenized_book_wiki_path)}")
229 |
230 | tokenized_datasets = load_from_disk(str(tokenized_book_wiki_path))
231 |
232 | else: # do tokenization
233 | train_split = (
234 | "train" if args.train_percentage is None else f"train[:{args.train_percentage}%]"
235 | )
236 | val_split = (
237 | "validation"
238 | if args.validation_percentage is None
239 | else f"validation[:{args.validation_percentage}%]"
240 | )
241 |
242 | if dataset_setup == DatasetSetups.wikitext_2:
243 | raw_datasets = DatasetDict()
244 | raw_datasets["train"] = load_dataset(
245 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=train_split
246 | )
247 | raw_datasets["validation"] = load_dataset(
248 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=val_split
249 | )
250 |
251 | elif dataset_setup == DatasetSetups.wikitext_103:
252 | raw_datasets = DatasetDict()
253 | raw_datasets["train"] = load_dataset(
254 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=train_split
255 | )
256 | raw_datasets["validation"] = load_dataset(
257 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=val_split
258 | )
259 |
260 | elif dataset_setup == DatasetSetups.bookcorpus_and_wiki:
261 | bookcorpus = load_dataset(
262 | "bookcorpus", cache_dir=args.data_cache_dir, split=train_split
263 | )
264 |
265 | wiki_train = load_dataset(
266 | "wiki40b", "en", cache_dir=args.data_cache_dir, split=train_split
267 | )
268 | wiki_val = load_dataset("wiki40b", "en", cache_dir=args.data_cache_dir, split=val_split)
269 |
270 | # only keep the 'text' column
271 | wiki_train = wiki_train.remove_columns(
272 | [c for c in wiki_train.column_names if c != "text"]
273 | )
274 | wiki_val = wiki_val.remove_columns(
275 | [col for col in wiki_val.column_names if col != "text"]
276 | )
277 | assert bookcorpus.features.type == wiki_train.features.type
278 |
279 | raw_datasets = DatasetDict()
280 | raw_datasets["train_book"] = bookcorpus
281 | raw_datasets["train_wiki"] = wiki_train
282 | raw_datasets["validation"] = wiki_val
283 |
284 | else:
285 | raise ValueError(f"Unknown dataset, {dataset_setup}")
286 |
287 | # Preprocessing the datasets.
288 |
289 | # Check sequence length
290 | if args.max_seq_length is None:
291 | max_seq_length = tokenizer.model_max_length
292 | if max_seq_length > 1024:
293 | logger.warning(
294 | f"The tokenizer picked seems to have a very large `model_max_length` "
295 | f"({tokenizer.model_max_length}). Picking 1024 instead. You can change that "
296 | f"default value by passing --max_seq_length xxx."
297 | )
298 | max_seq_length = 1024
299 | else:
300 | if args.max_seq_length > tokenizer.model_max_length:
301 | logger.warning(
302 | f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum "
303 | f"length for the model ({tokenizer.model_max_length}). Using "
304 | f"max_seq_length={tokenizer.model_max_length}."
305 | )
306 | max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
307 |
308 | # Tokenize all the texts.
309 | # YB: removed line-by-line option as we'll likely never use it
310 | column_names = raw_datasets["validation"].column_names
311 | text_column_name = "text" if "text" in column_names else column_names[0]
312 |
313 | # ... we tokenize every text, then concatenate them together before splitting them in smaller
314 | # parts. We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling
315 | # (see below) is more efficient when it receives the `special_tokens_mask`.
316 | def tokenize_function(examples):
317 | return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
318 |
319 | # YB: make the default bs for text pre-processing explicit
320 | tokenizer_map_batch_size = 1000
321 | with accelerator.main_process_first():
322 | tokenized_datasets = raw_datasets.map(
323 | tokenize_function,
324 | batched=True,
325 | batch_size=tokenizer_map_batch_size,
326 | writer_batch_size=tokenizer_map_batch_size,
327 | num_proc=args.preprocessing_num_workers,
328 | remove_columns=column_names,
329 | load_from_cache_file=not args.overwrite_cache,
330 | desc="Running tokenizer on every text in dataset",
331 | )
332 |
333 | # Main data processing function that will concatenate all texts from our dataset and generate
334 | # chunks of max_seq_length.
335 | def group_texts(examples):
336 | # Concatenate all texts.
337 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
338 | total_length = len(concatenated_examples[list(examples.keys())[0]])
339 | # We drop the small remainder, we could add padding if the model supported it instead of
340 | # this drop, you can customize this part to your needs.
341 | if total_length >= max_seq_length:
342 | total_length = (total_length // max_seq_length) * max_seq_length
343 | # Split by chunks of max_len.
344 | result = {
345 | k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
346 | for k, t in concatenated_examples.items()
347 | }
348 | return result
349 |
350 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts
351 | # throws away a remainder for each of those groups of 1,000 texts. You can adjust that
352 | # batch_size here but a higher value might be slower to preprocess.
353 | #
354 | # To speed up this part, we use multiprocessing. See the documentation of the map method for
355 | # more information:
356 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
357 |
358 | with accelerator.main_process_first():
359 | tokenized_datasets = tokenized_datasets.map(
360 | group_texts,
361 | batched=True,
362 | batch_size=tokenizer_map_batch_size,
363 | num_proc=args.preprocessing_num_workers,
364 | load_from_cache_file=not args.overwrite_cache,
365 | desc=f"Grouping texts in chunks of {max_seq_length}",
366 | )
367 |
368 | #
369 |
370 | if dataset_setup == DatasetSetups.bookcorpus_and_wiki:
371 | train_dataset = concatenate_datasets(
372 | [tokenized_datasets["train_book"], tokenized_datasets["train_wiki"]]
373 | )
374 | eval_dataset = tokenized_datasets["validation"]
375 | else:
376 | train_dataset = tokenized_datasets["train"]
377 | eval_dataset = tokenized_datasets["validation"]
378 |
379 | # Conditional for small test subsets
380 | if len(train_dataset) > 3:
381 | # Log a few random samples from the training set:
382 | for index in random.sample(range(len(train_dataset)), 3):
383 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
384 |
385 | # Data collator
386 | # This one will take care of randomly masking the tokens.
387 | data_collator = DataCollatorForLanguageModeling(
388 | tokenizer=tokenizer, mlm_probability=args.mlm_probability
389 | )
390 |
391 | # DataLoaders creation:
392 | train_dataloader = DataLoader(
393 | train_dataset,
394 | shuffle=True,
395 | collate_fn=data_collator,
396 | batch_size=args.per_device_train_batch_size,
397 | num_workers=args.preprocessing_num_workers,
398 | )
399 | eval_dataloader = DataLoader(
400 | eval_dataset,
401 | collate_fn=data_collator,
402 | batch_size=args.per_device_eval_batch_size,
403 | num_workers=args.preprocessing_num_workers,
404 | )
405 |
406 | # Optimizer
407 | # Split weights in two groups, one with weight decay and the other not.
408 | no_decay = ["bias", "LayerNorm.weight"]
409 | optimizer_grouped_parameters = [
410 | {
411 | "params": [
412 | p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)
413 | ],
414 | "weight_decay": args.weight_decay,
415 | },
416 | {
417 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
418 | "weight_decay": 0.0,
419 | },
420 | ]
421 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
422 |
423 | # LR Scheduler and math around the number of training steps.
424 | overrode_max_train_steps = False
425 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
426 | if args.max_train_steps is None:
427 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
428 | overrode_max_train_steps = True
429 |
430 | lr_scheduler = get_scheduler(
431 | name=args.lr_scheduler_type,
432 | optimizer=optimizer,
433 | num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
434 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
435 | )
436 |
437 | # Prepare everything with our `accelerator`.
438 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
439 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
440 | )
441 |
442 | # We need to recalculate our total training steps as the size of the training dataloader may
443 | # have changed.
444 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
445 | if overrode_max_train_steps:
446 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
447 | # Afterwards we recalculate our number of training epochs
448 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
449 |
450 | # Figure out how many steps we should save the Accelerator states
451 | checkpointing_steps = args.checkpointing_steps
452 | if checkpointing_steps is not None and checkpointing_steps.isdigit():
453 | checkpointing_steps = int(checkpointing_steps)
454 |
455 | # We need to initialize the trackers we use, and also store our configuration.
456 | # The trackers initializes automatically on the main process.
457 | if args.with_tracking:
458 | experiment_config = vars(args)
459 | # TensorBoard cannot log Enums, need the raw value
460 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
461 | accelerator.init_trackers("tb_logs", experiment_config)
462 |
463 | # Train!
464 | total_batch_size = (
465 | args.per_device_train_batch_size
466 | * accelerator.num_processes
467 | * args.gradient_accumulation_steps
468 | )
469 |
470 | logger.info("***** Running training *****")
471 | logger.info(f" Num examples = {len(train_dataset)}")
472 | logger.info(f" Num Epochs = {args.num_train_epochs}")
473 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
474 | logger.info(
475 | f" Total train batch size (w. parallel, distributed & accumulation) = "
476 | f"{total_batch_size}"
477 | )
478 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
479 | logger.info(f" Total optimization steps = {args.max_train_steps}")
480 |
481 | # Only show the progress bar once on each machine.
482 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
483 | completed_steps = 0
484 | starting_epoch = 0
485 |
486 | # Potentially load in the weights and states from a previous save
487 | if args.resume_from_checkpoint:
488 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
489 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
490 | accelerator.load_state(args.resume_from_checkpoint)
491 | path = os.path.basename(args.resume_from_checkpoint)
492 | else:
493 | # Get the most recent checkpoint
494 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
495 | dirs.sort(key=os.path.getctime)
496 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
497 | # Extract `epoch_{i}` or `step_{i}`
498 | training_difference = os.path.splitext(path)[0]
499 |
500 | if "epoch" in training_difference:
501 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1
502 | resume_step = None
503 | else:
504 | # need to multiply `gradient_accumulation_steps` to reflect real steps
505 | resume_step = (
506 | int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
507 | )
508 | starting_epoch = resume_step // len(train_dataloader)
509 | resume_step -= starting_epoch * len(train_dataloader)
510 |
511 | # update the progress_bar if load from checkpoint
512 | progress_bar.update(starting_epoch * num_update_steps_per_epoch)
513 | completed_steps = starting_epoch * num_update_steps_per_epoch
514 |
515 | # attach hooks for activation stats (if needed)
516 | if args.with_tracking:
517 | act_dict = attach_tb_act_hooks(model)
518 |
519 | # store the value of the FFN magnitude (second to last layer)
520 | num_layers = len(model.bert.encoder.layer)
521 | ffn_inf_norm = None
522 |
523 | # ** Training loop **
524 | for epoch in range(starting_epoch, args.num_train_epochs):
525 | model.train()
526 | if args.with_tracking:
527 | total_loss = 0
528 |
529 | for step, batch in enumerate(train_dataloader):
530 | # We need to skip steps until we reach the resumed step
531 | if args.resume_from_checkpoint and epoch == starting_epoch:
532 | if resume_step is not None and step < resume_step:
533 | if step % args.gradient_accumulation_steps == 0:
534 | progress_bar.update(1)
535 | completed_steps += 1
536 | continue
537 |
538 | with accelerator.accumulate(model):
539 | outputs = model(**batch)
540 | loss = outputs.loss
541 |
542 | # We keep track of the loss at each epoch
543 | if args.with_tracking:
544 | total_loss += loss.detach().float()
545 | accelerator.backward(loss)
546 |
547 | # grad clipping
548 | if (
549 | args.max_grad_norm is not None
550 | and args.max_grad_norm > 0
551 | and accelerator.sync_gradients
552 | ):
553 | accelerator.clip_grad_norm_(
554 | model.parameters(),
555 | max_norm=args.max_grad_norm,
556 | norm_type=args.grad_norm_type,
557 | )
558 |
559 | optimizer.step()
560 |
561 | if not accelerator.optimizer_step_was_skipped:
562 | # do not update LR if the grad update was skipped (because of overflow in grad
563 | # computation cause by mixed-precision)
564 | lr_scheduler.step()
565 |
566 | optimizer.zero_grad()
567 |
568 | # Checks if the accelerator has performed an optimization step behind the scenes
569 | if accelerator.sync_gradients:
570 | completed_steps += 1
571 |
572 | tqdm_update_interval = args.tqdm_update_interval
573 | if completed_steps % tqdm_update_interval == 0:
574 | progress_bar.update(tqdm_update_interval)
575 |
576 | if isinstance(checkpointing_steps, int):
577 | if completed_steps % checkpointing_steps == 0:
578 | output_dir = f"step_{completed_steps}"
579 | if args.output_dir is not None:
580 | output_dir = os.path.join(args.output_dir, output_dir)
581 | accelerator.save_state(output_dir)
582 |
583 | # TB log scalars
584 | if args.with_tracking and completed_steps % args.tb_scalar_log_interval == 0:
585 | # weights inf-norm
586 | for name, module in model.named_modules():
587 | if hasattr(module, "weight"):
588 | w = module.weight
589 | w_inf_norm = max(w.max().item(), -w.min().item())
590 | accelerator.log(
591 | {f"{name}.weight_inf_norm": w_inf_norm}, step=completed_steps
592 | )
593 |
594 | # act inf norm
595 | for name, x in act_dict.items():
596 | x_inf_norm = max(x.max().item(), -x.min().item())
597 | accelerator.log({f"{name}.act_inf_norm": x_inf_norm}, step=completed_steps)
598 |
599 | # gate probs (if present)
600 | for layer_idx in range(len(model.bert.encoder.layer)):
601 | self_attn_layer = model.bert.encoder.layer[layer_idx].attention.self
602 | if self_attn_layer.last_gate_avg_prob is not None:
603 | for head_idx in range(self_attn_layer.num_attention_heads):
604 | gate_prob = self_attn_layer.last_gate_avg_prob[head_idx].item()
605 | accelerator.log(
606 | {f"layer{layer_idx}.head{head_idx}.avg_prob": gate_prob},
607 | step=completed_steps,
608 | )
609 |
610 | # TB log histograms
611 | if (
612 | args.with_tracking
613 | and accelerator.is_main_process
614 | and completed_steps % args.tb_hist_log_interval == 0
615 | ):
616 | tb_writer = accelerator.trackers[0].writer
617 |
618 | # weight histograms
619 | for name, module in model.named_modules():
620 | if hasattr(module, "weight"):
621 | w = module.weight
622 | try:
623 | with warnings.catch_warnings():
624 | warnings.filterwarnings("ignore", category=DeprecationWarning)
625 | tb_writer.add_histogram(
626 | f"{name}.weight_hist", w, global_step=completed_steps
627 | )
628 | except:
629 | logger.warn(
630 | f"Could not log weight histogram for {name} at step {completed_steps}"
631 | )
632 |
633 | # act histograms
634 | for name, x in act_dict.items():
635 | try:
636 | with warnings.catch_warnings():
637 | warnings.filterwarnings("ignore", category=DeprecationWarning)
638 | tb_writer.add_histogram(
639 | f"{name}.act_hist", x, global_step=completed_steps
640 | )
641 | except:
642 | logger.warn(
643 | f"Could not log act histogram for {name} at step {completed_steps}"
644 | )
645 |
646 | # gate probs (if present)
647 | for layer_idx in range(len(model.bert.encoder.layer)):
648 | self_attn_layer = model.bert.encoder.layer[layer_idx].attention.self
649 | if self_attn_layer.last_gate_all_probs is not None:
650 | for head_idx in range(self_attn_layer.num_attention_heads):
651 | gate_prob_head = self_attn_layer.last_gate_all_probs[:, head_idx, ...]
652 | try:
653 | with warnings.catch_warnings():
654 | warnings.filterwarnings("ignore", category=DeprecationWarning)
655 | tb_writer.add_histogram(
656 | f"layer{layer_idx}.head{head_idx}.probs",
657 | gate_prob_head,
658 | global_step=completed_steps,
659 | )
660 | except:
661 | logger.warn(
662 | f"Could not log act histogram for {name} at step {completed_steps}"
663 | )
664 |
665 | if completed_steps >= args.max_train_steps:
666 | break
667 |
668 | # ** Evaluation **
669 | model.eval()
670 | losses = []
671 | for step, batch in enumerate(eval_dataloader):
672 | with torch.no_grad():
673 | outputs = model(**batch)
674 |
675 | loss = outputs.loss
676 | loss_ = accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))
677 | losses.append(loss_)
678 |
679 | losses = torch.cat(losses)
680 | try:
681 | eval_loss = torch.mean(losses)
682 | perplexity = math.exp(eval_loss)
683 | except OverflowError:
684 | perplexity = float("inf")
685 |
686 | logger.info(f"epoch {epoch}: perplexity: {perplexity}")
687 |
688 | if args.with_tracking:
689 | accelerator.log(
690 | {
691 | "perplexity": perplexity,
692 | "eval_loss": eval_loss,
693 | "train_loss": total_loss.item() / len(train_dataloader),
694 | "epoch": epoch,
695 | "step": completed_steps,
696 | },
697 | step=completed_steps,
698 | )
699 |
700 | if args.checkpointing_steps == "epoch":
701 | output_dir = f"epoch_{epoch}"
702 | if args.output_dir is not None:
703 | output_dir = os.path.join(args.output_dir, output_dir)
704 | accelerator.save_state(output_dir)
705 |
706 | if args.with_tracking:
707 | accelerator.end_training()
708 |
709 | if args.output_dir is not None:
710 | accelerator.wait_for_everyone()
711 | unwrapped_model = accelerator.unwrap_model(model)
712 | unwrapped_model.save_pretrained(
713 | args.output_dir,
714 | is_main_process=accelerator.is_main_process,
715 | save_function=accelerator.save,
716 | )
717 | if accelerator.is_main_process:
718 | tokenizer.save_pretrained(args.output_dir)
719 |
720 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
721 | json.dump({"perplexity": perplexity, "ffn_inf_norm": ffn_inf_norm}, f)
722 |
723 |
724 | if __name__ == "__main__":
725 | main()
726 |
--------------------------------------------------------------------------------