├── quantization ├── __init__.py ├── quantizers │ ├── __init__.py │ ├── quantizer_utils.py │ ├── base_quantizers.py │ └── uniform_quantizers.py ├── qstates.py ├── utils.py ├── quantization_manager.py ├── hijacker.py ├── base_quantized_model.py ├── base_quantized_classes.py ├── autoquant_utils.py └── range_estimators.py ├── transformers_language ├── __init__.py ├── models │ ├── __init__.py │ ├── softmax.py │ ├── opt_attention.py │ └── bert_attention.py ├── dataset_setups.py ├── quant_configs.py ├── utils.py └── args.py ├── img └── bert_attention_patterns.png ├── model_configs ├── opt-1.3b.yaml ├── opt-12L12H.yaml ├── opt-350m.yaml ├── opt-6L12H.yaml └── bert-6L12H.yaml ├── docker ├── requirements.txt └── Dockerfile ├── .gitignore ├── accelerate_configs ├── 1gpu_fp16.yaml └── 1gpu_no_mp.yaml ├── scripts ├── opt_1.3b_vanilla.sh ├── opt_350m_vanilla.sh ├── opt_350m_gated_attention.sh ├── opt_1.3b_gated_attention.sh ├── bert_base_vanilla.sh ├── bert_base_clipped_softmax.sh ├── opt_125m_clipped_softmax.sh ├── bert_base_gated_attention.sh ├── opt_125m_vanilla.sh └── opt_125m_gated_attention.sh ├── setup.py ├── LICENSE ├── README.md ├── validate_mlm.py ├── validate_clm.py └── run_mlm.py /quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | -------------------------------------------------------------------------------- /transformers_language/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | -------------------------------------------------------------------------------- /transformers_language/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | -------------------------------------------------------------------------------- /img/bert_attention_patterns.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qualcomm-AI-research/outlier-free-transformers/HEAD/img/bert_attention_patterns.png -------------------------------------------------------------------------------- /model_configs/opt-1.3b.yaml: -------------------------------------------------------------------------------- 1 | max_position_embeddings: 1024 2 | ffn_dim: 8192 3 | hidden_size: 2048 4 | num_hidden_layers: 24 5 | num_attention_heads: 32 6 | init_std: 0.006 7 | dropout: 0.1 8 | -------------------------------------------------------------------------------- /model_configs/opt-12L12H.yaml: -------------------------------------------------------------------------------- 1 | max_position_embeddings: 512 2 | ffn_dim: 3072 3 | hidden_size: 768 4 | num_hidden_layers: 12 5 | num_attention_heads: 12 6 | init_std: 0.006 7 | dropout: 0.1 8 | -------------------------------------------------------------------------------- /model_configs/opt-350m.yaml: -------------------------------------------------------------------------------- 1 | max_position_embeddings: 1024 2 | ffn_dim: 4096 3 | hidden_size: 1024 4 | num_hidden_layers: 24 5 | num_attention_heads: 16 6 | init_std: 0.006 7 | dropout: 0.1 8 | -------------------------------------------------------------------------------- /model_configs/opt-6L12H.yaml: -------------------------------------------------------------------------------- 1 | max_position_embeddings: 512 2 | ffn_dim: 3072 3 | hidden_size: 768 4 | num_hidden_layers: 6 5 | num_attention_heads: 12 6 | init_std: 0.006 7 | dropout: 0.1 8 | -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate~=0.21.0 2 | datasets~=2.14.0 3 | numpy~=1.23.5 4 | pyyaml 5 | scipy~=1.10.1 6 | tensorboard~=2.14.0 7 | timm~=0.4.12 8 | tqdm 9 | transformers~=4.31.0 10 | -------------------------------------------------------------------------------- /model_configs/bert-6L12H.yaml: -------------------------------------------------------------------------------- 1 | max_position_embeddings: 256 2 | num_hidden_layers: 6 3 | num_attention_heads: 12 4 | hidden_size: 768 5 | intermediate_size: 3072 6 | gradient_checkpointing: False 7 | 8 | -------------------------------------------------------------------------------- /transformers_language/dataset_setups.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | from enum import auto 4 | 5 | from quantization.utils import BaseEnumOptions 6 | 7 | 8 | class DatasetSetups(BaseEnumOptions): 9 | wikitext_2 = auto() 10 | wikitext_103 = auto() 11 | bookcorpus_and_wiki = auto() 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCharm 2 | .idea 3 | 4 | # MAC OS 5 | .DS_Store 6 | 7 | # pytest 8 | .coverage 9 | .pytest 10 | .pytest_cache 11 | 12 | # Python 13 | *__pycache__* 14 | *.py[cod] 15 | *.cpython-38.pyc 16 | *.egg-info 17 | 18 | # Patches (not to be committed by mistake) 19 | *.patch 20 | 21 | # Notebooks 22 | *.ipynb 23 | 24 | # Plots 25 | *.pdf 26 | *.png 27 | *.jpg 28 | 29 | # Hidden files 30 | .* 31 | !.gitignore 32 | 33 | -------------------------------------------------------------------------------- /quantization/quantizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | from quantization.quantizers.uniform_quantizers import ( 4 | AsymmetricUniformQuantizer, 5 | SymmetricUniformQuantizer, 6 | ) 7 | from quantization.utils import ClassEnumOptions, MethodMap 8 | 9 | 10 | class QMethods(ClassEnumOptions): 11 | symmetric_uniform = MethodMap(SymmetricUniformQuantizer) 12 | asymmetric_uniform = MethodMap(AsymmetricUniformQuantizer) 13 | -------------------------------------------------------------------------------- /accelerate_configs/1gpu_fp16.yaml: -------------------------------------------------------------------------------- 1 | command_file: null 2 | commands: null 3 | compute_environment: LOCAL_MACHINE 4 | deepspeed_config: {} 5 | distributed_type: 'NO' 6 | downcast_bf16: 'no' 7 | fsdp_config: {} 8 | gpu_ids: all 9 | machine_rank: 0 10 | main_process_ip: null 11 | main_process_port: null 12 | main_training_function: main 13 | megatron_lm_config: {} 14 | mixed_precision: fp16 15 | num_machines: 1 16 | num_processes: 1 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_name: null 20 | tpu_zone: null 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /accelerate_configs/1gpu_no_mp.yaml: -------------------------------------------------------------------------------- 1 | command_file: null 2 | commands: null 3 | compute_environment: LOCAL_MACHINE 4 | deepspeed_config: {} 5 | distributed_type: 'NO' 6 | downcast_bf16: 'no' 7 | fsdp_config: {} 8 | gpu_ids: all 9 | machine_rank: 0 10 | main_process_ip: null 11 | main_process_port: null 12 | main_training_function: main 13 | megatron_lm_config: {} 14 | mixed_precision: 'no' 15 | num_machines: 1 16 | num_processes: 1 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_name: null 20 | tpu_zone: null 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /quantization/qstates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | from enum import auto 4 | 5 | from quantization.utils import BaseEnumOptions 6 | 7 | 8 | class Qstates(BaseEnumOptions): 9 | estimate_ranges = auto() # ranges are updated in eval and train mode 10 | fix_ranges = auto() # quantization ranges are fixed for train and eval 11 | learn_ranges = auto() # quantization params are nn.Parameters 12 | estimate_ranges_train = ( 13 | auto() 14 | ) # quantization ranges are updated during train and fixed for eval 15 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:11.3.1-base-ubuntu20.04 2 | 3 | WORKDIR /app 4 | ENV LC_ALL=C.UTF-8 5 | ENV LANG=C.UTF-8 6 | 7 | RUN apt update && \ 8 | DEBIAN_FRONTEND=noninteractive apt install --yes --no-install-recommends \ 9 | git \ 10 | less \ 11 | python-is-python3 \ 12 | python3 \ 13 | python3-pip \ 14 | tree \ 15 | vim \ 16 | && \ 17 | rm -rf /var/lib/apt/lists/* 18 | 19 | COPY docker/requirements.txt requirements.txt 20 | 21 | RUN python3 -m pip install --no-cache-dir --upgrade pip && \ 22 | python3 -m pip install --no-cache-dir --upgrade --requirement requirements.txt 23 | 24 | COPY . /app 25 | RUN python3 -m pip install --no-cache-dir /app 26 | -------------------------------------------------------------------------------- /scripts/opt_1.3b_vanilla.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \ 3 | --pad_to_max_length \ 4 | --with_tracking \ 5 | --report_to tensorboard \ 6 | --seed 1000 \ 7 | --dataset_setup bookcorpus_and_wiki \ 8 | --preprocessing_num_workers 4 \ 9 | --data_cache_dir ~/.hf_data \ 10 | --model_cache_dir ~/.hf_cache \ 11 | --model_type opt \ 12 | --tokenizer_name facebook/opt-350m \ 13 | --max_seq_length 2048 \ 14 | --block_size 512 \ 15 | --learning_rate 0.0004 \ 16 | --lr_scheduler_type linear \ 17 | --max_train_steps 100000 \ 18 | --num_warmup_steps 2000 \ 19 | --per_device_train_batch_size 14 \ 20 | --per_device_eval_batch_size 14 \ 21 | --gradient_accumulation_steps 18 \ 22 | --max_grad_norm 1.0 \ 23 | --weight_decay 0.1 \ 24 | --checkpointing_steps 1000 \ 25 | --config_path model_configs/opt-1.3b.yaml \ 26 | --attn_softmax vanilla \ 27 | --output_dir output 28 | -------------------------------------------------------------------------------- /scripts/opt_350m_vanilla.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \ 3 | --pad_to_max_length \ 4 | --with_tracking \ 5 | --report_to tensorboard \ 6 | --seed 1000 \ 7 | --dataset_setup bookcorpus_and_wiki \ 8 | --preprocessing_num_workers 4 \ 9 | --data_cache_dir ~/.hf_data \ 10 | --model_cache_dir ~/.hf_cache \ 11 | --model_type opt \ 12 | --tokenizer_name facebook/opt-350m \ 13 | --max_seq_length 2048 \ 14 | --block_size 512 \ 15 | --learning_rate 0.0004 \ 16 | --lr_scheduler_type linear \ 17 | --max_train_steps 100000 \ 18 | --num_warmup_steps 2000 \ 19 | --per_device_train_batch_size 24 \ 20 | --per_device_eval_batch_size 24 \ 21 | --gradient_accumulation_steps 11 \ 22 | --max_grad_norm 1.0 \ 23 | --weight_decay 0.1 \ 24 | --checkpointing_steps 1000 \ 25 | --config_path model_configs/opt-350m.yaml \ 26 | --attn_softmax vanilla \ 27 | --output_dir output 28 | -------------------------------------------------------------------------------- /scripts/opt_350m_gated_attention.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \ 3 | --pad_to_max_length \ 4 | --with_tracking \ 5 | --report_to tensorboard \ 6 | --seed 1000 \ 7 | --dataset_setup bookcorpus_and_wiki \ 8 | --preprocessing_num_workers 4 \ 9 | --data_cache_dir ~/.hf_data \ 10 | --model_cache_dir ~/.hf_cache \ 11 | --model_type opt \ 12 | --tokenizer_name facebook/opt-350m \ 13 | --max_seq_length 2048 \ 14 | --block_size 512 \ 15 | --learning_rate 0.0004 \ 16 | --lr_scheduler_type linear \ 17 | --max_train_steps 100000 \ 18 | --num_warmup_steps 2000 \ 19 | --per_device_train_batch_size 24 \ 20 | --per_device_eval_batch_size 24 \ 21 | --gradient_accumulation_steps 11 \ 22 | --max_grad_norm 1.0 \ 23 | --weight_decay 0.1 \ 24 | --checkpointing_steps 1000 \ 25 | --config_path model_configs/opt-350m.yaml \ 26 | --attn_gate_type conditional_per_token \ 27 | --attn_gate_init 0.25 \ 28 | --output_dir output 29 | -------------------------------------------------------------------------------- /scripts/opt_1.3b_gated_attention.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \ 3 | --pad_to_max_length \ 4 | --with_tracking \ 5 | --report_to tensorboard \ 6 | --seed 1000 \ 7 | --dataset_setup bookcorpus_and_wiki \ 8 | --preprocessing_num_workers 4 \ 9 | --data_cache_dir ~/.hf_data \ 10 | --model_cache_dir ~/.hf_cache \ 11 | --model_type opt \ 12 | --tokenizer_name facebook/opt-350m \ 13 | --max_seq_length 2048 \ 14 | --block_size 512 \ 15 | --learning_rate 0.0004 \ 16 | --lr_scheduler_type linear \ 17 | --max_train_steps 100000 \ 18 | --num_warmup_steps 2000 \ 19 | --per_device_train_batch_size 14 \ 20 | --per_device_eval_batch_size 14 \ 21 | --gradient_accumulation_steps 18 \ 22 | --max_grad_norm 1.0 \ 23 | --weight_decay 0.1 \ 24 | --checkpointing_steps 1000 \ 25 | --config_path model_configs/opt-1.3b.yaml \ 26 | --attn_gate_type conditional_per_token \ 27 | --attn_gate_linear_all_features \ 28 | --output_dir output 29 | -------------------------------------------------------------------------------- /scripts/bert_base_vanilla.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_mlm.py \ 3 | --with_tracking \ 4 | --report_to tensorboard \ 5 | --extra_tb_stats \ 6 | --seed 1000 \ 7 | --dataset_setup bookcorpus_and_wiki \ 8 | --preprocessing_num_workers 4 \ 9 | --data_cache_dir ~/.hf_data \ 10 | --model_cache_dir ~/.hf_cache \ 11 | --model_type bert \ 12 | --tokenizer_name bert-base-uncased \ 13 | --max_seq_length 128 \ 14 | --mlm_probability 0.15 \ 15 | --learning_rate 0.0001 \ 16 | --lr_scheduler_type linear \ 17 | --max_train_steps 1000000 \ 18 | --num_warmup_steps 10000 \ 19 | --per_device_train_batch_size 256 \ 20 | --per_device_eval_batch_size 256 \ 21 | --gradient_accumulation_steps 1 \ 22 | --max_grad_norm 1.0 \ 23 | --weight_decay 0.01 \ 24 | --config_name bert-base-uncased \ 25 | --checkpointing_steps 100000 \ 26 | --tb_scalar_log_interval 2000 \ 27 | --tb_hist_log_interval 100000 \ 28 | --attn_softmax vanilla \ 29 | --output_dir output 30 | -------------------------------------------------------------------------------- /scripts/bert_base_clipped_softmax.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_mlm.py \ 3 | --with_tracking \ 4 | --report_to tensorboard \ 5 | --extra_tb_stats \ 6 | --seed 1000 \ 7 | --dataset_setup bookcorpus_and_wiki \ 8 | --preprocessing_num_workers 4 \ 9 | --data_cache_dir ~/.hf_data \ 10 | --model_cache_dir ~/.hf_cache \ 11 | --model_type bert \ 12 | --tokenizer_name bert-base-uncased \ 13 | --max_seq_length 128 \ 14 | --mlm_probability 0.15 \ 15 | --learning_rate 0.0001 \ 16 | --lr_scheduler_type linear \ 17 | --max_train_steps 1000000 \ 18 | --num_warmup_steps 10000 \ 19 | --per_device_train_batch_size 256 \ 20 | --per_device_eval_batch_size 256 \ 21 | --gradient_accumulation_steps 1 \ 22 | --max_grad_norm 1.0 \ 23 | --weight_decay 0.01 \ 24 | --config_name bert-base-uncased \ 25 | --checkpointing_steps 100000 \ 26 | --tb_scalar_log_interval 2000 \ 27 | --tb_hist_log_interval 100000 \ 28 | --attn_softmax "clipped(-.025:1)" \ 29 | --output_dir output 30 | -------------------------------------------------------------------------------- /scripts/opt_125m_clipped_softmax.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \ 3 | --pad_to_max_length \ 4 | --wd_LN_gamma \ 5 | --with_tracking \ 6 | --report_to tensorboard \ 7 | --extra_tb_stats \ 8 | --seed 1000 \ 9 | --dataset_setup bookcorpus_and_wiki \ 10 | --preprocessing_num_workers 4 \ 11 | --data_cache_dir ~/.hf_data \ 12 | --model_cache_dir ~/.hf_cache \ 13 | --model_type opt \ 14 | --tokenizer_name facebook/opt-350m \ 15 | --max_seq_length 2048 \ 16 | --block_size 512 \ 17 | --learning_rate 0.0004 \ 18 | --lr_scheduler_type linear \ 19 | --max_train_steps 125000 \ 20 | --num_warmup_steps 2000 \ 21 | --per_device_train_batch_size 48 \ 22 | --per_device_eval_batch_size 48 \ 23 | --gradient_accumulation_steps 4 \ 24 | --max_grad_norm 1.0 \ 25 | --weight_decay 0.1 \ 26 | --checkpointing_steps 10000 \ 27 | --tb_scalar_log_interval 2000 \ 28 | --tb_hist_log_interval 10000 \ 29 | --config_path model_configs/opt-12L12H.yaml \ 30 | --alpha 12 \ 31 | --output_dir output 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from pathlib import Path 3 | 4 | scripts = ["run_mlm", "run_clm", "validate_mlm", "validate_clm"] 5 | bash_scripts = Path("scripts").glob("*.sh") 6 | 7 | setup( 8 | name="outlier_free_transformers", 9 | version="1.0.0", 10 | packages=[ 11 | "quantization", 12 | "quantization.quantizers", 13 | "transformers_language", 14 | "transformers_language.models", 15 | ], 16 | py_modules=scripts, 17 | scripts=[str(path) for path in bash_scripts], 18 | entry_points={"console_scripts": [f"{script} = {script}:main" for script in scripts]}, 19 | url="https://github.com/Qualcomm-AI-research/outlier-free-transformers", 20 | license="BSD 3-Clause Clear License", 21 | author="Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort", 22 | author_email="{ybond, markusn, tijmen}@qti.qualcomm.com", 23 | description='Code for "Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing"', 24 | ) 25 | -------------------------------------------------------------------------------- /scripts/bert_base_gated_attention.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_mlm.py \ 3 | --with_tracking \ 4 | --report_to tensorboard \ 5 | --extra_tb_stats \ 6 | --seed 1000 \ 7 | --dataset_setup bookcorpus_and_wiki \ 8 | --preprocessing_num_workers 4 \ 9 | --data_cache_dir ~/.hf_data \ 10 | --model_cache_dir ~/.hf_cache \ 11 | --model_type bert \ 12 | --tokenizer_name bert-base-uncased \ 13 | --max_seq_length 128 \ 14 | --mlm_probability 0.15 \ 15 | --learning_rate 0.0001 \ 16 | --lr_scheduler_type linear \ 17 | --max_train_steps 1000000 \ 18 | --num_warmup_steps 10000 \ 19 | --per_device_train_batch_size 256 \ 20 | --per_device_eval_batch_size 256 \ 21 | --gradient_accumulation_steps 1 \ 22 | --max_grad_norm 1.0 \ 23 | --weight_decay 0.01 \ 24 | --config_name bert-base-uncased \ 25 | --checkpointing_steps 100000 \ 26 | --tb_scalar_log_interval 2000 \ 27 | --tb_hist_log_interval 100000 \ 28 | --attn_gate_type conditional_per_token \ 29 | --attn_gate_mlp \ 30 | --output_dir output 31 | -------------------------------------------------------------------------------- /scripts/opt_125m_vanilla.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \ 3 | --pad_to_max_length \ 4 | --wd_LN_gamma \ 5 | --with_tracking \ 6 | --report_to tensorboard \ 7 | --extra_tb_stats \ 8 | --seed 1000 \ 9 | --dataset_setup bookcorpus_and_wiki \ 10 | --preprocessing_num_workers 4 \ 11 | --data_cache_dir ~/.hf_data \ 12 | --model_cache_dir ~/.hf_cache \ 13 | --model_type opt \ 14 | --tokenizer_name facebook/opt-350m \ 15 | --max_seq_length 2048 \ 16 | --block_size 512 \ 17 | --learning_rate 0.0004 \ 18 | --lr_scheduler_type linear \ 19 | --max_train_steps 125000 \ 20 | --num_warmup_steps 2000 \ 21 | --per_device_train_batch_size 48 \ 22 | --per_device_eval_batch_size 48 \ 23 | --gradient_accumulation_steps 4 \ 24 | --max_grad_norm 1.0 \ 25 | --weight_decay 0.1 \ 26 | --checkpointing_steps 10000 \ 27 | --tb_scalar_log_interval 2000 \ 28 | --tb_hist_log_interval 10000 \ 29 | --config_path model_configs/opt-12L12H.yaml \ 30 | --attn_softmax vanilla \ 31 | --output_dir output 32 | -------------------------------------------------------------------------------- /quantization/quantizers/quantizer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import torch 4 | from torch.autograd import Function 5 | 6 | 7 | class RoundStraightThrough(Function): 8 | @staticmethod 9 | def forward(ctx, x): 10 | return torch.round(x) 11 | 12 | @staticmethod 13 | def backward(ctx, output_grad): 14 | return output_grad 15 | 16 | 17 | class ScaleGradient(Function): 18 | @staticmethod 19 | def forward(ctx, x, scale): 20 | ctx.scale = scale 21 | return x 22 | 23 | @staticmethod 24 | def backward(ctx, output_grad): 25 | return output_grad * ctx.scale, None 26 | 27 | 28 | round_ste_func = RoundStraightThrough.apply 29 | scale_grad_func = ScaleGradient.apply 30 | 31 | 32 | class QuantizerNotInitializedError(Exception): 33 | """Raised when a quantizer has not initialized""" 34 | 35 | def __init__(self): 36 | super(QuantizerNotInitializedError, self).__init__( 37 | "Quantizer has not been initialized yet" 38 | ) 39 | -------------------------------------------------------------------------------- /scripts/opt_125m_gated_attention.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | accelerate launch --config_file accelerate_configs/1gpu_fp16.yaml run_clm.py \ 3 | --pad_to_max_length \ 4 | --wd_LN_gamma \ 5 | --with_tracking \ 6 | --report_to tensorboard \ 7 | --extra_tb_stats \ 8 | --seed 1000 \ 9 | --dataset_setup bookcorpus_and_wiki \ 10 | --preprocessing_num_workers 4 \ 11 | --data_cache_dir ~/.hf_data \ 12 | --model_cache_dir ~/.hf_cache \ 13 | --model_type opt \ 14 | --tokenizer_name facebook/opt-350m \ 15 | --max_seq_length 2048 \ 16 | --block_size 512 \ 17 | --learning_rate 0.0004 \ 18 | --lr_scheduler_type linear \ 19 | --max_train_steps 125000 \ 20 | --num_warmup_steps 2000 \ 21 | --per_device_train_batch_size 48 \ 22 | --per_device_eval_batch_size 48 \ 23 | --gradient_accumulation_steps 4 \ 24 | --max_grad_norm 1.0 \ 25 | --weight_decay 0.1 \ 26 | --checkpointing_steps 10000 \ 27 | --tb_scalar_log_interval 2000 \ 28 | --tb_hist_log_interval 10000 \ 29 | --config_path model_configs/opt-12L12H.yaml \ 30 | --attn_gate_type conditional_per_token \ 31 | --attn_gate_init 0.25 \ 32 | --output_dir output 33 | -------------------------------------------------------------------------------- /transformers_language/quant_configs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | from quantization.quantizers import QMethods 4 | from quantization.range_estimators import RangeEstimators 5 | from transformers_language.utils import DotDict 6 | 7 | 8 | def get_quant_config(): 9 | config = DotDict() 10 | config.act_quant = DotDict( 11 | { 12 | "cross_entropy_layer": None, 13 | "num_batches": 16, 14 | "options": {}, 15 | "quant_method": RangeEstimators.running_minmax, 16 | "std_dev": None, 17 | } 18 | ) 19 | config.quant = DotDict( 20 | { 21 | "act_quant": True, 22 | "n_bits": 8, 23 | "n_bits_act": 8, 24 | "num_candidates": None, 25 | "per_channel": False, 26 | "percentile": None, 27 | "quant_setup": "all", 28 | "qmethod": QMethods.symmetric_uniform, 29 | "qmethod_act": QMethods.asymmetric_uniform, 30 | "weight_quant": True, 31 | "weight_quant_method": RangeEstimators.current_minmax, 32 | } 33 | ) 34 | return config 35 | -------------------------------------------------------------------------------- /quantization/quantizers/base_quantizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | from torch import nn 4 | 5 | 6 | class QuantizerBase(nn.Module): 7 | def __init__(self, n_bits, *args, per_channel=False, act_quant=False, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self.n_bits = n_bits 10 | self.act_quant = act_quant 11 | self.per_channel = per_channel 12 | self.state = None 13 | self.x_min_fp32 = self.x_max_fp32 = None 14 | 15 | @property 16 | def is_initialized(self): 17 | raise NotImplementedError() 18 | 19 | @property 20 | def x_max(self): 21 | raise NotImplementedError() 22 | 23 | @property 24 | def symmetric(self): 25 | raise NotImplementedError() 26 | 27 | @property 28 | def x_min(self): 29 | raise NotImplementedError() 30 | 31 | def forward(self, x_float): 32 | raise NotImplementedError() 33 | 34 | def _adjust_params_per_channel(self, x): 35 | raise NotImplementedError() 36 | 37 | def set_quant_range(self, x_min, x_max): 38 | raise NotImplementedError() 39 | 40 | def extra_repr(self): 41 | return "n_bits={}, per_channel={}, is_initalized={}".format( 42 | self.n_bits, self.per_channel, self.is_initialized 43 | ) 44 | 45 | def reset(self): 46 | self._delta = None 47 | 48 | def fix_ranges(self): 49 | raise NotImplementedError() 50 | 51 | def make_range_trainable(self): 52 | raise NotImplementedError() 53 | -------------------------------------------------------------------------------- /quantization/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | from collections import namedtuple 4 | from enum import Flag, auto 5 | from functools import partial 6 | 7 | import numpy as np 8 | 9 | 10 | def to_numpy(tensor): 11 | """ 12 | Helper function that turns the given tensor into a numpy array 13 | 14 | Parameters 15 | ---------- 16 | tensor : torch.Tensor 17 | 18 | Returns 19 | ------- 20 | tensor : float or np.array 21 | 22 | """ 23 | if isinstance(tensor, np.ndarray): 24 | return tensor 25 | if hasattr(tensor, "is_cuda"): 26 | if tensor.is_cuda: 27 | return tensor.cpu().detach().numpy() 28 | if hasattr(tensor, "detach"): 29 | return tensor.detach().numpy() 30 | if hasattr(tensor, "numpy"): 31 | return tensor.numpy() 32 | 33 | return np.array(tensor) 34 | 35 | 36 | class BaseEnumOptions(Flag): 37 | def __str__(self): 38 | return self.name 39 | 40 | @classmethod 41 | def list_names(cls): 42 | return [m.name for m in cls] 43 | 44 | 45 | class ClassEnumOptions(BaseEnumOptions): 46 | @property 47 | def cls(self): 48 | return self.value.cls 49 | 50 | def __call__(self, *args, **kwargs): 51 | return self.value.cls(*args, **kwargs) 52 | 53 | 54 | class StopForwardException(Exception): 55 | """Used to throw and catch an exception to stop traversing the graph.""" 56 | 57 | pass 58 | 59 | 60 | MethodMap = partial(namedtuple("MethodMap", ["value", "cls"]), auto()) 61 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, are permitted 6 | (subject to the limitations in the disclaimer below) provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this list of conditions 9 | and the following disclaimer: 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions 12 | and the following disclaimer in the documentation and/or other materials provided with the 13 | istribution. 14 | 15 | * Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to 16 | endorse or promote products derived from this software without specific prior written permission. 17 | 18 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS 19 | SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED 20 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 24 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF 26 | THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /transformers_language/models/softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | from functools import partial 4 | 5 | import torch 6 | 7 | 8 | def clipped_softmax(data, dim=1, eta=1.1, gamma=-0.1, **kw): 9 | sm_out = torch.nn.functional.softmax(data, dim=dim, **kw) 10 | stretched_out = sm_out * (eta - gamma) + gamma 11 | return torch.clip(stretched_out, 0, 1) 12 | 13 | 14 | SOFTMAX_MAPPING = { 15 | "vanilla": torch.nn.functional.softmax, 16 | # Clipped softmax 17 | "clipped(0:1.0003)": partial(clipped_softmax, gamma=0, eta=1.0003), 18 | "clipped(0:1.001)": partial(clipped_softmax, gamma=0, eta=1.001), 19 | "clipped(0:1.002)": partial(clipped_softmax, gamma=0, eta=1.002), 20 | "clipped(0:1.003)": partial(clipped_softmax, gamma=0, eta=1.003), 21 | "clipped(0:1.004)": partial(clipped_softmax, gamma=0, eta=1.004), 22 | "clipped(0:1.01)": partial(clipped_softmax, gamma=0, eta=1.01), 23 | "clipped(0:1.02)": partial(clipped_softmax, gamma=0, eta=1.02), 24 | "clipped(0:1.03)": partial(clipped_softmax, gamma=0, eta=1.03), 25 | "clipped(0:1.1)": partial(clipped_softmax, gamma=0, eta=1.1), 26 | "clipped(-.1:1)": partial(clipped_softmax, gamma=-0.1, eta=1.0), 27 | "clipped(-.00001:1)": partial(clipped_softmax, gamma=-0.00001, eta=1.0), 28 | "clipped(-.00003:1)": partial(clipped_softmax, gamma=-0.00003, eta=1.0), 29 | "clipped(-.0001:1)": partial(clipped_softmax, gamma=-0.0001, eta=1.0), 30 | "clipped(-.0003:1)": partial(clipped_softmax, gamma=-0.0003, eta=1.0), 31 | "clipped(-.0005:1)": partial(clipped_softmax, gamma=-0.0005, eta=1.0), 32 | "clipped(-.001:1)": partial(clipped_softmax, gamma=-0.001, eta=1.0), 33 | "clipped(-.002:1)": partial(clipped_softmax, gamma=-0.002, eta=1.0), 34 | "clipped(-.0025:1)": partial(clipped_softmax, gamma=-0.0025, eta=1.0), 35 | "clipped(-.003:1)": partial(clipped_softmax, gamma=-0.003, eta=1.0), 36 | "clipped(-.004:1)": partial(clipped_softmax, gamma=-0.004, eta=1.0), 37 | "clipped(-.005:1)": partial(clipped_softmax, gamma=-0.005, eta=1.0), 38 | "clipped(-.01:1)": partial(clipped_softmax, gamma=-0.01, eta=1.0), 39 | "clipped(-.015:1)": partial(clipped_softmax, gamma=-0.015, eta=1.0), 40 | "clipped(-.02:1)": partial(clipped_softmax, gamma=-0.02, eta=1.0), 41 | "clipped(-.025:1)": partial(clipped_softmax, gamma=-0.025, eta=1.0), 42 | "clipped(-.03:1)": partial(clipped_softmax, gamma=-0.03, eta=1.0), 43 | "clipped(-.04:1)": partial(clipped_softmax, gamma=-0.04, eta=1.0), 44 | "clipped(-.001:1.001)": partial(clipped_softmax, gamma=-0.001, eta=1.001), 45 | "clipped(-.002:1.002)": partial(clipped_softmax, gamma=-0.002, eta=1.002), 46 | "clipped(-.003:1.003)": partial(clipped_softmax, gamma=-0.003, eta=1.003), 47 | "clipped(-.005:1.005)": partial(clipped_softmax, gamma=-0.003, eta=1.005), 48 | "clipped(-.01:1.01)": partial(clipped_softmax, gamma=-0.01, eta=1.01), 49 | "clipped(-.03:1.03)": partial(clipped_softmax, gamma=-0.03, eta=1.03), 50 | "clipped(-.1:1.1)": partial(clipped_softmax, gamma=-0.1, eta=1.1), 51 | } 52 | -------------------------------------------------------------------------------- /transformers_language/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import torch 4 | import torch.nn as nn 5 | 6 | from quantization.range_estimators import RangeEstimators 7 | from quantization.utils import StopForwardException 8 | 9 | 10 | def kurtosis(x, eps=1e-6): 11 | """x - (B, d)""" 12 | mu = x.mean(dim=1, keepdims=True) 13 | s = x.std(dim=1) 14 | mu4 = ((x - mu) ** 4.0).mean(dim=1) 15 | k = mu4 / (s**4.0 + eps) 16 | return k 17 | 18 | 19 | def count_params(module): 20 | return len(nn.utils.parameters_to_vector(module.parameters())) 21 | 22 | 23 | def val_qparams(config): 24 | weight_range_options = {} 25 | if config.quant.weight_quant_method == RangeEstimators.MSE: 26 | weight_range_options = dict(opt_method=config.quant.weight_opt_method) 27 | if config.quant.num_candidates is not None: 28 | weight_range_options["num_candidates"] = config.quant.num_candidates 29 | 30 | params = { 31 | "method": config.quant.qmethod.cls, 32 | "n_bits": config.quant.n_bits, 33 | "n_bits_act": config.quant.n_bits_act, 34 | "act_method": config.quant.qmethod_act.cls, 35 | "per_channel_weights": config.quant.per_channel, 36 | "percentile": config.quant.percentile, 37 | "quant_setup": config.quant.quant_setup, 38 | "weight_range_method": config.quant.weight_quant_method.cls, 39 | "weight_range_options": weight_range_options, 40 | "act_range_method": config.act_quant.quant_method.cls, 41 | "act_range_options": config.act_quant.options, 42 | } 43 | return params 44 | 45 | 46 | def pass_data_for_range_estimation(loader, model, act_quant=None, max_num_batches=20, inp_idx=0): 47 | model.eval() 48 | batches = [] 49 | device = next(model.parameters()).device 50 | with torch.no_grad(): 51 | for i, data in enumerate(loader): 52 | try: 53 | if isinstance(data, (tuple, list)): 54 | x = data[inp_idx].to(device=device) 55 | batches.append(x.data.cpu().numpy()) 56 | model(x) 57 | print(f"proccesed step={i}") 58 | else: 59 | x = {k: v.to(device=device) for k, v in data.items()} 60 | model(**x) 61 | print(f"proccesed step={i}") 62 | 63 | if i >= max_num_batches - 1 or not act_quant: 64 | break 65 | except StopForwardException: 66 | pass 67 | return batches 68 | 69 | 70 | class DotDict(dict): 71 | """ 72 | This class enables access to its attributes as both ['attr'] and .attr . 73 | Its advantage is that content of its `instance` can be accessed with `.` 74 | and still passed to functions as `**instance` (as dictionaries) for 75 | implementing variable-length arguments. 76 | """ 77 | 78 | def __setattr__(self, key, value): 79 | self.__setitem__(key, value) 80 | 81 | def __delattr__(self, key): 82 | self.__delitem__(key) 83 | 84 | def __getattr__(self, key): 85 | if key in self: 86 | return self.__getitem__(key) 87 | raise AttributeError(f"DotDict instance has no key '{key}' ({self.keys()})") 88 | -------------------------------------------------------------------------------- /quantization/quantization_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | from torch import nn 4 | 5 | from quantization.qstates import Qstates 6 | from quantization.quantizers import QMethods 7 | from quantization.quantizers.base_quantizers import QuantizerBase 8 | from quantization.quantizers.quantizer_utils import QuantizerNotInitializedError 9 | from quantization.range_estimators import RangeEstimatorBase, RangeEstimators 10 | 11 | 12 | class QuantizationManager(nn.Module): 13 | """Implementation of Quantization and Quantization Range Estimation 14 | 15 | Parameters 16 | ---------- 17 | n_bits: int 18 | Number of bits for the quantization. 19 | qmethod: QMethods member (Enum) 20 | The quantization scheme to use, e.g. symmetric_uniform, asymmetric_uniform, 21 | qmn_uniform etc. 22 | init: RangeEstimators member (Enum) 23 | Initialization method for the grid from 24 | per_channel: bool 25 | If true, will use a separate quantization grid for each kernel/channel. 26 | x_min: float or PyTorch Tensor 27 | The minimum value which needs to be represented. 28 | x_max: float or PyTorch Tensor 29 | The maximum value which needs to be represented. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | qmethod: QuantizerBase = QMethods.symmetric_uniform.cls, 35 | init: RangeEstimatorBase = RangeEstimators.current_minmax.cls, 36 | per_channel=False, 37 | x_min=None, 38 | x_max=None, 39 | qparams=None, 40 | init_params=None, 41 | ): 42 | super().__init__() 43 | self.state = Qstates.estimate_ranges 44 | self.qmethod = qmethod 45 | self.init = init 46 | self.per_channel = per_channel 47 | self.qparams = qparams if qparams else {} 48 | self.init_params = init_params if init_params else {} 49 | self.range_estimator = None 50 | 51 | # define quantizer 52 | self.quantizer = self.qmethod(per_channel=self.per_channel, **qparams) 53 | self.quantizer.state = self.state 54 | 55 | # define range estimation method for quantizer initialisation 56 | if x_min is not None and x_max is not None: 57 | self.set_quant_range(x_min, x_max) 58 | self.fix_ranges() 59 | else: 60 | # set up the collector function to set the ranges 61 | self.range_estimator = self.init( 62 | per_channel=self.per_channel, quantizer=self.quantizer, **self.init_params 63 | ) 64 | 65 | @property 66 | def n_bits(self): 67 | return self.quantizer.n_bits 68 | 69 | def estimate_ranges(self): 70 | self.state = Qstates.estimate_ranges 71 | self.quantizer.state = self.state 72 | 73 | def fix_ranges(self): 74 | if self.quantizer.is_initialized: 75 | self.state = Qstates.fix_ranges 76 | self.quantizer.state = self.state 77 | self.quantizer.fix_ranges() 78 | else: 79 | raise QuantizerNotInitializedError() 80 | 81 | def learn_ranges(self): 82 | self.quantizer.make_range_trainable() 83 | self.state = Qstates.learn_ranges 84 | self.quantizer.state = self.state 85 | 86 | def estimate_ranges_train(self): 87 | self.state = Qstates.estimate_ranges_train 88 | self.quantizer.state = self.state 89 | 90 | def reset_ranges(self): 91 | self.range_estimator.reset() 92 | self.quantizer.reset() 93 | self.estimate_ranges() 94 | 95 | def forward(self, x): 96 | if self.state == Qstates.estimate_ranges or ( 97 | self.state == Qstates.estimate_ranges_train and self.training 98 | ): 99 | # Note this can be per tensor or per channel 100 | cur_xmin, cur_xmax = self.range_estimator(x) 101 | self.set_quant_range(cur_xmin, cur_xmax) 102 | 103 | return self.quantizer(x) 104 | 105 | def set_quant_range(self, x_min, x_max): 106 | self.quantizer.set_quant_range(x_min, x_max) 107 | 108 | def extra_repr(self): 109 | return "state={}".format(self.state.name) 110 | -------------------------------------------------------------------------------- /quantization/hijacker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import copy 4 | 5 | import torch 6 | from timm.models.layers import Swish 7 | from timm.models.layers.activations_me import SwishMe 8 | from torch import nn 9 | 10 | from quantization.base_quantized_classes import QuantizedModule 11 | from quantization.quantization_manager import QuantizationManager 12 | from quantization.range_estimators import RangeEstimators 13 | from quantization.utils import to_numpy 14 | 15 | activations_set = [ 16 | nn.ReLU, 17 | nn.ReLU6, 18 | nn.Hardtanh, 19 | nn.Sigmoid, 20 | nn.Tanh, 21 | nn.GELU, 22 | nn.PReLU, 23 | Swish, 24 | SwishMe, 25 | ] 26 | 27 | 28 | class QuantizationHijacker(QuantizedModule): 29 | """Mixin class that 'hijacks' the forward pass in a module to perform quantization and 30 | dequantization on the weights and output distributions. 31 | 32 | Usage: 33 | To make a quantized nn.Linear layer: 34 | class HijackedLinear(QuantizationHijacker, nn.Linear): 35 | pass 36 | 37 | It is vital that QSchemeForwardHijacker is the first parent class, and that the second parent 38 | class derives from nn.Module, otherwise it will not be reached by a super(., .) call. 39 | 40 | NB: this implementation (for now) assumes that there will always be some training involved, 41 | e.g. to estimate the activation ranges. 42 | """ 43 | 44 | def __init__(self, *args, activation: nn.Module = None, **kwargs): 45 | super().__init__(*args, **kwargs) 46 | if activation: 47 | assert isinstance(activation, tuple(activations_set)), str(activation) 48 | 49 | self.activation_function = copy.deepcopy(activation) if activation else None 50 | 51 | self.activation_quantizer = QuantizationManager( 52 | qmethod=self.act_method, 53 | init=self.act_range_method, 54 | qparams=self.act_qparams, 55 | init_params=self.act_range_options, 56 | ) 57 | 58 | if self.weight_range_method == RangeEstimators.current_minmax: 59 | weight_init_params = dict(percentile=self.percentile) 60 | else: 61 | weight_init_params = self.weight_range_options 62 | 63 | self.weight_quantizer = QuantizationManager( 64 | qmethod=self.method, 65 | init=self.weight_range_method, 66 | per_channel=self.per_channel_weights, 67 | qparams=self.weight_qparams, 68 | init_params=weight_init_params, 69 | ) 70 | 71 | self.prune_manager = None 72 | if hasattr(self, "prune_method") and self.prune_method is not None: 73 | self.prune_manager = self.prune_method(self, **self.prune_kwargs) 74 | 75 | self.activation_save_target = None 76 | self.activation_save_name = None 77 | 78 | def forward(self, x, offsets=None): 79 | weight, bias = self.get_params() 80 | res = self.run_forward(x, weight, bias, offsets=offsets) 81 | res = self.quantize_activations(res) 82 | return res 83 | 84 | def get_params(self): 85 | if not self.training and self.cached_params: 86 | return self.cached_params 87 | 88 | weight, bias = self.get_weight_bias() 89 | 90 | if self.prune_manager is not None: 91 | weight, bias = self.prune_manager(weight, bias) 92 | 93 | if self._quant_w: 94 | weight = self.quantize_weights(weight) 95 | 96 | if self._caching and not self.training and self.cached_params is None: 97 | self.cached_params = ( 98 | torch.Tensor(to_numpy(weight)).to(weight.device), 99 | torch.Tensor(to_numpy(bias)).to(bias.device) if bias is not None else None, 100 | ) 101 | return weight, bias 102 | 103 | def quantize_weights(self, weights): 104 | return self.weight_quantizer(weights) 105 | 106 | def get_weight_bias(self): 107 | bias = None 108 | if hasattr(self, "bias"): 109 | bias = self.bias 110 | return self.weight, bias 111 | 112 | def run_forward(self, x, weight, bias, offsets=None): 113 | # Performs the actual linear operation of the layer 114 | raise NotImplementedError() 115 | 116 | def quantize_activations(self, activations): 117 | """Quantize a single activation tensor or all activations from a layer. I'm assuming that 118 | we should quantize all outputs for a layer with the same quantization scheme. 119 | """ 120 | if self.activation_function is not None: 121 | activations = self.activation_function(activations) 122 | 123 | if self.activation_save_target is not None: 124 | self.activation_save_target[self.activation_save_name] = activations.data.cpu().numpy() 125 | 126 | if self._quant_a: 127 | activations = self.activation_quantizer(activations) 128 | 129 | if self.activation_save_target is not None: 130 | self.activation_save_target[ 131 | self.activation_save_name + "_Q" 132 | ] = activations.data.cpu().numpy() 133 | 134 | return activations 135 | -------------------------------------------------------------------------------- /quantization/base_quantized_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import warnings 4 | from typing import Dict, Union 5 | 6 | import torch 7 | from torch import Tensor, nn 8 | 9 | from quantization.base_quantized_classes import ( 10 | QuantizedModule, 11 | _set_layer_estimate_ranges, 12 | _set_layer_estimate_ranges_train, 13 | _set_layer_fix_ranges, 14 | _set_layer_learn_ranges, 15 | ) 16 | from quantization.quantizers.base_quantizers import QuantizerBase 17 | 18 | 19 | class QuantizedModel(nn.Module): 20 | """ 21 | Parent class for a quantized model. This allows you to have convenience functions to put the 22 | whole model into quantization or full precision or to freeze BN. Otherwise it does not add any 23 | further functionality, so it is not a necessity that a quantized model uses this class. 24 | """ 25 | 26 | def __init__(self, input_size=(1, 3, 224, 224)): 27 | """ 28 | Parameters 29 | ---------- 30 | input_size: Tuple with the input dimension for the model (including batch dimension) 31 | """ 32 | super().__init__() 33 | self.input_size = input_size 34 | 35 | def load_state_dict( 36 | self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], strict: bool = True 37 | ): 38 | """ 39 | This function overwrites the load_state_dict of nn.Module to ensure that quantization 40 | parameters are loaded correctly for quantized model. 41 | 42 | """ 43 | quant_state_dict = { 44 | k: v for k, v in state_dict.items() if k.endswith("_quant_a") or k.endswith("_quant_w") 45 | } 46 | 47 | if quant_state_dict: 48 | # Case 1: the quantization states are stored in the state_dict 49 | super().load_state_dict(quant_state_dict, strict=False) 50 | 51 | else: 52 | # Case 2 (older models): the quantization states are NOT stored in the state_dict but 53 | # only the scale factor _delta. 54 | warnings.warn( 55 | "Old state_dict without quantization state included. Checking for " "_delta instead" 56 | ) 57 | # Add quantization flags to the state_dict 58 | for name, module in self.named_modules(): 59 | if isinstance(module, QuantizedModule): 60 | state_dict[".".join((name, "_quant_a"))] = torch.BoolTensor([False]) 61 | state_dict[".".join((name, "_quant_w"))] = torch.BoolTensor([False]) 62 | if ( 63 | ".".join((name, "activation_quantizer", "quantizer", "_delta")) 64 | in state_dict.keys() 65 | ): 66 | module.quantized_acts() 67 | state_dict[".".join((name, "_quant_a"))] = torch.BoolTensor([True]) 68 | if ( 69 | ".".join((name, "weight_quantizer", "quantizer", "_delta")) 70 | in state_dict.keys() 71 | ): 72 | module.quantized_weights() 73 | state_dict[".".join((name, "_quant_w"))] = torch.BoolTensor([True]) 74 | 75 | # Pass dummy data through quantized model to ensure all quantization parameters are 76 | # initialized with the correct dimensions (None tensors will lead to issues in state dict loading) 77 | device = next(self.parameters()).device 78 | dummy_input = torch.rand(*self.input_size, device=device) 79 | with torch.no_grad(): 80 | self.forward(dummy_input) 81 | 82 | # Load state dict 83 | super().load_state_dict(state_dict, strict) 84 | 85 | def disable_caching(self): 86 | def _fn(layer): 87 | if isinstance(layer, QuantizedModule): 88 | layer.disable_caching() 89 | 90 | self.apply(_fn) 91 | 92 | def quantized_weights(self): 93 | def _fn(layer): 94 | if isinstance(layer, QuantizedModule): 95 | layer.quantized_weights() 96 | 97 | self.apply(_fn) 98 | 99 | def full_precision_weights(self): 100 | def _fn(layer): 101 | if isinstance(layer, QuantizedModule): 102 | layer.full_precision_weights() 103 | 104 | self.apply(_fn) 105 | 106 | def quantized_acts(self): 107 | def _fn(layer): 108 | if isinstance(layer, QuantizedModule): 109 | layer.quantized_acts() 110 | 111 | self.apply(_fn) 112 | 113 | def full_precision_acts(self): 114 | def _fn(layer): 115 | if isinstance(layer, QuantizedModule): 116 | layer.full_precision_acts() 117 | 118 | self.apply(_fn) 119 | 120 | def quantized(self): 121 | def _fn(layer): 122 | if isinstance(layer, QuantizedModule): 123 | layer.quantized() 124 | 125 | self.apply(_fn) 126 | 127 | def full_precision(self): 128 | def _fn(layer): 129 | if isinstance(layer, QuantizedModule): 130 | layer.full_precision() 131 | 132 | self.apply(_fn) 133 | 134 | # Methods for switching quantizer quantization states 135 | def learn_ranges(self): 136 | self.apply(_set_layer_learn_ranges) 137 | 138 | def fix_ranges(self): 139 | self.apply(_set_layer_fix_ranges) 140 | 141 | def estimate_ranges(self): 142 | self.apply(_set_layer_estimate_ranges) 143 | 144 | def estimate_ranges_train(self): 145 | self.apply(_set_layer_estimate_ranges_train) 146 | 147 | def set_quant_state(self, weight_quant, act_quant): 148 | if act_quant: 149 | self.quantized_acts() 150 | else: 151 | self.full_precision_acts() 152 | 153 | if weight_quant: 154 | self.quantized_weights() 155 | else: 156 | self.full_precision_weights() 157 | 158 | def grad_scaling(self, grad_scaling=True): 159 | def _fn(module): 160 | if isinstance(module, QuantizerBase): 161 | module.grad_scaling = grad_scaling 162 | 163 | self.apply(_fn) 164 | -------------------------------------------------------------------------------- /quantization/base_quantized_classes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import torch 4 | from torch import nn 5 | 6 | from quantization.quantization_manager import QuantizationManager 7 | from quantization.quantizers.base_quantizers import QuantizerBase 8 | from quantization.quantizers.uniform_quantizers import AsymmetricUniformQuantizer 9 | from quantization.range_estimators import ( 10 | CurrentMinMaxEstimator, 11 | RangeEstimatorBase, 12 | RunningMinMaxEstimator, 13 | ) 14 | 15 | 16 | def _set_layer_learn_ranges(layer): 17 | if isinstance(layer, QuantizationManager): 18 | if layer.quantizer.is_initialized: 19 | layer.learn_ranges() 20 | 21 | 22 | def _set_layer_fix_ranges(layer): 23 | if isinstance(layer, QuantizationManager): 24 | if layer.quantizer.is_initialized: 25 | layer.fix_ranges() 26 | 27 | 28 | def _set_layer_estimate_ranges(layer): 29 | if isinstance(layer, QuantizationManager): 30 | layer.estimate_ranges() 31 | 32 | 33 | def _set_layer_estimate_ranges_train(layer): 34 | if isinstance(layer, QuantizationManager): 35 | if layer.quantizer.is_initialized: 36 | layer.estimate_ranges_train() 37 | 38 | 39 | class QuantizedModule(nn.Module): 40 | """ 41 | Parent class for a quantized module. It adds the basic functionality of switching the module 42 | between quantized and full precision mode. It also defines the cached parameters and handles 43 | the reset of the cache properly. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | *args, 49 | method: QuantizerBase = AsymmetricUniformQuantizer, 50 | act_method=None, 51 | weight_range_method: RangeEstimatorBase = CurrentMinMaxEstimator, 52 | act_range_method: RangeEstimatorBase = RunningMinMaxEstimator, 53 | n_bits=8, 54 | n_bits_act=None, 55 | per_channel_weights=False, 56 | percentile=None, 57 | weight_range_options=None, 58 | act_range_options=None, 59 | scale_domain="linear", 60 | bayesian_bits_kwargs=None, 61 | prune_method=None, 62 | prune_kwargs=None, 63 | **kwargs 64 | ): 65 | kwargs.pop("act_quant_dict", None) 66 | kwargs.pop("quant_dict", None) 67 | 68 | super().__init__(*args, **kwargs) 69 | 70 | self.method = method 71 | self.act_method = act_method or method 72 | self.n_bits = n_bits 73 | self.n_bits_act = n_bits_act or n_bits 74 | self.per_channel_weights = per_channel_weights 75 | self.percentile = percentile 76 | self.weight_range_method = weight_range_method 77 | self.weight_range_options = weight_range_options if weight_range_options else {} 78 | self.act_range_method = act_range_method 79 | self.act_range_options = act_range_options if act_range_options else {} 80 | self.scale_domain = scale_domain 81 | 82 | self.bayesian_bits_kwargs = bayesian_bits_kwargs or {} 83 | self.prune_method = prune_method 84 | self.prune_kwargs = prune_kwargs 85 | 86 | self.cached_params = None 87 | self._caching = True 88 | 89 | self.quant_params = None 90 | self.register_buffer("_quant_w", torch.BoolTensor([False])) 91 | self.register_buffer("_quant_a", torch.BoolTensor([False])) 92 | 93 | self.act_qparams = dict( 94 | n_bits=self.n_bits_act, 95 | scale_domain=self.scale_domain, 96 | act_quant=True, 97 | **self.bayesian_bits_kwargs 98 | ) 99 | self.weight_qparams = dict( 100 | n_bits=self.n_bits, 101 | scale_domain=self.scale_domain, 102 | act_quant=False, 103 | **self.bayesian_bits_kwargs 104 | ) 105 | 106 | @property 107 | def caching(self): 108 | return self._caching 109 | 110 | @caching.setter 111 | def caching(self, value: bool): 112 | self._caching = value 113 | if not value: 114 | self.cached_params = None 115 | 116 | def quantized_weights(self): 117 | self.cached_params = None 118 | self._quant_w = torch.BoolTensor([True]) 119 | 120 | def full_precision_weights(self): 121 | self.cached_params = None 122 | self._quant_w = torch.BoolTensor([False]) 123 | 124 | def quantized_acts(self): 125 | self._quant_a = torch.BoolTensor([True]) 126 | 127 | def full_precision_acts(self): 128 | self._quant_a = torch.BoolTensor([False]) 129 | 130 | def quantized(self): 131 | self.quantized_weights() 132 | self.quantized_acts() 133 | 134 | def full_precision(self): 135 | self.full_precision_weights() 136 | self.full_precision_acts() 137 | 138 | def get_quantizer_status(self): 139 | return dict(quant_a=self._quant_a.item(), quant_w=self._quant_w.item()) 140 | 141 | def set_quantizer_status(self, quantizer_status): 142 | if quantizer_status["quant_a"]: 143 | self.quantized_acts() 144 | else: 145 | self.full_precision_acts() 146 | 147 | if quantizer_status["quant_w"]: 148 | self.quantized_weights() 149 | else: 150 | self.full_precision_weights() 151 | 152 | def learn_ranges(self): 153 | self.apply(_set_layer_learn_ranges) 154 | 155 | def fix_ranges(self): 156 | self.apply(_set_layer_fix_ranges) 157 | 158 | def estimate_ranges(self): 159 | self.apply(_set_layer_estimate_ranges) 160 | 161 | def estimate_ranges_train(self): 162 | self.apply(_set_layer_estimate_ranges_train) 163 | 164 | def train(self, mode=True): 165 | super().train(mode) 166 | if mode: 167 | self.cached_params = None 168 | return self 169 | 170 | def _apply(self, *args, **kwargs): 171 | self.cached_params = None 172 | return super(QuantizedModule, self)._apply(*args, **kwargs) 173 | 174 | def extra_repr(self): 175 | quant_state = "weight_quant={}, act_quant={}".format( 176 | self._quant_w.item(), self._quant_a.item() 177 | ) 178 | parent_repr = super().extra_repr() 179 | return "{},\n{}".format(parent_repr, quant_state) if parent_repr else quant_state 180 | 181 | 182 | class QuantizedActivation(QuantizedModule): 183 | def __init__(self, *args, **kwargs): 184 | super().__init__(*args, **kwargs) 185 | self.activation_quantizer = QuantizationManager( 186 | qmethod=self.act_method, 187 | qparams=self.act_qparams, 188 | init=self.act_range_method, 189 | init_params=self.act_range_options, 190 | ) 191 | 192 | def quantize_activations(self, x): 193 | if self._quant_a: 194 | return self.activation_quantizer(x) 195 | else: 196 | return x 197 | 198 | def forward(self, x): 199 | return self.quantize_activations(x) 200 | 201 | 202 | class FP32Acts(nn.Module): 203 | def forward(self, x): 204 | return x 205 | 206 | def reset_ranges(self): 207 | pass 208 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Outlier-free Transformers 2 | This repository contains the implementation and LLM experiments for the paper presented in 3 | 4 | **Yelysei Bondarenko1, Markus Nagel1, Tijmen Blankevoort1, 5 | "Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing", NeurIPS 2023.** [[ArXiv]](https://arxiv.org/abs/2306.12929) 6 | 7 | 1 Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc.) 8 | 9 | 10 | ## Reference 11 | If you find our work useful, please cite 12 | ``` 13 | @article{bondarenko2023quantizable, 14 | title={Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing}, 15 | author={Bondarenko, Yelysei and Nagel, Markus and Blankevoort, Tijmen}, 16 | journal={arXiv preprint arXiv:2306.12929}, 17 | year={2023} 18 | } 19 | ``` 20 | 21 | ## Abstract 22 |

23 | 24 |

25 | 26 | Many studies have shown that modern transformer models tend to learn strong outliers in their activations, 27 | making them difficult to quantize. We show that strong outliers are related to very specific behavior of attention 28 | heads that try to learn a _"no-op"_ or just a _partial update_ of the residual. To achieve the exact zeros needed in the 29 | attention matrix for a no-update, the input to the softmax is pushed to be larger and larger during training, causing 30 | outliers in other parts of the network. 31 | 32 | Based on these observations, we propose two simple (independent) modifications to the attention mechanism - 33 | **clipped softmax** and **gated attention**. We empirically show that models pre-trained using our methods learn 34 | significantly smaller outliers while maintaining and sometimes even improving the floating-point task performance. 35 | This enables us to quantize transformers to full INT8 quantization without any additional effort. 36 | 37 | 38 | ## Repository structure 39 | ```bash 40 | . 41 | ├── accelerate_configs # HuggingFace accelerate configs 42 | ├── docker # dockerfile and requirements files 43 | ├── img 44 | ├── model_configs # YAML configs for different model sizes 45 | ├── quantization # quantization tools and functionality 46 | │   └── quantizers 47 | ├── scripts # preset scripts for pre-training 48 | └── transformers_language # source code for quantized LLMs and our methods 49 | └── models 50 | └── run_clm.py # train a Causal Language Model (e.g., OPT) 51 | └── run_mlm.py # train a Masked Language Model (e.g., BERT) 52 | └── validate_clm.py # validate a Causal Language Model (e.g., OPT) 53 | └── validate_mlm.py # validate a Masked Language Model (e.g., BERT) 54 | ``` 55 | 56 | ## How to install 57 | ### using docker (recommended) 58 | You can build and run the docker container as follows 59 | ```bash 60 | docker build -f docker/Dockerfile --tag outlier_free_transformers:latest . 61 | docker run -ti outlier_free_transformers:latest 62 | ``` 63 | 64 | 65 | ### without docker 66 | Set locale variables and add the project root directory to your pythonpath: 67 | ```bash 68 | export LC_ALL=C.UTF-8 69 | export LANG=C.UTF-8 70 | export PYTHONPATH=${PYTHONPATH}:$(realpath "$PWD") 71 | ``` 72 | 73 | Make sure to have Python ≥3.6 (tested with Python 3.8.10) and 74 | ensure the latest version of `pip` (**tested** with 23.2.1): 75 | ```bash 76 | pip install --upgrade --no-deps pip 77 | ``` 78 | 79 | Next, install PyTorch 1.11 with the appropriate CUDA version (tested with CUDA 11.3, CuDNN 8.2.0): 80 | ```bash 81 | pip install torch==1.11.0 torchvision==0.12.0 82 | ``` 83 | 84 | Install the remaining dependencies using pip: 85 | ```bash 86 | pip install -r docker/requirements.txt 87 | ``` 88 | 89 | Finally, make all scripts executable 90 | ```bash 91 | chmod +x scripts/*.sh 92 | ``` 93 | 94 | ## Pre-training commands 95 | All the training scripts (batch size, etc.) are set up to fit on a single A100 80GB GPU. 96 | 97 | | Model | Softmax | Script | 98 | |:----------|:----------------|:---------------------------------------------------------------------| 99 | | BERT-base | vanilla | [scripts/bert_base_vanilla.sh](scripts/bert_base_vanilla.sh) | 100 | | BERT-base | clipped softmax | [scripts/bert_base_clipped_softmax.sh](scripts/bert_base_clipped_softmax.sh) | 101 | | BERT-base | gated attention | [scripts/bert_base_gated_attention.sh](scripts/bert_base_gated_attention.sh) | 102 | | OPT-125m | vanilla | [scripts/opt_125m_vanilla.sh](scripts/opt_125m_vanilla.sh) | 103 | | OPT-125m | clipped softmax | [scripts/opt_125m_clipped_softmax.sh](scripts/opt_125m_clipped_softmax.sh) | 104 | | OPT-125m | gated attention | [scripts/opt_125m_gated_attention.sh](scripts/opt_125m_gated_attention.sh) | 105 | | OPT-350m | vanilla | [scripts/opt_350m_vanilla.sh](scripts/opt_350m_vanilla.sh) | 106 | | OPT-350m | gated attention | [scripts/opt_350m_gated_attention.sh](scripts/opt_350m_gated_attention.sh) | 107 | | OPT-1.3B | vanilla | [scripts/opt_1.3b_vanilla.sh](scripts/opt_1.3b_vanilla.sh) | 108 | | OPT-1.3B | gated attention | [scripts/opt_1.3b_gated_attention.sh](scripts/opt_1.3b_gated_attention.sh) | 109 | 110 | ## Validation commands 111 | After the model is trained, you can run evaluation (both floating point, and quantized) using 112 | the following commands. 113 | Make sure to pass the same softmax method arguments that were used for pre-training (e.g., `--attn_softmax vanilla`, `--attn_softmax "clipped(-.025:1)"`, `--alpha 12`, `--attn_gate_type conditional_per_token --attn_gate_mlp`, `--attn_gate_type conditional_per_token --attn_gate_init 0.25` etc.) 114 | 115 | ### FP validation for BERT models 116 | Run command: 117 | ```bash 118 | accelerate launch validate_mlm.py \ 119 | --seed 1000 \ 120 | --dataset_setup bookcorpus_and_wiki \ 121 | --preprocessing_num_workers 8 \ 122 | --model_type bert \ 123 | --max_seq_length 128 \ 124 | --mlm_probability 0.15 \ 125 | --per_device_eval_batch_size 32 \ 126 | --data_cache_dir ~/.hf_data \ 127 | --model_cache_dir ~/.hf_cache \ 128 | --model_name_or_path /path/to/saved/checkpoint \ 129 | --output_dir output_metrics 130 | ``` 131 | Expected (example) output: 132 | ``` 133 | INFO - validate_mlm - perplexity: 4.5438 134 | INFO - validate_mlm - max FFN output inf norm: 20.8 135 | INFO - validate_mlm - max FFN input + output inf norm: 47.0 136 | INFO - validate_mlm - max LN(FFN i + o) inf norm: 25.9 137 | INFO - validate_mlm - Avg Kurtosis: 72.19 138 | INFO - validate_mlm - Max Kurtosis: 197.6 139 | ``` 140 | 141 | 142 | ### INT8 validation for BERT models 143 | Run command: 144 | ```bash 145 | accelerate launch validate_mlm.py \ 146 | --quantize \ 147 | --est_num_batches 16 \ 148 | --seed 2000 \ 149 | --dataset_setup bookcorpus_and_wiki \ 150 | --preprocessing_num_workers 8 \ 151 | --model_type bert \ 152 | --max_seq_length 128 \ 153 | --mlm_probability 0.15 \ 154 | --per_device_eval_batch_size 32 \ 155 | --data_cache_dir ~/.hf_data \ 156 | --model_cache_dir ~/.hf_cache \ 157 | --model_name_or_path /path/to/saved/checkpoint \ 158 | --output_dir output_metrics 159 | ``` 160 | Expected (example) output: 161 | ``` 162 | INFO - validate_mlm - perplexity: 4.6550 163 | INFO - validate_mlm - max FFN output inf norm: 20.6 164 | INFO - validate_mlm - max FFN input + output inf norm: 47.0 165 | INFO - validate_mlm - max LN(FFN i + o) inf norm: 25.7 166 | INFO - validate_mlm - Avg Kurtosis: 70.32 167 | INFO - validate_mlm - Max Kurtosis: 188.4 168 | ``` 169 | 170 | ### FP validation for OPT models 171 | Run command: 172 | ```bash 173 | accelerate launch validate_clm.py \ 174 | --seed 1000 \ 175 | --dataset_setup bookcorpus_and_wiki \ 176 | --preprocessing_num_workers 4 \ 177 | --model_type opt \ 178 | --block_size 512 \ 179 | --per_device_eval_batch_size 4 \ 180 | --data_cache_dir ~/.hf_data \ 181 | --model_cache_dir ~/.hf_cache \ 182 | --model_name_or_path /path/to/saved/checkpoint \ 183 | --output_dir output_metrics 184 | ``` 185 | Expected (example) output: 186 | ```bash 187 | INFO - validate_clm - perplexity: 15.5449 188 | INFO - validate_clm - Max inf norm: 122.4 189 | INFO - validate_clm - Max FFN inf norm: 0.5 190 | INFO - validate_clm - Max layer inf norm: 9.1 191 | INFO - validate_clm - Avg Kurtosis: 18.26 192 | INFO - validate_clm - Max Kurtosis: 151.5 193 | INFO - validate_clm - Max Kurtosis layers: 151.5 194 | INFO - validate_clm - 195 | # (...) 196 | # detailed per-layer stats 197 | ``` 198 | 199 | ### INT8 validation for OPT models 200 | Run command: 201 | ```bash 202 | accelerate launch validate_clm.py \ 203 | --quantize \ 204 | --quant_setup fp32_head \ 205 | --ranges_acts running_minmax \ 206 | --qmethod_acts asymmetric_uniform \ 207 | --percentile 99.999 \ 208 | --est_num_batches 4 \ 209 | --seed 2000 \ 210 | --dataset_setup bookcorpus_and_wiki \ 211 | --preprocessing_num_workers 4 \ 212 | --model_type opt \ 213 | --block_size 512 \ 214 | --per_device_eval_batch_size 1 \ 215 | --data_cache_dir ~/.hf_data \ 216 | --model_cache_dir ~/.hf_cache \ 217 | --model_name_or_path /path/to/saved/checkpoint \ 218 | --output_dir output_metrics 219 | ``` 220 | Expected (example) output: 221 | ``` 222 | perplexity: 16.0132 223 | ``` 224 | -------------------------------------------------------------------------------- /quantization/autoquant_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import copy 4 | import warnings 5 | 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AvgPoolNd 9 | 10 | from quantization.base_quantized_classes import ( 11 | FP32Acts, 12 | QuantizedActivation, 13 | QuantizedModule, 14 | ) 15 | from quantization.hijacker import QuantizationHijacker, activations_set 16 | from quantization.quantization_manager import QuantizationManager 17 | 18 | 19 | class QuantLinear(QuantizationHijacker, nn.Linear): 20 | def run_forward(self, x, weight, bias, offsets=None): 21 | return F.linear(x.contiguous(), weight.contiguous(), bias=bias) 22 | 23 | 24 | class QuantizedActivationWrapper(QuantizedActivation): 25 | """ 26 | Wraps over a layer and quantized the activation. 27 | It also allow for tying the input and output quantizer which is helpful 28 | for layers such Average Pooling 29 | """ 30 | 31 | def __init__( 32 | self, 33 | layer, 34 | *args, 35 | tie_activation_quantizers=False, 36 | input_quantizer: QuantizationManager = None, 37 | **kwargs, 38 | ): 39 | super().__init__(*args, **kwargs) 40 | self.tie_activation_quantizers = tie_activation_quantizers 41 | if input_quantizer: 42 | assert isinstance(input_quantizer, QuantizationManager) 43 | self.activation_quantizer = input_quantizer 44 | self.layer = layer 45 | 46 | def quantize_activations_no_range_update(self, x): 47 | if self._quant_a: 48 | return self.activation_quantizer.quantizer(x) 49 | else: 50 | return x 51 | 52 | def forward(self, x): 53 | x = self.layer(x) 54 | if self.tie_activation_quantizers: 55 | # The input activation quantizer is used to quantize the activation 56 | # but without updating the quantization range 57 | return self.quantize_activations_no_range_update(x) 58 | else: 59 | return self.quantize_activations(x) 60 | 61 | def extra_repr(self): 62 | return f"tie_activation_quantizers={self.tie_activation_quantizers}" 63 | 64 | 65 | class QuantLayerNorm(QuantizationHijacker, nn.LayerNorm): 66 | def run_forward(self, x, weight, bias, offsets=None): 67 | return F.layer_norm( 68 | input=x.contiguous(), 69 | normalized_shape=self.normalized_shape, 70 | weight=weight.contiguous(), 71 | bias=bias.contiguous(), 72 | eps=self.eps, 73 | ) 74 | 75 | 76 | class QuantEmbedding(QuantizationHijacker, nn.Embedding): 77 | def __init__(self, *args, activation=None, **kwargs): 78 | super().__init__(*args, activation=activation, **kwargs) 79 | # NB: We should not (re-)quantize activations of this module, as it is a 80 | # lookup table (=weights), which is already quantized 81 | self.activation_quantizer = FP32Acts() 82 | 83 | def run_forward(self, x, weight, bias, offsets=None): 84 | return F.embedding( 85 | input=x.contiguous(), 86 | weight=weight.contiguous(), 87 | padding_idx=self.padding_idx, 88 | max_norm=self.max_norm, 89 | norm_type=self.norm_type, 90 | scale_grad_by_freq=self.scale_grad_by_freq, 91 | sparse=self.sparse, 92 | ) 93 | 94 | 95 | # Modules Map 96 | module_map = {nn.Linear: QuantLinear, nn.LayerNorm: QuantLayerNorm, nn.Embedding: QuantEmbedding} 97 | 98 | 99 | non_param_modules = (_AdaptiveAvgPoolNd, _AvgPoolNd) 100 | 101 | 102 | def next_bn(module, i): 103 | return len(module) > i + 1 and isinstance(module[i + 1], (nn.BatchNorm2d, nn.BatchNorm1d)) 104 | 105 | 106 | def get_act(module, i): 107 | # Case 1: conv + act 108 | if len(module) - i > 1 and isinstance(module[i + 1], tuple(activations_set)): 109 | return module[i + 1], i + 1 110 | 111 | # Case 2: conv + bn + act 112 | if ( 113 | len(module) - i > 2 114 | and next_bn(module, i) 115 | and isinstance(module[i + 2], tuple(activations_set)) 116 | ): 117 | return module[i + 2], i + 2 118 | 119 | # Case 3: conv + bn + X -> return false 120 | # Case 4: conv + X -> return false 121 | return None, None 122 | 123 | 124 | def get_linear_args(module): 125 | args = dict( 126 | in_features=module.in_features, 127 | out_features=module.out_features, 128 | bias=module.bias is not None, 129 | ) 130 | return args 131 | 132 | 133 | def get_layernorm_args(module): 134 | args = dict(normalized_shape=module.normalized_shape, eps=module.eps) 135 | return args 136 | 137 | 138 | def get_embedding_args(module): 139 | args = dict( 140 | num_embeddings=module.num_embeddings, 141 | embedding_dim=module.embedding_dim, 142 | padding_idx=module.padding_idx, 143 | max_norm=module.max_norm, 144 | norm_type=module.norm_type, 145 | scale_grad_by_freq=module.scale_grad_by_freq, 146 | sparse=module.sparse, 147 | ) 148 | return args 149 | 150 | 151 | def get_module_args(mod, act): 152 | if isinstance(mod, nn.Linear): 153 | kwargs = get_linear_args(mod) 154 | elif isinstance(mod, nn.LayerNorm): 155 | kwargs = get_layernorm_args(mod) 156 | elif isinstance(mod, nn.Embedding): 157 | kwargs = get_embedding_args(mod) 158 | else: 159 | raise ValueError 160 | 161 | kwargs["activation"] = act 162 | 163 | return kwargs 164 | 165 | 166 | def quant_module(module, i, **quant_params): 167 | act, _ = get_act(module, i) 168 | modtype = module_map[type(module[i])] 169 | 170 | kwargs = get_module_args(module[i], act) 171 | new_module = modtype(**kwargs, **quant_params) 172 | new_module.weight.data = module[i].weight.data.clone() 173 | 174 | if module[i].bias is not None: 175 | new_module.bias.data = module[i].bias.data.clone() 176 | 177 | return new_module, i + int(bool(act)) + 1 178 | 179 | 180 | def quantize_sequential(model, specials=None, tie_activation_quantizers=False, **quant_params): 181 | specials = specials or dict() 182 | 183 | i = 0 184 | quant_modules = [] 185 | while i < len(model): 186 | if isinstance(model[i], QuantizedModule): 187 | quant_modules.append(model[i]) 188 | 189 | elif type(model[i]) in module_map: 190 | new_module, new_i = quant_module(model, i, **quant_params) 191 | quant_modules.append(new_module) 192 | i = new_i 193 | continue 194 | 195 | elif type(model[i]) in specials: 196 | quant_modules.append(specials[type(model[i])](model[i], **quant_params)) 197 | 198 | elif isinstance(model[i], non_param_modules): 199 | # Check for last quantizer 200 | input_quantizer = None 201 | if quant_modules and isinstance(quant_modules[-1], QuantizedModule): 202 | last_layer = quant_modules[-1] 203 | input_quantizer = quant_modules[-1].activation_quantizer 204 | elif ( 205 | quant_modules 206 | and isinstance(quant_modules[-1], nn.Sequential) 207 | and isinstance(quant_modules[-1][-1], QuantizedModule) 208 | ): 209 | last_layer = quant_modules[-1][-1] 210 | input_quantizer = quant_modules[-1][-1].activation_quantizer 211 | 212 | if input_quantizer and tie_activation_quantizers: 213 | # If input quantizer is found the tie input/output act quantizers 214 | print( 215 | f"Tying input quantizer {i-1}^th layer of type {type(last_layer)} to the " 216 | f"quantized {type(model[i])} following it" 217 | ) 218 | quant_modules.append( 219 | QuantizedActivationWrapper( 220 | model[i], 221 | tie_activation_quantizers=tie_activation_quantizers, 222 | input_quantizer=input_quantizer, 223 | **quant_params, 224 | ) 225 | ) 226 | else: 227 | # Input quantizer not found 228 | quant_modules.append(QuantizedActivationWrapper(model[i], **quant_params)) 229 | if tie_activation_quantizers: 230 | warnings.warn("Input quantizer not found, so we do not tie quantizers") 231 | else: 232 | quant_modules.append(quantize_model(model[i], specials=specials, **quant_params)) 233 | i += 1 234 | return nn.Sequential(*quant_modules) 235 | 236 | 237 | def quantize_model(model, specials=None, tie_activation_quantizers=False, **quant_params): 238 | specials = specials or dict() 239 | 240 | if isinstance(model, nn.Sequential): 241 | quant_model = quantize_sequential( 242 | model, specials, tie_activation_quantizers, **quant_params 243 | ) 244 | 245 | elif type(model) in specials: 246 | quant_model = specials[type(model)](model, **quant_params) 247 | 248 | elif isinstance(model, non_param_modules): 249 | quant_model = QuantizedActivationWrapper(model, **quant_params) 250 | 251 | elif type(model) in module_map: 252 | # If we do isinstance() then we might run into issues with modules that inherit from 253 | # one of these classes, for whatever reason 254 | modtype = module_map[type(model)] 255 | kwargs = get_module_args(model, None) 256 | quant_model = modtype(**kwargs, **quant_params) 257 | 258 | quant_model.weight.data = model.weight.data 259 | if getattr(model, "bias", None) is not None: 260 | quant_model.bias.data = model.bias.data 261 | 262 | else: 263 | # Unknown type, try to quantize all child modules 264 | quant_model = copy.deepcopy(model) 265 | for name, module in quant_model._modules.items(): 266 | new_model = quantize_model(module, specials=specials, **quant_params) 267 | if new_model is not None: 268 | setattr(quant_model, name, new_model) 269 | 270 | return quant_model 271 | -------------------------------------------------------------------------------- /quantization/quantizers/uniform_quantizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import torch 4 | 5 | from quantization.quantizers.base_quantizers import QuantizerBase 6 | from quantization.quantizers.quantizer_utils import ( 7 | QuantizerNotInitializedError, 8 | round_ste_func, 9 | scale_grad_func, 10 | ) 11 | 12 | 13 | class AsymmetricUniformQuantizer(QuantizerBase): 14 | """ 15 | PyTorch Module that implements Asymmetric Uniform Quantization using STE. 16 | Quantizes its argument in the forward pass, passes the gradient 'straight 17 | through' on the backward pass, ignoring the quantization that occurred. 18 | 19 | Parameters 20 | ---------- 21 | n_bits: int 22 | Number of bits for quantization. 23 | scale_domain: str ('log', 'linear) with default='linear' 24 | Domain of scale factor 25 | per_channel: bool 26 | If True: allows for per-channel quantization 27 | """ 28 | 29 | def __init__(self, n_bits, scale_domain="linear", grad_scaling=False, eps=1e-8, **kwargs): 30 | super().__init__(n_bits=n_bits, **kwargs) 31 | 32 | assert scale_domain in ("linear", "log") 33 | self.register_buffer("_delta", None) 34 | self.register_buffer("_zero_float", None) 35 | self.scale_domain = scale_domain 36 | self.grad_scaling = grad_scaling 37 | self.eps = eps 38 | 39 | # A few useful properties 40 | @property 41 | def delta(self): 42 | if self._delta is not None: 43 | return self._delta 44 | else: 45 | raise QuantizerNotInitializedError() 46 | 47 | @property 48 | def zero_float(self): 49 | if self._zero_float is not None: 50 | return self._zero_float 51 | else: 52 | raise QuantizerNotInitializedError() 53 | 54 | @property 55 | def is_initialized(self): 56 | return self._delta is not None 57 | 58 | @property 59 | def symmetric(self): 60 | return False 61 | 62 | @property 63 | def int_min(self): 64 | # integer grid minimum 65 | return 0.0 66 | 67 | @property 68 | def int_max(self): 69 | # integer grid maximum 70 | return 2.0**self.n_bits - 1 71 | 72 | @property 73 | def scale(self): 74 | if self.scale_domain == "linear": 75 | return torch.clamp(self.delta, min=self.eps) 76 | elif self.scale_domain == "log": 77 | return torch.exp(self.delta) 78 | 79 | @property 80 | def zero_point(self): 81 | zero_point = round_ste_func(self.zero_float) 82 | zero_point = torch.clamp(zero_point, self.int_min, self.int_max) 83 | return zero_point 84 | 85 | @property 86 | def x_max(self): 87 | return self.scale * (self.int_max - self.zero_point) 88 | 89 | @property 90 | def x_min(self): 91 | return self.scale * (self.int_min - self.zero_point) 92 | 93 | def to_integer_forward(self, x_float): 94 | """ 95 | Qunatized input to its integer represantion 96 | Parameters 97 | ---------- 98 | x_float: PyTorch Float Tensor 99 | Full-precision Tensor 100 | 101 | Returns 102 | ------- 103 | x_int: PyTorch Float Tensor of integers 104 | """ 105 | if self.grad_scaling: 106 | grad_scale = self.calculate_grad_scale(x_float) 107 | scale = scale_grad_func(self.scale, grad_scale) 108 | zero_point = ( 109 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale) 110 | ) 111 | else: 112 | scale = self.scale 113 | zero_point = self.zero_point 114 | 115 | x_int = round_ste_func(x_float / scale) + zero_point 116 | x_int = torch.clamp(x_int, self.int_min, self.int_max) 117 | 118 | return x_int 119 | 120 | def forward(self, x_float): 121 | """ 122 | Quantizes (quantized to integer and the scales back to original domain) 123 | Parameters 124 | ---------- 125 | x_float: PyTorch Float Tensor 126 | Full-precision Tensor 127 | 128 | Returns 129 | ------- 130 | x_quant: PyTorch Float Tensor 131 | Quantized-Dequantized Tensor 132 | """ 133 | if self.per_channel: 134 | self._adjust_params_per_channel(x_float) 135 | 136 | if self.grad_scaling: 137 | grad_scale = self.calculate_grad_scale(x_float) 138 | scale = scale_grad_func(self.scale, grad_scale) 139 | zero_point = ( 140 | self.zero_point if self.symmetric else scale_grad_func(self.zero_point, grad_scale) 141 | ) 142 | else: 143 | scale = self.scale 144 | zero_point = self.zero_point 145 | 146 | x_int = self.to_integer_forward(x_float) 147 | x_quant = scale * (x_int - zero_point) 148 | 149 | return x_quant 150 | 151 | def calculate_grad_scale(self, quant_tensor): 152 | num_pos_levels = self.int_max # Qp in LSQ paper 153 | num_elements = quant_tensor.numel() # nfeatures or nweights in LSQ paper 154 | if self.per_channel: 155 | # In the per tensor case we do not sum the gradients over the output channel dimension 156 | num_elements /= quant_tensor.shape[0] 157 | 158 | return (num_pos_levels * num_elements) ** -0.5 # 1 / sqrt (Qn * nfeatures) 159 | 160 | def _adjust_params_per_channel(self, x): 161 | """ 162 | Adjusts the quantization parameter tensors (delta, zero_float) 163 | to the input tensor shape if they don't match 164 | Parameters 165 | ---------- 166 | x: input tensor 167 | """ 168 | if x.ndim != self.delta.ndim: 169 | new_shape = [-1] + [1] * (len(x.shape) - 1) 170 | self._delta = self.delta.view(new_shape) 171 | if self._zero_float is not None: 172 | self._zero_float = self._zero_float.view(new_shape) 173 | 174 | def _tensorize_min_max(self, x_min, x_max): 175 | """ 176 | Converts provided min max range into tensors 177 | Parameters 178 | ---------- 179 | x_min: float or PyTorch 1D tensor 180 | x_max: float or PyTorch 1D tensor 181 | 182 | Returns 183 | ------- 184 | x_min: PyTorch Tensor 0 or 1-D 185 | x_max: PyTorch Tensor 0 or 1-D 186 | """ 187 | # Ensure a torch tensor 188 | if not torch.is_tensor(x_min): 189 | x_min = torch.tensor(x_min).float() 190 | x_max = torch.tensor(x_max).float() 191 | 192 | if x_min.dim() > 0 and len(x_min) > 1 and not self.per_channel: 193 | print(x_min) 194 | print(self.per_channel) 195 | raise ValueError( 196 | "x_min and x_max must be a float or 1-D Tensor" 197 | " for per-tensor quantization (per_channel=False)" 198 | ) 199 | # Ensure we always use zero and avoid division by zero 200 | x_min = torch.min(x_min, torch.zeros_like(x_min)) 201 | x_max = torch.max(x_max, torch.ones_like(x_max) * self.eps) 202 | 203 | return x_min, x_max 204 | 205 | def set_quant_range(self, x_min, x_max): 206 | """ 207 | Instantiates the quantization parameters based on the provided 208 | min and max range 209 | 210 | Parameters 211 | ---------- 212 | x_min: tensor or float 213 | Quantization range minimum limit 214 | x_max: tensor of float 215 | Quantization range minimum limit 216 | """ 217 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max 218 | x_min, x_max = self._tensorize_min_max(x_min, x_max) 219 | self._delta = (x_max - x_min) / self.int_max 220 | self._zero_float = (-x_min / self.delta).detach() 221 | 222 | if self.scale_domain == "log": 223 | self._delta = torch.log(self.delta) 224 | 225 | self._delta = self._delta.detach() 226 | 227 | def make_range_trainable(self): 228 | # Converts trainable parameters to nn.Parameters 229 | if self.delta not in self.parameters(): 230 | self._delta = torch.nn.Parameter(self._delta) 231 | self._zero_float = torch.nn.Parameter(self._zero_float) 232 | 233 | def fix_ranges(self): 234 | # Removes trainable quantization params from nn.Parameters 235 | if self.delta in self.parameters(): 236 | _delta = self._delta.data 237 | _zero_float = self._zero_float.data 238 | del self._delta # delete the parameter 239 | del self._zero_float 240 | self.register_buffer("_delta", _delta) 241 | self.register_buffer("_zero_float", _zero_float) 242 | 243 | 244 | class SymmetricUniformQuantizer(AsymmetricUniformQuantizer): 245 | """ 246 | PyTorch Module that implements Symmetric Uniform Quantization using STE. 247 | Quantizes its argument in the forward pass, passes the gradient 'straight 248 | through' on the backward pass, ignoring the quantization that occurred. 249 | 250 | Parameters 251 | ---------- 252 | n_bits: int 253 | Number of bits for quantization. 254 | scale_domain: str ('log', 'linear) with default='linear' 255 | Domain of scale factor 256 | per_channel: bool 257 | If True: allows for per-channel quantization 258 | """ 259 | 260 | def __init__(self, *args, **kwargs): 261 | super().__init__(*args, **kwargs) 262 | self.register_buffer("_signed", None) 263 | 264 | @property 265 | def signed(self): 266 | if self._signed is not None: 267 | return self._signed.item() 268 | else: 269 | raise QuantizerNotInitializedError() 270 | 271 | @property 272 | def symmetric(self): 273 | return True 274 | 275 | @property 276 | def int_min(self): 277 | return -(2.0 ** (self.n_bits - 1)) if self.signed else 0 278 | 279 | @property 280 | def int_max(self): 281 | pos_n_bits = self.n_bits - self.signed 282 | return 2.0**pos_n_bits - 1 283 | 284 | @property 285 | def zero_point(self): 286 | return 0.0 287 | 288 | def set_quant_range(self, x_min, x_max): 289 | self.x_min_fp32, self.x_max_fp32 = x_min, x_max 290 | x_min, x_max = self._tensorize_min_max(x_min, x_max) 291 | self._signed = x_min.min() < 0 292 | 293 | x_absmax = torch.max(x_min.abs(), x_max) 294 | self._delta = x_absmax / self.int_max 295 | 296 | if self.scale_domain == "log": 297 | self._delta = torch.log(self._delta) 298 | 299 | self._delta = self._delta.detach() 300 | 301 | def make_range_trainable(self): 302 | # Converts trainable parameters to nn.Parameters 303 | if self.delta not in self.parameters(): 304 | self._delta = torch.nn.Parameter(self._delta) 305 | 306 | def fix_ranges(self): 307 | # Removes trainable quantization params from nn.Parameters 308 | if self.delta in self.parameters(): 309 | _delta = self._delta.data 310 | del self._delta # delete the parameter 311 | self.register_buffer("_delta", _delta) 312 | -------------------------------------------------------------------------------- /transformers_language/args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import argparse 4 | 5 | from transformers import MODEL_MAPPING, SchedulerType 6 | 7 | from transformers_language.dataset_setups import DatasetSetups 8 | from transformers_language.models.bert_attention import AttentionGateType 9 | from transformers_language.models.softmax import SOFTMAX_MAPPING 10 | 11 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 12 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser( 17 | description="Finetune/pre-train a transformers model on a " "MLM/CLM task" 18 | ) 19 | 20 | # *** Options from example script *** 21 | 22 | # 23 | ## Base 24 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 25 | 26 | # 27 | ## Data 28 | parser.add_argument( 29 | "--preprocessing_num_workers", 30 | type=int, 31 | default=None, 32 | help="The number of processes to use for the preprocessing.", 33 | ) 34 | parser.add_argument( 35 | "--overwrite_cache", 36 | action="store_true", 37 | help="Overwrite the cached training and evaluation sets", 38 | ) 39 | 40 | # 41 | ## Model & tokenizer 42 | parser.add_argument( 43 | "--model_type", 44 | type=str, 45 | default=None, 46 | help="Model type to use if training from scratch.", 47 | choices=MODEL_TYPES, 48 | ) 49 | parser.add_argument( 50 | "--model_name_or_path", 51 | type=str, 52 | help="Path to pretrained model or model identifier from huggingface.co/models.", 53 | required=False, 54 | ) 55 | parser.add_argument( 56 | "--config_name", 57 | type=str, 58 | default=None, 59 | help="Pretrained config name or path if not the same as model_name", 60 | ) 61 | parser.add_argument( 62 | "--tokenizer_name", 63 | type=str, 64 | default=None, 65 | help="Pretrained tokenizer name or path if not the same as model_name", 66 | ) 67 | parser.add_argument( 68 | "--use_slow_tokenizer", 69 | action="store_true", 70 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 71 | ) 72 | parser.add_argument( 73 | "--pad_to_max_length", 74 | action="store_true", 75 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 76 | ) 77 | parser.add_argument( 78 | "--max_seq_length", 79 | type=int, 80 | default=None, 81 | help=( 82 | "The maximum total input sequence length after tokenization. Sequences longer than " 83 | "this will be truncated." 84 | ), 85 | ) 86 | parser.add_argument( 87 | "--block_size", 88 | type=int, 89 | default=None, 90 | help=( 91 | "Optional input sequence length after tokenization. The training dataset will be truncated in block of" 92 | " this size for training. Default to the model max input length for single sentence inputs (take into" 93 | " account special tokens)." 94 | ), 95 | ) 96 | 97 | # 98 | ## Task 99 | parser.add_argument( 100 | "--mlm_probability", 101 | type=float, 102 | default=0.15, 103 | help="Ratio of tokens to mask for masked language modeling loss", 104 | ) 105 | 106 | # 107 | ## Training 108 | parser.add_argument( 109 | "--per_device_train_batch_size", 110 | type=int, 111 | default=8, 112 | help="Batch size (per device) for the training dataloader.", 113 | ) 114 | parser.add_argument( 115 | "--per_device_eval_batch_size", 116 | type=int, 117 | default=8, 118 | help="Batch size (per device) for the evaluation dataloader.", 119 | ) 120 | parser.add_argument( 121 | "--learning_rate", 122 | type=float, 123 | default=5e-5, 124 | help="Initial learning rate (after the potential warmup period) to use.", 125 | ) 126 | parser.add_argument( 127 | "--lr_scheduler_type", 128 | type=SchedulerType, 129 | default="linear", 130 | help="The scheduler type to use.", 131 | choices=[ 132 | "linear", 133 | "cosine", 134 | "cosine_with_restarts", 135 | "polynomial", 136 | "constant", 137 | "constant_with_warmup", 138 | ], 139 | ) 140 | parser.add_argument( 141 | "--num_train_epochs", 142 | type=int, 143 | default=3, 144 | help="Total number of training epochs to perform.", 145 | ) 146 | parser.add_argument( 147 | "--max_train_steps", 148 | type=int, 149 | default=None, 150 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 151 | ) 152 | parser.add_argument( 153 | "--num_warmup_steps", 154 | type=int, 155 | default=0, 156 | help="Number of steps for the warmup in the lr scheduler.", 157 | ) 158 | parser.add_argument( 159 | "--gradient_accumulation_steps", 160 | type=int, 161 | default=1, 162 | help="Number of updates steps to accumulate before performing a backward/update pass.", 163 | ) 164 | 165 | # 166 | ## Regularization 167 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 168 | 169 | # 170 | ## Saving/loading & logging 171 | parser.add_argument( 172 | "--output_dir", type=str, default=None, help="Where to store the final model." 173 | ) 174 | parser.add_argument( 175 | "--checkpointing_steps", 176 | type=str, 177 | default=None, 178 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' " 179 | "for each epoch.", 180 | ) 181 | parser.add_argument( 182 | "--resume_from_checkpoint", 183 | type=str, 184 | default=None, 185 | help="If the training should continue from a checkpoint folder.", 186 | ) 187 | parser.add_argument( 188 | "--with_tracking", 189 | action="store_true", 190 | help="Whether to enable experiment trackers for logging.", 191 | ) 192 | parser.add_argument( 193 | "--extra_tb_stats", 194 | action="store_true", 195 | help="Whether to log extra scalars and histograms to TensorBoard.", 196 | ) 197 | parser.add_argument( 198 | "--report_to", 199 | type=str, 200 | default="all", 201 | help=( 202 | "The integration to report the results and logs to. Supported platforms are " 203 | '`"tensorboard"`, `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to ' 204 | "report to all integrations. Only applicable when `--with_tracking` is passed." 205 | ), 206 | ) 207 | parser.add_argument( 208 | "--low_cpu_mem_usage", 209 | action="store_true", 210 | help=( 211 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." 212 | "If passed, LLM loading time and RAM consumption will be benefited." 213 | ), 214 | ) 215 | 216 | # *** New options *** 217 | 218 | # 219 | ## Data 220 | parser.add_argument( 221 | "--dataset_setup", 222 | choices=DatasetSetups.list_names(), 223 | default=DatasetSetups.wikitext_103.name, 224 | help=f"The setup/preset of the datasets to use.", 225 | ) 226 | parser.add_argument( 227 | "--data_cache_dir", 228 | type=str, 229 | default="/local/mnt/workspace/.hf_data", 230 | help="Where to store data.", 231 | ) 232 | parser.add_argument( 233 | "--train_percentage", 234 | type=int, 235 | default=None, 236 | help="Percentage of training set to use.", 237 | ) 238 | parser.add_argument( 239 | "--validation_percentage", 240 | type=int, 241 | default=None, 242 | help="Percentage of validation set to use.", 243 | ) 244 | 245 | # 246 | ## Model & tokenizer 247 | parser.add_argument( 248 | "--config_path", 249 | type=str, 250 | default=None, 251 | help="Path to a yaml file with model config modifications.", 252 | ) 253 | parser.add_argument( 254 | "--model_cache_dir", 255 | type=str, 256 | default="/local/mnt/workspace/.hf_cache", 257 | help="Where to store models & tokenizers.", 258 | ) 259 | 260 | # 261 | ## Training 262 | parser.add_argument( 263 | "--final_lr_fraction", 264 | type=float, 265 | default=0.0, 266 | help="Final LR as a fraction of the maximum LR (only for CLM).", 267 | ) 268 | 269 | # 270 | ## Logging 271 | parser.add_argument( 272 | "--tqdm_update_interval", 273 | type=int, 274 | default=100, 275 | help="How often to update the progress bar. " 276 | "Note that using small value might generate large log files.", 277 | ) 278 | 279 | # 280 | ## Regularization 281 | parser.add_argument( 282 | "--max_grad_norm", 283 | type=float, 284 | default=None, 285 | help="Max gradient norm. If set to 0, no clipping will be applied.", 286 | ) 287 | parser.add_argument( 288 | "--grad_norm_type", 289 | type=float, 290 | default=2.0, 291 | help="Norm type to use for gradient clipping.", 292 | ) 293 | parser.add_argument( 294 | "--attn_dropout", 295 | type=float, 296 | default=None, 297 | help="Dropout rate to set for attention probs.", 298 | ) 299 | parser.add_argument( 300 | "--hidden_dropout", 301 | type=float, 302 | default=None, 303 | help="Dropout rate to set for hidden states.", 304 | ) 305 | 306 | # 307 | ## Logging 308 | parser.add_argument( 309 | "--tb_scalar_log_interval", 310 | type=int, 311 | default=1000, 312 | help="How often to log scalar stats of weights and activations to TensorBoard.", 313 | ) 314 | parser.add_argument( 315 | "--tb_hist_log_interval", 316 | type=int, 317 | default=10000, 318 | help="How often to log histograms of weights and activations to TensorBoard.", 319 | ) 320 | 321 | # 322 | ## Extra options 323 | parser.add_argument("--wd_LN_gamma", action="store_true") 324 | 325 | parser.add_argument( 326 | "--skip_attn", 327 | action="store_true", 328 | help="Skip attention (don't update the residual).", 329 | ) 330 | 331 | parser.add_argument( 332 | "--attn_softmax", 333 | type=str, 334 | default="vanilla", 335 | help="Softmax variation to use in attention module.", 336 | choices=SOFTMAX_MAPPING.keys(), 337 | ) 338 | parser.add_argument( 339 | "--alpha", 340 | type=float, 341 | default=None, 342 | help="If specified, use clipped softmax gamma = -alpha / seq_length.", 343 | ) 344 | parser.add_argument( 345 | "--attn_gate_type", 346 | type=str, 347 | default=AttentionGateType.none.name, 348 | help="The type of gating to use for the self-attention.", 349 | choices=AttentionGateType.list_names(), 350 | ) 351 | parser.add_argument( 352 | "--attn_gate_init", 353 | type=float, 354 | default=0.5, 355 | help="init bias s.t. the gate prob is approx this value", 356 | ) 357 | parser.add_argument( 358 | "--attn_gate_mlp", 359 | action="store_true", 360 | help="Use MLP instead of single linear layer to predict the gate.", 361 | ) 362 | parser.add_argument( 363 | "--attn_gate_mlp2", 364 | action="store_true", 365 | help="Use bigger MLP instead of single linear layer to predict the gate.", 366 | ) 367 | parser.add_argument( 368 | "--attn_gate_linear_all_features", 369 | action="store_true", 370 | help="Use Linear (d_model -> n_heads) instead of n_heads Linear's (d_head -> 1).", 371 | ) 372 | 373 | # 374 | ## Quantization 375 | parser.add_argument("--quantize", action="store_true") 376 | parser.add_argument("--est_num_batches", type=int, default=1) 377 | parser.add_argument("--n_bits", type=int, default=8) 378 | parser.add_argument("--n_bits_act", type=int, default=8) 379 | parser.add_argument("--no_weight_quant", action="store_true") 380 | parser.add_argument("--no_act_quant", action="store_true") 381 | parser.add_argument("--qmethod_acts", type=str, default="asymmetric_uniform") 382 | parser.add_argument("--ranges_weights", type=str, default="minmax") 383 | parser.add_argument("--ranges_acts", type=str, default="running_minmax") 384 | parser.add_argument( 385 | "--percentile", type=float, default=None, help="Percentile (in %) for range estimation." 386 | ) 387 | parser.add_argument("--quant_setup", type=str, default="all") 388 | 389 | # Fine-tuning 390 | parser.add_argument("--fine_tuning", action="store_true") 391 | 392 | # Parse options 393 | args = parser.parse_args() 394 | 395 | return args 396 | -------------------------------------------------------------------------------- /transformers_language/models/opt_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | from functools import partial 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | import torch.utils.checkpoint 8 | from torch import nn 9 | 10 | from transformers_language.models.bert_attention import AttentionGateType, logit 11 | from transformers_language.models.softmax import clipped_softmax 12 | 13 | 14 | class OPTAttentionWithExtras(nn.Module): 15 | """Multi-headed attention from 'Attention Is All You Need' paper""" 16 | 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | num_heads: int, 21 | dropout: float = 0.0, 22 | is_decoder: bool = False, 23 | bias: bool = True, 24 | ## new 25 | softmax_fn=torch.nn.functional.softmax, 26 | alpha=None, 27 | max_seq_length=None, 28 | ssm_eps=None, 29 | tau=None, 30 | skip_attn=False, 31 | attn_gate_type=AttentionGateType.none, 32 | attn_gate_init=None, 33 | attn_gate_mlp=False, 34 | attn_gate_mlp2=False, 35 | attn_gate_linear_all_features=False, 36 | fine_tuning=False, 37 | ): 38 | super().__init__() 39 | self.embed_dim = embed_dim 40 | self.num_heads = num_heads 41 | self.dropout = dropout 42 | self.head_dim = embed_dim // num_heads 43 | 44 | if (self.head_dim * num_heads) != self.embed_dim: 45 | raise ValueError( 46 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 47 | f" and `num_heads`: {num_heads})." 48 | ) 49 | self.scaling = self.head_dim**-0.5 50 | self.is_decoder = is_decoder 51 | 52 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 53 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 54 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 55 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 56 | 57 | # YB: capture the input and output of the softmax 58 | self.attn_scores = nn.Identity() # before attention mask 59 | self.attn_probs_before_dropout = nn.Identity() 60 | self.attn_probs_after_dropout = nn.Identity() 61 | 62 | self.alpha = alpha 63 | self.max_seq_length = max_seq_length 64 | self.ssm_eps = ssm_eps 65 | self.tau = tau 66 | 67 | # define softmax function 68 | if self.alpha is not None: 69 | assert self.max_seq_length is not None 70 | gamma = -self.alpha / self.max_seq_length 71 | self.softmax_fn = partial(clipped_softmax, gamma=gamma, eta=1.0) 72 | else: 73 | self.softmax_fn = softmax_fn 74 | 75 | self.skip_attn = skip_attn 76 | 77 | # attention gating 78 | self.last_gate_avg_prob = None 79 | self.last_gate_all_probs = None 80 | 81 | self.attn_gate_type = attn_gate_type 82 | self.attn_gate_init = attn_gate_init 83 | self.attn_gate_mlp = attn_gate_mlp 84 | self.attn_gate_mlp2 = attn_gate_mlp2 85 | self.attn_gate_linear_all_features = attn_gate_linear_all_features 86 | 87 | self.alpha = None 88 | self.ssm_eps = ssm_eps 89 | self.gate_fn = torch.sigmoid 90 | self.pooling_fn = partial(torch.mean, dim=1, keepdims=True) 91 | 92 | self.fine_tuning = fine_tuning 93 | 94 | # gate scaling factor 95 | self.gate_scaling_factor = 1.0 96 | if self.fine_tuning and self.attn_gate_init is not None: 97 | self.gate_scaling_factor = 1.0 / self.attn_gate_init 98 | 99 | # define gate 100 | if self.attn_gate_type == AttentionGateType.unconditional_per_head: 101 | init_alpha = torch.zeros(size=(self.num_heads,)) 102 | self.alpha = nn.Parameter(init_alpha, requires_grad=True) 103 | 104 | elif self.attn_gate_type in ( 105 | AttentionGateType.conditional_per_head, 106 | AttentionGateType.conditional_per_token, 107 | ): 108 | if self.attn_gate_linear_all_features: 109 | self.alpha = nn.Linear(self.embed_dim, self.num_heads, bias=True) 110 | 111 | else: # separate predictors for each head 112 | module_list = [] 113 | for _ in range(self.num_heads): 114 | if self.attn_gate_mlp: 115 | fc = nn.Sequential( 116 | nn.Linear(self.head_dim, self.head_dim // 4, bias=True), 117 | nn.ReLU(), 118 | nn.Linear(self.head_dim // 4, 1, bias=True), 119 | ) 120 | elif self.attn_gate_mlp2: 121 | fc = nn.Sequential( 122 | nn.Linear(self.head_dim, self.head_dim, bias=True), 123 | nn.ReLU(), 124 | nn.Linear(self.head_dim, 1, bias=True), 125 | ) 126 | else: 127 | fc = nn.Linear(self.head_dim, 1, bias=True) 128 | 129 | if self.attn_gate_init is not None: 130 | init_bias = logit(self.attn_gate_init) 131 | torch.nn.init.constant_(fc.bias, init_bias) 132 | 133 | if self.fine_tuning: 134 | # init to a very small values 135 | torch.nn.init.normal_(fc.weight, mean=0.0, std=0.001) 136 | 137 | module_list.append(fc) 138 | self.alpha = nn.ModuleList(module_list) 139 | 140 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 141 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 142 | 143 | def forward( 144 | self, 145 | hidden_states: torch.Tensor, 146 | key_value_states: Optional[torch.Tensor] = None, 147 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 148 | attention_mask: Optional[torch.Tensor] = None, 149 | layer_head_mask: Optional[torch.Tensor] = None, 150 | output_attentions: bool = False, 151 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 152 | """Input shape: Batch x Time x Channel""" 153 | 154 | # if key_value_states are provided this layer is used as a cross-attention layer 155 | # for the decoder 156 | is_cross_attention = key_value_states is not None 157 | 158 | bsz, tgt_len, _ = hidden_states.size() 159 | 160 | # get query proj 161 | query_states = self.q_proj(hidden_states) * self.scaling 162 | # get key, value proj 163 | if is_cross_attention and past_key_value is not None: 164 | # reuse k,v, cross_attentions 165 | key_states = past_key_value[0] 166 | value_states = past_key_value[1] 167 | elif is_cross_attention: 168 | # cross_attentions 169 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 170 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 171 | elif past_key_value is not None: 172 | # reuse k, v, self_attention 173 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 174 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 175 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 176 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 177 | else: 178 | # self_attention 179 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 180 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 181 | 182 | if self.is_decoder: 183 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 184 | # Further calls to cross_attention layer can then reuse all cross-attention 185 | # key/value_states (first "if" case) 186 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 187 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 188 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 189 | # if encoder bi-directional self-attention `past_key_value` is always `None` 190 | past_key_value = (key_states, value_states) 191 | 192 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 193 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 194 | key_states = key_states.view(*proj_shape) 195 | value_states = value_states.view(*proj_shape) 196 | 197 | src_len = key_states.size(1) 198 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 199 | 200 | # YB: for logging softmax input 201 | attn_weights = self.attn_scores(attn_weights) 202 | 203 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 204 | raise ValueError( 205 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" 206 | f" {attn_weights.size()}" 207 | ) 208 | 209 | if attention_mask is not None: 210 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 211 | raise ValueError( 212 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 213 | ) 214 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 215 | attn_weights = torch.max( 216 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 217 | ) 218 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 219 | 220 | # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 221 | if attn_weights.dtype == torch.float16: 222 | attn_weights = self.softmax_fn(attn_weights, dim=-1, dtype=torch.float32).to( 223 | torch.float16 224 | ) 225 | else: 226 | attn_weights = self.softmax_fn(attn_weights, dim=-1) 227 | 228 | if layer_head_mask is not None: 229 | if layer_head_mask.size() != (self.num_heads,): 230 | raise ValueError( 231 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" 232 | f" {layer_head_mask.size()}" 233 | ) 234 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( 235 | bsz, self.num_heads, tgt_len, src_len 236 | ) 237 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 238 | 239 | if output_attentions: 240 | # this operation is a bit awkward, but it's required to 241 | # make sure that attn_weights keeps its gradient. 242 | # In order to do so, attn_weights have to be reshaped 243 | # twice and have to be reused in the following 244 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 245 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 246 | else: 247 | attn_weights_reshaped = None 248 | 249 | # YB: for logging softmax output 250 | attn_weights = self.attn_probs_before_dropout(attn_weights) 251 | 252 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 253 | 254 | # YB: for logging softmax output 255 | attn_probs = self.attn_probs_after_dropout(attn_probs) 256 | 257 | attn_output = torch.bmm(attn_probs, value_states) 258 | 259 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 260 | raise ValueError( 261 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" 262 | f" {attn_output.size()}" 263 | ) 264 | 265 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 266 | # attn_output - (B, H, T, d_head) 267 | 268 | # 269 | # *** Gating *** 270 | if self.attn_gate_type == AttentionGateType.unconditional_per_head: 271 | gate = self.gate_fn(self.alpha) # (H,) 272 | attn_output *= gate.view(-1, 1, 1) # (B, H, T, d_head) 273 | 274 | self.last_gate_avg_prob = gate.view(-1) 275 | 276 | elif self.attn_gate_type in ( 277 | AttentionGateType.conditional_per_head, 278 | AttentionGateType.conditional_per_token, 279 | ): 280 | x = hidden_states # (B, T, d_model) 281 | 282 | if self.attn_gate_linear_all_features: # assume per_token 283 | alpha = self.alpha(x) # (B, T, H) 284 | gate = self.gate_fn(alpha) 285 | gate = gate.permute(0, 2, 1).contiguous() # (B, H, T) 286 | gate = gate.unsqueeze(3) # (B, H, T, 1) 287 | 288 | else: 289 | # x = self.transpose_for_scores(x) # (B, H, T, d_head) 290 | x = self._shape(x, -1, bsz) # (B, H, T, d_head) 291 | 292 | alpha = [] 293 | for head_idx in range(self.num_heads): 294 | x_head = x[:, head_idx, ...] # (B, T, d_head) 295 | fc_head = self.alpha[head_idx] 296 | alpha_head = fc_head(x_head) # (B, T, 1) 297 | if self.attn_gate_type == AttentionGateType.conditional_per_head: 298 | alpha_head = self.pooling_fn(alpha_head) # (B, 1, 1) 299 | alpha.append(alpha_head) 300 | alpha = torch.stack(alpha, dim=1) # (B, H, *, 1) 301 | gate = self.gate_fn(alpha) 302 | 303 | attn_output *= gate * self.gate_scaling_factor 304 | 305 | self.last_gate_all_probs = gate # all gates to see the distributions 306 | avg_gate = gate.mean(dim=0) 307 | self.last_gate_avg_prob = avg_gate.view(self.num_heads, -1).mean(dim=1) 308 | 309 | # 310 | ## 311 | 312 | attn_output = attn_output.transpose(1, 2) 313 | 314 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 315 | # partitioned aross GPUs when using tensor-parallelism. 316 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 317 | 318 | attn_output = self.out_proj(attn_output) 319 | 320 | return attn_output, attn_weights_reshaped, past_key_value 321 | -------------------------------------------------------------------------------- /transformers_language/models/bert_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import math 4 | from functools import partial 5 | from typing import Optional, Tuple 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.checkpoint 10 | from torch import nn 11 | 12 | from quantization.utils import BaseEnumOptions 13 | from transformers_language.models.softmax import clipped_softmax 14 | 15 | 16 | def logit(p, eps=1e-16): 17 | p = np.clip(p, eps, 1 - eps) 18 | return -np.log(1 / p - 1) 19 | 20 | 21 | class AttentionGateType(BaseEnumOptions): 22 | none = 0 23 | unconditional_per_head = 1 24 | conditional_per_head = 2 25 | conditional_per_token = 3 26 | 27 | 28 | class BertSelfAttentionWithExtras(nn.Module): 29 | def __init__( 30 | self, 31 | config, 32 | position_embedding_type=None, 33 | softmax_fn=torch.nn.functional.softmax, 34 | alpha=None, 35 | ssm_eps=None, 36 | tau=None, 37 | max_seq_length=None, 38 | skip_attn=False, 39 | attn_gate_type=AttentionGateType.none, 40 | attn_gate_init=None, 41 | attn_gate_mlp=False, 42 | attn_gate_mlp2=False, 43 | attn_gate_linear_all_features=False, 44 | fine_tuning=False, 45 | ): 46 | super().__init__() 47 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr( 48 | config, "embedding_size" 49 | ): 50 | raise ValueError( 51 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 52 | f"heads ({config.num_attention_heads})" 53 | ) 54 | 55 | self.num_attention_heads = config.num_attention_heads 56 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 57 | self.all_head_size = self.num_attention_heads * self.attention_head_size 58 | 59 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 60 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 61 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 62 | 63 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 64 | self.position_embedding_type = position_embedding_type or getattr( 65 | config, "position_embedding_type", "absolute" 66 | ) 67 | if ( 68 | self.position_embedding_type == "relative_key" 69 | or self.position_embedding_type == "relative_key_query" 70 | ): 71 | self.max_position_embeddings = config.max_position_embeddings 72 | self.distance_embedding = nn.Embedding( 73 | 2 * config.max_position_embeddings - 1, self.attention_head_size 74 | ) 75 | 76 | self.is_decoder = config.is_decoder 77 | 78 | # YB: capture the input and output of the softmax 79 | self.attn_scores = nn.Identity() # before attention mask 80 | self.attn_probs_before_dropout = nn.Identity() 81 | self.attn_probs_after_dropout = nn.Identity() 82 | 83 | self.alpha = alpha 84 | self.ssm_eps = ssm_eps 85 | self.tau = tau 86 | self.max_seq_length = max_seq_length 87 | 88 | # define softmax function 89 | if self.alpha is not None: 90 | assert self.max_seq_length is not None 91 | gamma = -self.alpha / self.max_seq_length 92 | self.softmax_fn = partial(clipped_softmax, gamma=gamma, eta=1.0) 93 | else: 94 | self.softmax_fn = softmax_fn 95 | 96 | self.skip_attn = skip_attn 97 | 98 | # attention gating 99 | self.last_gate_avg_prob = None 100 | self.last_gate_all_probs = None 101 | 102 | self.attn_gate_type = attn_gate_type 103 | self.attn_gate_init = attn_gate_init 104 | self.attn_gate_mlp = attn_gate_mlp 105 | self.attn_gate_mlp2 = attn_gate_mlp2 106 | self.attn_gate_linear_all_features = attn_gate_linear_all_features 107 | 108 | self.alpha = None 109 | self.gate_fn = torch.sigmoid 110 | self.pooling_fn = partial(torch.mean, dim=1, keepdims=True) 111 | 112 | self.fine_tuning = fine_tuning 113 | 114 | # gate scaling factor 115 | self.gate_scaling_factor = 1.0 116 | if self.fine_tuning and self.attn_gate_init is not None: 117 | self.gate_scaling_factor = 1.0 / self.attn_gate_init 118 | 119 | # define gate 120 | if self.attn_gate_type == AttentionGateType.unconditional_per_head: 121 | init_alpha = torch.zeros(size=(self.num_attention_heads,)) 122 | self.alpha = nn.Parameter(init_alpha, requires_grad=True) 123 | 124 | elif self.attn_gate_type in ( 125 | AttentionGateType.conditional_per_head, 126 | AttentionGateType.conditional_per_token, 127 | ): 128 | if self.attn_gate_linear_all_features: 129 | self.alpha = nn.Linear(self.all_head_size, self.num_attention_heads, bias=True) 130 | 131 | else: # separate predictors for each head 132 | module_list = [] 133 | for _ in range(self.num_attention_heads): 134 | if self.attn_gate_mlp: 135 | fc = nn.Sequential( 136 | nn.Linear( 137 | self.attention_head_size, self.attention_head_size // 4, bias=True 138 | ), 139 | nn.ReLU(), 140 | nn.Linear(self.attention_head_size // 4, 1, bias=True), 141 | ) 142 | elif self.attn_gate_mlp2: 143 | fc = nn.Sequential( 144 | nn.Linear( 145 | self.attention_head_size, self.attention_head_size, bias=True 146 | ), 147 | nn.ReLU(), 148 | nn.Linear(self.attention_head_size, 1, bias=True), 149 | ) 150 | else: 151 | fc = nn.Linear(self.attention_head_size, 1, bias=True) 152 | 153 | if self.attn_gate_init is not None: 154 | init_bias = logit(self.attn_gate_init) 155 | torch.nn.init.constant_(fc.bias, init_bias) 156 | 157 | if self.fine_tuning: 158 | # init to a very small values 159 | torch.nn.init.normal_(fc.weight, mean=0.0, std=0.01) 160 | 161 | module_list.append(fc) 162 | self.alpha = nn.ModuleList(module_list) 163 | 164 | def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: 165 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 166 | x = x.view(new_x_shape) 167 | return x.permute(0, 2, 1, 3) 168 | 169 | def forward( 170 | self, 171 | hidden_states: torch.Tensor, 172 | attention_mask: Optional[torch.FloatTensor] = None, 173 | head_mask: Optional[torch.FloatTensor] = None, 174 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 175 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 176 | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 177 | output_attentions: Optional[bool] = False, 178 | ) -> Tuple[torch.Tensor]: 179 | if self.skip_attn: 180 | out = torch.zeros_like(hidden_states) 181 | return (out,) 182 | 183 | mixed_query_layer = self.query(hidden_states) 184 | 185 | # If this is instantiated as a cross-attention module, the keys 186 | # and values come from an encoder; the attention mask needs to be 187 | # such that the encoder's padding tokens are not attended to. 188 | is_cross_attention = encoder_hidden_states is not None 189 | 190 | if is_cross_attention and past_key_value is not None: 191 | # reuse k,v, cross_attentions 192 | key_layer = past_key_value[0] 193 | value_layer = past_key_value[1] 194 | attention_mask = encoder_attention_mask 195 | elif is_cross_attention: 196 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 197 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 198 | attention_mask = encoder_attention_mask 199 | elif past_key_value is not None: 200 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 201 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 202 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 203 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 204 | else: 205 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 206 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 207 | 208 | query_layer = self.transpose_for_scores(mixed_query_layer) 209 | 210 | use_cache = past_key_value is not None 211 | if self.is_decoder: 212 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 213 | # Further calls to cross_attention layer can then reuse all cross-attention 214 | # key/value_states (first "if" case) 215 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 216 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 217 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 218 | # if encoder bi-directional self-attention `past_key_value` is always `None` 219 | past_key_value = (key_layer, value_layer) 220 | 221 | # Take the dot product between "query" and "key" to get the raw attention scores. 222 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 223 | 224 | if ( 225 | self.position_embedding_type == "relative_key" 226 | or self.position_embedding_type == "relative_key_query" 227 | ): 228 | query_length, key_length = query_layer.shape[2], key_layer.shape[2] 229 | if use_cache: 230 | position_ids_l = torch.tensor( 231 | key_length - 1, dtype=torch.long, device=hidden_states.device 232 | ).view(-1, 1) 233 | else: 234 | position_ids_l = torch.arange( 235 | query_length, dtype=torch.long, device=hidden_states.device 236 | ).view(-1, 1) 237 | position_ids_r = torch.arange( 238 | key_length, dtype=torch.long, device=hidden_states.device 239 | ).view(1, -1) 240 | distance = position_ids_l - position_ids_r 241 | 242 | positional_embedding = self.distance_embedding( 243 | distance + self.max_position_embeddings - 1 244 | ) 245 | positional_embedding = positional_embedding.to( 246 | dtype=query_layer.dtype 247 | ) # fp16 compatibility 248 | 249 | if self.position_embedding_type == "relative_key": 250 | relative_position_scores = torch.einsum( 251 | "bhld,lrd->bhlr", query_layer, positional_embedding 252 | ) 253 | attention_scores = attention_scores + relative_position_scores 254 | elif self.position_embedding_type == "relative_key_query": 255 | relative_position_scores_query = torch.einsum( 256 | "bhld,lrd->bhlr", query_layer, positional_embedding 257 | ) 258 | relative_position_scores_key = torch.einsum( 259 | "bhrd,lrd->bhlr", key_layer, positional_embedding 260 | ) 261 | attention_scores = ( 262 | attention_scores + relative_position_scores_query + relative_position_scores_key 263 | ) 264 | 265 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 266 | 267 | # YB: for logging softmax input 268 | attention_scores = self.attn_scores(attention_scores) 269 | 270 | if attention_mask is not None: 271 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 272 | attention_scores = attention_scores + attention_mask 273 | 274 | # Normalize the attention scores to probabilities. 275 | # MN: uses our own SM function as specified in the config 276 | attention_probs = self.softmax_fn(attention_scores, dim=-1) 277 | 278 | # YB: for logging softmax output 279 | attention_probs = self.attn_probs_before_dropout(attention_probs) 280 | 281 | # This is actually dropping out entire tokens to attend to, which might 282 | # seem a bit unusual, but is taken from the original Transformer paper. 283 | attention_probs = self.dropout(attention_probs) 284 | 285 | # YB: for logging softmax output 286 | attention_probs = self.attn_probs_after_dropout(attention_probs) 287 | 288 | # Mask heads if we want to 289 | if head_mask is not None: 290 | attention_probs = attention_probs * head_mask 291 | 292 | context_layer = torch.matmul(attention_probs, value_layer) 293 | 294 | # *** Gating *** 295 | if self.attn_gate_type == AttentionGateType.unconditional_per_head: 296 | gate = self.gate_fn(self.alpha) # (H,) 297 | context_layer *= gate.view(-1, 1, 1) # (B, H, T, d_head) 298 | 299 | self.last_gate_avg_prob = gate.view(-1) 300 | 301 | elif self.attn_gate_type in ( 302 | AttentionGateType.conditional_per_head, 303 | AttentionGateType.conditional_per_token, 304 | ): 305 | x = hidden_states # (B, T, d_model) 306 | 307 | if self.attn_gate_linear_all_features: # assume per_token 308 | alpha = self.alpha(x) # (B, T, H) 309 | gate = self.gate_fn(alpha) 310 | gate = gate.permute(0, 2, 1).contiguous() # (B, H, T) 311 | gate = gate.unsqueeze(3) # (B, H, T, 1) 312 | 313 | else: 314 | x = self.transpose_for_scores(x) # (B, H, T, d_head) 315 | 316 | alpha = [] 317 | for head_idx in range(self.num_attention_heads): 318 | x_head = x[:, head_idx, ...] # (B, T, d_head) 319 | fc_head = self.alpha[head_idx] 320 | alpha_head = fc_head(x_head) # (B, T, 1) 321 | if self.attn_gate_type == AttentionGateType.conditional_per_head: 322 | alpha_head = self.pooling_fn(alpha_head) # (B, 1, 1) 323 | alpha.append(alpha_head) 324 | alpha = torch.stack(alpha, dim=1) # (B, H, *, 1) 325 | gate = self.gate_fn(alpha) 326 | 327 | context_layer *= gate * self.gate_scaling_factor 328 | 329 | self.last_gate_all_probs = gate # all gates to see the distributions 330 | avg_gate = gate.mean(dim=0) 331 | self.last_gate_avg_prob = avg_gate.view(self.num_attention_heads, -1).mean(dim=1) 332 | 333 | # 334 | 335 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 336 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 337 | context_layer = context_layer.view(new_context_layer_shape) 338 | 339 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 340 | 341 | if self.is_decoder: 342 | outputs = outputs + (past_key_value,) 343 | return outputs 344 | -------------------------------------------------------------------------------- /quantization/range_estimators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | import copy 4 | from enum import auto 5 | 6 | import numpy as np 7 | import torch 8 | from scipy.optimize import minimize_scalar 9 | from torch import nn 10 | 11 | from quantization.utils import BaseEnumOptions, ClassEnumOptions, MethodMap, to_numpy 12 | 13 | 14 | class RangeEstimatorBase(nn.Module): 15 | def __init__(self, *args, per_channel=False, quantizer=None, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.register_buffer("current_xmin", None) 18 | self.register_buffer("current_xmax", None) 19 | self.per_channel = per_channel 20 | self.quantizer = quantizer 21 | 22 | def forward(self, x): 23 | """ 24 | Accepts an input tensor, updates the current estimates of x_min and x_max 25 | and retruns them. 26 | Parameters 27 | ---------- 28 | x: Input tensor 29 | 30 | Returns 31 | ------- 32 | self.current_xmin: tensor 33 | 34 | self.current_xmax: tensor 35 | 36 | """ 37 | raise NotImplementedError() 38 | 39 | def reset(self): 40 | """ 41 | Reset the range estimator. 42 | """ 43 | self.current_xmin = None 44 | self.current_xmax = None 45 | 46 | def __repr__(self): 47 | # We overwrite this from nn.Module as we do not want to have submodules such as 48 | # self.quantizer in the reproduce. Otherwise it behaves as expected for an nn.Module. 49 | lines = self.extra_repr().split("\n") 50 | extra_str = lines[0] if len(lines) == 1 else "\n " + "\n ".join(lines) + "\n" 51 | 52 | return self._get_name() + "(" + extra_str + ")" 53 | 54 | 55 | class CurrentMinMaxEstimator(RangeEstimatorBase): 56 | def __init__(self, *args, percentile=None, **kwargs): 57 | self.percentile = percentile 58 | super().__init__(*args, **kwargs) 59 | 60 | def forward(self, x): 61 | if self.per_channel: 62 | x = x.view(x.shape[0], -1) 63 | if self.percentile: 64 | axis = -1 if self.per_channel else None 65 | data_np = to_numpy(x) 66 | x_min, x_max = np.percentile( 67 | data_np, (self.percentile, 100 - self.percentile), axis=axis 68 | ) 69 | self.current_xmin = torch.tensor(x_min).to(x.device) 70 | self.current_xmax = torch.tensor(x_max).to(x.device) 71 | else: 72 | self.current_xmin = x.min(-1)[0].detach() if self.per_channel else x.min().detach() 73 | self.current_xmax = x.max(-1)[0].detach() if self.per_channel else x.max().detach() 74 | 75 | return self.current_xmin, self.current_xmax 76 | 77 | 78 | class RunningMinMaxEstimator(RangeEstimatorBase): 79 | def __init__(self, *args, momentum=0.9, percentile=None, **kwargs): 80 | self.momentum = momentum 81 | self.percentile = percentile 82 | super().__init__(*args, **kwargs) 83 | 84 | def forward(self, x): 85 | if self.per_channel: 86 | # Along 1st dim 87 | x_flattened = x.view(x.shape[0], -1) 88 | x_min = x_flattened.min(-1)[0].detach() 89 | x_max = x_flattened.max(-1)[0].detach() 90 | elif self.percentile: 91 | data_np = to_numpy(x) 92 | x_min, x_max = np.percentile(data_np, (100 - self.percentile, self.percentile)) 93 | 94 | x_min = torch.tensor(x_min).to(x.device) 95 | x_max = torch.tensor(x_max).to(x.device) 96 | else: 97 | x_min = torch.min(x).detach() 98 | x_max = torch.max(x).detach() 99 | 100 | if self.current_xmin is None: 101 | self.current_xmin = x_min 102 | self.current_xmax = x_max 103 | else: 104 | self.current_xmin = (1 - self.momentum) * x_min + self.momentum * self.current_xmin 105 | self.current_xmax = (1 - self.momentum) * x_max + self.momentum * self.current_xmax 106 | 107 | return self.current_xmin, self.current_xmax 108 | 109 | 110 | class OptMethod(BaseEnumOptions): 111 | grid = auto() 112 | golden_section = auto() 113 | 114 | 115 | class MSE_Estimator(RangeEstimatorBase): 116 | def __init__( 117 | self, *args, num_candidates=100, opt_method=OptMethod.grid, range_margin=0.5, **kwargs 118 | ): 119 | super().__init__(*args, **kwargs) 120 | assert opt_method in OptMethod 121 | self.opt_method = opt_method 122 | self.num_candidates = num_candidates 123 | self.loss_array = None 124 | self.max_pos_thr = None 125 | self.max_neg_thr = None 126 | self.max_search_range = None 127 | self.one_sided_dist = None 128 | self.range_margin = range_margin 129 | if self.quantizer is None: 130 | raise NotImplementedError( 131 | "A Quantizer must be given as an argument to the MSE Range" "Estimator" 132 | ) 133 | self.max_int_skew = (2**self.quantizer.n_bits) // 4 # For asymmetric quantization 134 | 135 | def loss_fx(self, data, neg_thr, pos_thr, per_channel_loss=False): 136 | y = self.quantize(data, x_min=neg_thr, x_max=pos_thr) 137 | temp_sum = torch.sum(((data - y) ** 2).view(len(data), -1), dim=1) 138 | # if we want to return the MSE loss of each channel separately, speeds up the per-channel 139 | # grid search 140 | if per_channel_loss: 141 | return to_numpy(temp_sum) 142 | else: 143 | return to_numpy(torch.sum(temp_sum)) 144 | 145 | @property 146 | def step_size(self): 147 | if self.one_sided_dist is None: 148 | raise NoDataPassedError() 149 | 150 | return self.max_search_range / self.num_candidates 151 | 152 | @property 153 | def optimization_method(self): 154 | if self.one_sided_dist is None: 155 | raise NoDataPassedError() 156 | 157 | if self.opt_method == OptMethod.grid: 158 | # Grid search method 159 | if self.one_sided_dist or self.quantizer.symmetric: 160 | # 1-D grid search 161 | return self._perform_1D_search 162 | else: 163 | # 2-D grid_search 164 | return self._perform_2D_search 165 | elif self.opt_method == OptMethod.golden_section: 166 | # Golden section method 167 | if self.one_sided_dist or self.quantizer.symmetric: 168 | return self._golden_section_symmetric 169 | else: 170 | return self._golden_section_asymmetric 171 | else: 172 | raise NotImplementedError("Optimization Method not Implemented") 173 | 174 | def quantize(self, x_float, x_min=None, x_max=None): 175 | temp_q = copy.deepcopy(self.quantizer) 176 | # In the current implementation no optimization procedure requires temp quantizer for 177 | # loss_fx to be per-channel 178 | temp_q.per_channel = False 179 | if x_min or x_max: 180 | temp_q.set_quant_range(x_min, x_max) 181 | return temp_q(x_float) 182 | 183 | def golden_sym_loss(self, range, data): 184 | """ 185 | Loss function passed to the golden section optimizer from scipy in case of symmetric 186 | quantization 187 | """ 188 | neg_thr = 0 if self.one_sided_dist else -range 189 | pos_thr = range 190 | return self.loss_fx(data, neg_thr, pos_thr) 191 | 192 | def golden_asym_shift_loss(self, shift, range, data): 193 | """ 194 | Inner Loss function (shift) passed to the golden section optimizer from scipy 195 | in case of asymmetric quantization 196 | """ 197 | pos_thr = range + shift 198 | neg_thr = -range + shift 199 | return self.loss_fx(data, neg_thr, pos_thr) 200 | 201 | def golden_asym_range_loss(self, range, data): 202 | """ 203 | Outer Loss function (range) passed to the golden section optimizer from scipy in case of 204 | asymmetric quantization 205 | """ 206 | temp_delta = 2 * range / (2**self.quantizer.n_bits - 1) 207 | max_shift = temp_delta * self.max_int_skew 208 | result = minimize_scalar( 209 | self.golden_asym_shift_loss, 210 | args=(range, data), 211 | bounds=(-max_shift, max_shift), 212 | method="Bounded", 213 | ) 214 | return result.fun 215 | 216 | def _define_search_range(self, data): 217 | self.channel_groups = len(data) if self.per_channel else 1 218 | self.current_xmax = torch.zeros(self.channel_groups, device=data.device) 219 | self.current_xmin = torch.zeros(self.channel_groups, device=data.device) 220 | 221 | if self.one_sided_dist or self.quantizer.symmetric: 222 | # 1D search space 223 | self.loss_array = np.zeros( 224 | (self.channel_groups, self.num_candidates + 1) 225 | ) # 1D search space 226 | self.loss_array[:, 0] = np.inf # exclude interval_start=interval_finish 227 | # Defining the search range for clipping thresholds 228 | self.max_pos_thr = max(abs(float(data.min())), float(data.max())) + self.range_margin 229 | self.max_neg_thr = -self.max_pos_thr 230 | self.max_search_range = self.max_pos_thr 231 | else: 232 | # 2D search space (3rd and 4th index correspond to asymmetry where fourth 233 | # index represents whether the skew is positive (0) or negative (1)) 234 | self.loss_array = np.zeros( 235 | [self.channel_groups, self.num_candidates + 1, self.max_int_skew, 2] 236 | ) # 2D search space 237 | self.loss_array[:, 0, :, :] = np.inf # exclude interval_start=interval_finish 238 | # Define the search range for clipping thresholds in asymmetric case 239 | self.max_pos_thr = float(data.max()) + self.range_margin 240 | self.max_neg_thr = float(data.min()) - self.range_margin 241 | self.max_search_range = max(abs(self.max_pos_thr), abs(self.max_neg_thr)) 242 | 243 | def _perform_1D_search(self, data): 244 | """ 245 | Grid search through all candidate quantizers in 1D to find the best 246 | The loss is accumulated over all batches without any momentum 247 | :param data: input tensor 248 | """ 249 | for cand_index in range(1, self.num_candidates + 1): 250 | neg_thr = 0 if self.one_sided_dist else -self.step_size * cand_index 251 | pos_thr = self.step_size * cand_index 252 | 253 | self.loss_array[:, cand_index] += self.loss_fx( 254 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel 255 | ) 256 | # find the best clipping thresholds 257 | min_cand = self.loss_array.argmin(axis=1) 258 | xmin = ( 259 | np.zeros(self.channel_groups) if self.one_sided_dist else -self.step_size * min_cand 260 | ).astype(np.single) 261 | xmax = (self.step_size * min_cand).astype(np.single) 262 | self.current_xmax = torch.tensor(xmax).to(device=data.device) 263 | self.current_xmin = torch.tensor(xmin).to(device=data.device) 264 | 265 | def _perform_2D_search(self, data): 266 | """ 267 | Grid search through all candidate quantizers in 1D to find the best 268 | The loss is accumulated over all batches without any momentum 269 | Parameters 270 | ---------- 271 | data: PyTorch Tensor 272 | Returns 273 | ------- 274 | 275 | """ 276 | for cand_index in range(1, self.num_candidates + 1): 277 | # defining the symmetric quantization range 278 | temp_start = -self.step_size * cand_index 279 | temp_finish = self.step_size * cand_index 280 | temp_delta = float(temp_finish - temp_start) / (2**self.quantizer.n_bits - 1) 281 | for shift in range(self.max_int_skew): 282 | for reverse in range(2): 283 | # introducing asymmetry in the quantization range 284 | skew = ((-1) ** reverse) * shift * temp_delta 285 | neg_thr = max(temp_start + skew, self.max_neg_thr) 286 | pos_thr = min(temp_finish + skew, self.max_pos_thr) 287 | 288 | self.loss_array[:, cand_index, shift, reverse] += self.loss_fx( 289 | data, neg_thr, pos_thr, per_channel_loss=self.per_channel 290 | ) 291 | 292 | for channel_index in range(self.channel_groups): 293 | min_cand, min_shift, min_reverse = np.unravel_index( 294 | np.argmin(self.loss_array[channel_index], axis=None), 295 | self.loss_array[channel_index].shape, 296 | ) 297 | min_interval_start = -self.step_size * min_cand 298 | min_interval_finish = self.step_size * min_cand 299 | min_delta = float(min_interval_finish - min_interval_start) / ( 300 | 2**self.quantizer.n_bits - 1 301 | ) 302 | min_skew = ((-1) ** min_reverse) * min_shift * min_delta 303 | xmin = max(min_interval_start + min_skew, self.max_neg_thr) 304 | xmax = min(min_interval_finish + min_skew, self.max_pos_thr) 305 | 306 | self.current_xmin[channel_index] = torch.tensor(xmin).to(device=data.device) 307 | self.current_xmax[channel_index] = torch.tensor(xmax).to(device=data.device) 308 | 309 | def _golden_section_symmetric(self, data): 310 | for channel_index in range(self.channel_groups): 311 | if channel_index == 0 and not self.per_channel: 312 | data_segment = data 313 | else: 314 | data_segment = data[channel_index] 315 | 316 | self.result = minimize_scalar( 317 | self.golden_sym_loss, 318 | args=data_segment, 319 | bounds=(0.01 * self.max_search_range, self.max_search_range), 320 | method="Bounded", 321 | ) 322 | self.current_xmax[channel_index] = torch.tensor(self.result.x).to(device=data.device) 323 | self.current_xmin[channel_index] = ( 324 | torch.tensor(0.0).to(device=data.device) 325 | if self.one_sided_dist 326 | else -self.current_xmax[channel_index] 327 | ) 328 | 329 | def _golden_section_asymmetric(self, data): 330 | for channel_index in range(self.channel_groups): 331 | if channel_index == 0 and not self.per_channel: 332 | data_segment = data 333 | else: 334 | data_segment = data[channel_index] 335 | 336 | self.result = minimize_scalar( 337 | self.golden_asym_range_loss, 338 | args=data_segment, 339 | bounds=(0.01 * self.max_search_range, self.max_search_range), 340 | method="Bounded", 341 | ) 342 | self.final_range = self.result.x 343 | temp_delta = 2 * self.final_range / (2**self.quantizer.n_bits - 1) 344 | max_shift = temp_delta * self.max_int_skew 345 | self.subresult = minimize_scalar( 346 | self.golden_asym_shift_loss, 347 | args=(self.final_range, data_segment), 348 | bounds=(-max_shift, max_shift), 349 | method="Bounded", 350 | ) 351 | self.final_shift = self.subresult.x 352 | self.current_xmax[channel_index] = torch.tensor(self.final_range + self.final_shift).to( 353 | device=data.device 354 | ) 355 | self.current_xmin[channel_index] = torch.tensor( 356 | -self.final_range + self.final_shift 357 | ).to(device=data.device) 358 | 359 | def forward(self, data): 360 | if self.loss_array is None: 361 | # Initialize search range on first batch, and accumulate losses with subsequent calls 362 | 363 | # Decide whether input distribution is one-sided 364 | if self.one_sided_dist is None: 365 | self.one_sided_dist = bool((data.min() >= 0).item()) 366 | 367 | # Define search 368 | self._define_search_range(data) 369 | 370 | # Perform Search/Optimization for Quantization Ranges 371 | self.optimization_method(data) 372 | 373 | return self.current_xmin, self.current_xmax 374 | 375 | def reset(self): 376 | super().reset() 377 | self.loss_array = None 378 | 379 | def extra_repr(self): 380 | repr = "opt_method={}".format(self.opt_method.name) 381 | if self.opt_method == OptMethod.grid: 382 | repr += " ,num_candidates={}".format(self.num_candidates) 383 | return repr 384 | 385 | 386 | class NoDataPassedError(Exception): 387 | """Raised data has been passed inot the Range Estimator""" 388 | 389 | def __init__(self): 390 | super().__init__("Data must be pass through the range estimator to be initialized") 391 | 392 | 393 | class RangeEstimators(ClassEnumOptions): 394 | current_minmax = MethodMap(CurrentMinMaxEstimator) 395 | running_minmax = MethodMap(RunningMinMaxEstimator) 396 | MSE = MethodMap(MSE_Estimator) 397 | -------------------------------------------------------------------------------- /validate_mlm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 4 | # All Rights Reserved. 5 | import json 6 | import logging 7 | import math 8 | import os 9 | import random 10 | from collections import OrderedDict 11 | from itertools import chain 12 | from pathlib import Path 13 | 14 | import datasets 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import transformers 19 | from accelerate import Accelerator 20 | from accelerate.utils import set_seed 21 | from datasets import DatasetDict, load_dataset, load_from_disk 22 | from timm.utils import AverageMeter 23 | from torch.utils.data import DataLoader 24 | from tqdm.auto import tqdm 25 | from transformers import ( 26 | CONFIG_MAPPING, 27 | MODEL_MAPPING, 28 | AutoConfig, 29 | AutoModelForMaskedLM, 30 | AutoTokenizer, 31 | DataCollatorForLanguageModeling, 32 | ) 33 | 34 | from quantization.range_estimators import OptMethod, RangeEstimators 35 | from transformers_language.args import parse_args 36 | from transformers_language.dataset_setups import DatasetSetups 37 | from transformers_language.models.bert_attention import ( 38 | AttentionGateType, 39 | BertSelfAttentionWithExtras, 40 | ) 41 | from transformers_language.models.quantized_bert import QuantizedBertForMaskedLM 42 | from transformers_language.models.softmax import SOFTMAX_MAPPING 43 | from transformers_language.quant_configs import get_quant_config 44 | from transformers_language.utils import ( 45 | count_params, 46 | kurtosis, 47 | pass_data_for_range_estimation, 48 | val_qparams, 49 | ) 50 | 51 | logger = logging.getLogger("validate_mlm") 52 | logging.basicConfig( 53 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 54 | datefmt="%m/%d/%Y %H:%M:%S", 55 | level=logging.INFO, 56 | ) 57 | 58 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 59 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 60 | 61 | EXTRA_METRICS = True 62 | 63 | 64 | def attach_act_hooks(model): 65 | act_dict = OrderedDict() 66 | 67 | def _make_hook(name): 68 | def _hook(mod, inp, out): 69 | if isinstance(inp, tuple) and len(inp) > 0: 70 | inp = inp[0] 71 | act_dict[name] = (inp, out) 72 | 73 | return _hook 74 | 75 | for name, module in model.named_modules(): 76 | module.register_forward_hook(_make_hook(name)) 77 | return act_dict 78 | 79 | 80 | def main(): 81 | args = parse_args() 82 | logger.info(args) 83 | 84 | # convert dataset setup to an enum 85 | dataset_setup = DatasetSetups[args.dataset_setup] 86 | 87 | # Initialize the accelerator. We will let the accelerator handle device placement for us in 88 | # this example. 89 | # If we're using tracking, we also need to initialize it here and it will by default pick up 90 | # all supported trackers in the environment 91 | accelerator_log_kwargs = {} 92 | 93 | if args.with_tracking: 94 | accelerator_log_kwargs["log_with"] = args.report_to 95 | accelerator_log_kwargs["logging_dir"] = args.output_dir 96 | 97 | accelerator = Accelerator( 98 | gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs 99 | ) 100 | 101 | logger.info(accelerator.state) 102 | if accelerator.is_local_main_process: 103 | datasets.utils.logging.set_verbosity_warning() 104 | transformers.utils.logging.set_verbosity_info() 105 | else: 106 | datasets.utils.logging.set_verbosity_error() 107 | transformers.utils.logging.set_verbosity_error() 108 | 109 | # If passed along, set the training seed now. 110 | if args.seed is not None: 111 | set_seed(args.seed) 112 | 113 | # Prepare HuggingFace config 114 | # In distributed training, the .from_pretrained methods guarantee that only one local process 115 | # can concurrently download model & vocab. 116 | config_kwargs = { 117 | "cache_dir": args.model_cache_dir, 118 | } 119 | if args.config_name: 120 | config = AutoConfig.from_pretrained(args.config_name, **config_kwargs) 121 | elif args.model_name_or_path: 122 | config = AutoConfig.from_pretrained(args.model_name_or_path, **config_kwargs) 123 | else: 124 | config = CONFIG_MAPPING[args.model_type]() 125 | logger.warning("You are instantiating a new config instance from scratch.") 126 | 127 | # Display config after changes 128 | logger.info("HuggingFace config after user changes:") 129 | logger.info(str(config)) 130 | 131 | # Load tokenizer 132 | tokenizer_kwargs = { 133 | # 'cache_dir': args.model_cache_dir, 134 | } 135 | if args.model_name_or_path: 136 | tokenizer = AutoTokenizer.from_pretrained( 137 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, **tokenizer_kwargs 138 | ) 139 | else: 140 | raise ValueError( 141 | "You are instantiating a new tokenizer from scratch. This is not supported by this " 142 | "script. You can do it from another script, save it, and load it from here, " 143 | "using --tokenizer_name." 144 | ) 145 | 146 | # Load and prepare model 147 | if args.model_name_or_path: 148 | model = AutoModelForMaskedLM.from_pretrained( 149 | args.model_name_or_path, 150 | from_tf=bool(".ckpt" in args.model_name_or_path), 151 | config=config, 152 | cache_dir=args.model_cache_dir, 153 | ) 154 | else: 155 | logger.info("Training new model from scratch") 156 | model = AutoModelForMaskedLM.from_config(config) 157 | 158 | # replace GELUActivation with nn.GELU 159 | for layer_idx in range(len(model.bert.encoder.layer)): 160 | model.bert.encoder.layer[layer_idx].intermediate.intermediate_act_fn = nn.GELU() 161 | # (skip head since we do not quantize it anyway) 162 | 163 | # replace Self-attention module with ours 164 | # NOTE: currently assumes BERT 165 | logger.info("replace self-attention module with ours (+copy loaded weights for Q,K,V)") 166 | for layer_idx in range(len(model.bert.encoder.layer)): 167 | old_self = model.bert.encoder.layer[layer_idx].attention.self 168 | new_self = BertSelfAttentionWithExtras( 169 | config, 170 | softmax_fn=SOFTMAX_MAPPING[args.attn_softmax], 171 | alpha=args.alpha, 172 | max_seq_length=args.max_seq_length, 173 | skip_attn=args.skip_attn, 174 | attn_gate_type=AttentionGateType[args.attn_gate_type], 175 | attn_gate_init=args.attn_gate_init, 176 | attn_gate_mlp=args.attn_gate_mlp, 177 | attn_gate_mlp2=args.attn_gate_mlp2, 178 | attn_gate_linear_all_features=args.attn_gate_linear_all_features, 179 | ) 180 | 181 | # copy loaded weights 182 | new_self.load_state_dict(old_self.state_dict(), strict=False) 183 | model.bert.encoder.layer[layer_idx].attention.self = new_self 184 | 185 | # Gating -> load the model again to load missing alpha 186 | if args.attn_gate_type != "none": 187 | state_dict = torch.load(str(Path(args.model_name_or_path) / "pytorch_model.bin")) 188 | new_state_dict = {} 189 | for name, val in state_dict.items(): 190 | if "alpha" in name: 191 | new_state_dict[name] = val 192 | model.load_state_dict(new_state_dict, strict=False) 193 | 194 | # Display num params 195 | n_embeddings = count_params(model.bert.embeddings) 196 | n_encoder = count_params(model.bert.encoder) 197 | n_head = count_params(model.cls) 198 | logger.info( 199 | f"\nNumber of parameters:\n" 200 | f"\t* Embeddings:\t{n_embeddings}\n" 201 | f"\t* Encoder:\t{n_encoder}\n" 202 | f"\t* Head:\t{n_head}\n" 203 | f"\t= Total (pre-training):\t{n_embeddings + n_encoder + n_head}\n" 204 | f"\t= Total (encoder):\t{n_embeddings + n_encoder}\n" 205 | ) 206 | 207 | # Get the datasets. 208 | # In distributed training, the load_dataset function guarantee that only one local process can 209 | # concurrently download the dataset. 210 | 211 | pre_tokenized_path_map = OrderedDict( 212 | [ 213 | # (data_setup, max_seq_length, validation_percentage) -> dirname 214 | ((DatasetSetups.bookcorpus_and_wiki, 128, None), "tokenized_wiki_val_128"), 215 | ((DatasetSetups.bookcorpus_and_wiki, 128, 5), "tokenized_wiki_val_128_5%"), 216 | ((DatasetSetups.bookcorpus_and_wiki, 128, 1), "tokenized_wiki_val_128_1%"), 217 | ((DatasetSetups.wikitext_103, 128, None), "tokenized_wikitext_103_val_128"), 218 | ((DatasetSetups.wikitext_103, 128, 5), "tokenized_wikitext_103_val_128_5%"), 219 | ] 220 | ) 221 | for k, v in pre_tokenized_path_map.items(): 222 | pre_tokenized_path_map[k] = Path(args.data_cache_dir) / v 223 | 224 | tokenized_configuration = (dataset_setup, args.max_seq_length, args.validation_percentage) 225 | pre_tokenized_path = pre_tokenized_path_map.get(tokenized_configuration, None) 226 | 227 | if pre_tokenized_path is not None and pre_tokenized_path.exists(): 228 | pre_tokenized_path = str(pre_tokenized_path) 229 | 230 | accelerator.print(f"Loading pre-tokenized dataset from {pre_tokenized_path}") 231 | tokenized_datasets = load_from_disk(pre_tokenized_path) 232 | 233 | else: # do tokenization 234 | train_split = ( 235 | "train" if args.train_percentage is None else f"train[:{args.train_percentage}%]" 236 | ) 237 | val_split = ( 238 | "validation" 239 | if args.validation_percentage is None 240 | else f"validation[:{args.validation_percentage}%]" 241 | ) 242 | 243 | if dataset_setup == DatasetSetups.wikitext_2: 244 | raw_datasets = DatasetDict() 245 | raw_datasets["train"] = load_dataset( 246 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=train_split 247 | ) 248 | raw_datasets["validation"] = load_dataset( 249 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=val_split 250 | ) 251 | 252 | elif dataset_setup == DatasetSetups.wikitext_103: 253 | raw_datasets = DatasetDict() 254 | raw_datasets["train"] = load_dataset( 255 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=train_split 256 | ) 257 | raw_datasets["validation"] = load_dataset( 258 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=val_split 259 | ) 260 | 261 | elif dataset_setup == DatasetSetups.bookcorpus_and_wiki: 262 | bookcorpus = load_dataset( 263 | "bookcorpus", cache_dir=args.data_cache_dir, split=train_split 264 | ) 265 | 266 | wiki_train = load_dataset( 267 | "wiki40b", "en", cache_dir=args.data_cache_dir, split=train_split 268 | ) 269 | wiki_val = load_dataset("wiki40b", "en", cache_dir=args.data_cache_dir, split=val_split) 270 | 271 | # only keep the 'text' column 272 | wiki_train = wiki_train.remove_columns( 273 | [c for c in wiki_train.column_names if c != "text"] 274 | ) 275 | wiki_val = wiki_val.remove_columns( 276 | [col for col in wiki_val.column_names if col != "text"] 277 | ) 278 | assert bookcorpus.features.type == wiki_train.features.type 279 | 280 | raw_datasets = DatasetDict() 281 | raw_datasets["train_book"] = bookcorpus 282 | raw_datasets["train_wiki"] = wiki_train 283 | raw_datasets["validation"] = wiki_val 284 | 285 | else: 286 | raise ValueError(f"Unknown dataset, {dataset_setup}") 287 | 288 | # Preprocessing the datasets. 289 | 290 | # Check sequence length 291 | if args.max_seq_length is None: 292 | max_seq_length = tokenizer.model_max_length 293 | if max_seq_length > 1024: 294 | logger.warning( 295 | f"The tokenizer picked seems to have a very large `model_max_length` " 296 | f"({tokenizer.model_max_length}). Picking 1024 instead. You can change that " 297 | f"default value by passing --max_seq_length xxx." 298 | ) 299 | max_seq_length = 1024 300 | else: 301 | if args.max_seq_length > tokenizer.model_max_length: 302 | logger.warning( 303 | f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum " 304 | f"length for the model ({tokenizer.model_max_length}). Using " 305 | f"max_seq_length={tokenizer.model_max_length}." 306 | ) 307 | max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) 308 | 309 | # Tokenize all the texts. 310 | # YB: removed line-by-line option as we'll likely never use it 311 | column_names = raw_datasets["validation"].column_names 312 | text_column_name = "text" if "text" in column_names else column_names[0] 313 | 314 | # ... we tokenize every text, then concatenate them together before splitting them in smaller 315 | # parts. We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling 316 | # (see below) is more efficient when it receives the `special_tokens_mask`. 317 | def tokenize_function(examples): 318 | return tokenizer(examples[text_column_name], return_special_tokens_mask=True) 319 | 320 | # YB: make the default bs for text pre-processing explicit 321 | tokenizer_map_batch_size = 1000 322 | with accelerator.main_process_first(): 323 | tokenized_datasets = raw_datasets.map( 324 | tokenize_function, 325 | batched=True, 326 | batch_size=tokenizer_map_batch_size, 327 | writer_batch_size=tokenizer_map_batch_size, 328 | num_proc=args.preprocessing_num_workers, 329 | remove_columns=column_names, 330 | load_from_cache_file=not args.overwrite_cache, 331 | desc="Running tokenizer on every text in dataset", 332 | ) 333 | 334 | # Main data processing function that will concatenate all texts from our dataset and generate 335 | # chunks of max_seq_length. 336 | def group_texts(examples): 337 | # Concatenate all texts. 338 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 339 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 340 | # We drop the small remainder, we could add padding if the model supported it instead of 341 | # this drop, you can customize this part to your needs. 342 | if total_length >= max_seq_length: 343 | total_length = (total_length // max_seq_length) * max_seq_length 344 | # Split by chunks of max_len. 345 | result = { 346 | k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] 347 | for k, t in concatenated_examples.items() 348 | } 349 | return result 350 | 351 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts 352 | # throws away a remainder for each of those groups of 1,000 texts. You can adjust that 353 | # batch_size here but a higher value might be slower to preprocess. 354 | # 355 | # To speed up this part, we use multiprocessing. See the documentation of the map method for 356 | # more information: 357 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 358 | 359 | with accelerator.main_process_first(): 360 | tokenized_datasets = tokenized_datasets.map( 361 | group_texts, 362 | batched=True, 363 | batch_size=tokenizer_map_batch_size, 364 | num_proc=args.preprocessing_num_workers, 365 | load_from_cache_file=not args.overwrite_cache, 366 | desc=f"Grouping texts in chunks of {max_seq_length}", 367 | ) 368 | 369 | # 370 | 371 | eval_dataset = tokenized_datasets["validation"] 372 | 373 | # Conditional for small test subsets 374 | if len(eval_dataset) > 3: 375 | # Log a few random samples from the training set: 376 | for index in random.sample(range(len(eval_dataset)), 3): 377 | logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") 378 | 379 | # Data collator 380 | # This one will take care of randomly masking the tokens. 381 | data_collator = DataCollatorForLanguageModeling( 382 | tokenizer=tokenizer, mlm_probability=args.mlm_probability 383 | ) 384 | 385 | # DataLoaders creation: 386 | eval_dataloader = DataLoader( 387 | eval_dataset, 388 | collate_fn=data_collator, 389 | batch_size=args.per_device_eval_batch_size, 390 | num_workers=args.preprocessing_num_workers, 391 | ) 392 | 393 | # Prepare everything with our `accelerator`. 394 | model, eval_dataloader = accelerator.prepare(model, eval_dataloader) 395 | 396 | logger.info("FP model:") 397 | logger.info(model) 398 | 399 | # Quantize: 400 | if args.quantize: 401 | click_config = get_quant_config() 402 | 403 | # override number of batches 404 | click_config.act_quant.num_batches = args.est_num_batches 405 | click_config.quant.n_bits = args.n_bits 406 | click_config.quant.n_bits_act = args.n_bits_act 407 | if args.no_weight_quant: 408 | click_config.quant.weight_quant = False 409 | if args.no_act_quant: 410 | click_config.quant.act_quant = False 411 | 412 | # Weight Ranges 413 | if args.ranges_weights == "minmax": 414 | pass 415 | elif args.ranges_weights in ("mse", "MSE"): 416 | click_config.quant.weight_quant_method = RangeEstimators.MSE 417 | click_config.quant.weight_opt_method = OptMethod.grid 418 | else: 419 | raise ValueError(f"Unknown weight range estimation: {args.ranges_weights}") 420 | 421 | # Acts ranges 422 | if args.percentile is not None: 423 | click_config.act_quant.options["percentile"] = args.percentile 424 | 425 | if args.ranges_acts == "running_minmax": 426 | click_config.act_quant.quant_method = RangeEstimators.running_minmax 427 | 428 | elif args.ranges_acts == "MSE": 429 | click_config.act_quant.quant_method = RangeEstimators.MSE 430 | if args.qmethod_acts == "symmetric_uniform": 431 | click_config.act_quant.options = dict(opt_method=OptMethod.grid) 432 | elif args.qmethod_acts == "asymmetric_uniform": 433 | click_config.act_quant.options = dict(opt_method=OptMethod.golden_section) 434 | 435 | elif args.ranges_acts.startswith("L"): 436 | click_config.act_quant.quant_method = RangeEstimators.Lp 437 | p_norm = float(args.ranges_acts.replace("L", "")) 438 | options = dict(p_norm=p_norm) 439 | if args.qmethod_acts == "symmetric_uniform": 440 | options["opt_method"] = OptMethod.grid 441 | elif args.qmethod_acts == "asymmetric_uniform": 442 | options["opt_method"] = OptMethod.golden_section 443 | click_config.act_quant.options = options 444 | 445 | else: 446 | raise NotImplementedError(f"Unknown act range estimation setting, '{args.ranges_acts}'") 447 | 448 | qparams = val_qparams(click_config) 449 | qparams["quant_dict"] = {} 450 | 451 | model = QuantizedBertForMaskedLM(model, **qparams) 452 | model.set_quant_state( 453 | weight_quant=click_config.quant.weight_quant, act_quant=click_config.quant.act_quant 454 | ) 455 | 456 | logger.info("Quantized model:") 457 | logger.info(model) 458 | 459 | # Range estimation 460 | logger.info("** Estimate quantization ranges on training data **") 461 | pass_data_for_range_estimation( 462 | loader=eval_dataloader, 463 | model=model, 464 | act_quant=click_config.quant.act_quant, 465 | max_num_batches=click_config.act_quant.num_batches, 466 | ) 467 | model.fix_ranges() 468 | model.set_quant_state( 469 | weight_quant=click_config.quant.weight_quant, act_quant=click_config.quant.act_quant 470 | ) 471 | 472 | # attach hooks for activation stats (if needed) 473 | act_dict = {} 474 | if EXTRA_METRICS: 475 | act_dict = attach_act_hooks(model) 476 | 477 | num_layers = len(model.bert.encoder.layer) 478 | act_inf_norms = OrderedDict() 479 | act_kurtoses = OrderedDict() 480 | 481 | # *** Evaluation *** 482 | model.eval() 483 | losses = [] 484 | for batch_idx, batch in enumerate(tqdm(eval_dataloader)): 485 | with torch.no_grad(): 486 | outputs = model(**batch) 487 | 488 | loss = outputs.loss 489 | loss_ = accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)) 490 | losses.append(loss_) 491 | 492 | # compute inf norms 493 | if EXTRA_METRICS: 494 | for j in range(num_layers): 495 | for name in ( 496 | f"bert.encoder.layer.{j}.output.dense", # FFN output 497 | f"bert.encoder.layer.{j}.output.LayerNorm", # LN(FFN output + input) 498 | ): 499 | x_inp, x_out = act_dict[name] 500 | 501 | x = x_out 502 | 503 | # inf-norm 504 | x = x.view(x.size(0), -1) 505 | inf_norms = x.norm(dim=1, p=np.inf) 506 | if not name in act_inf_norms: 507 | act_inf_norms[name] = AverageMeter() 508 | for v in inf_norms: 509 | act_inf_norms[name].update(v.item()) 510 | 511 | # kurtosis 512 | if batch_idx <= 256: 513 | kurt = kurtosis(x) 514 | if not name in act_kurtoses: 515 | act_kurtoses[name] = AverageMeter() 516 | for v in kurt: 517 | act_kurtoses[name].update(v.item()) 518 | 519 | # compute inf norm also for input 520 | if "LayerNorm" in name: 521 | x = x_inp 522 | x = x.view(x.size(0), -1) 523 | inf_norms = x.norm(dim=1, p=np.inf) 524 | name += ".input" 525 | if not name in act_inf_norms: 526 | act_inf_norms[name] = AverageMeter() 527 | for v in inf_norms: 528 | act_inf_norms[name].update(v.item()) 529 | 530 | losses = torch.cat(losses) 531 | try: 532 | eval_loss = torch.mean(losses) 533 | perplexity = math.exp(eval_loss) 534 | except OverflowError: 535 | perplexity = float("inf") 536 | logger.info(f"perplexity: {perplexity:.4f}") 537 | 538 | # metrics 539 | metrics = OrderedDict([("perplexity", perplexity)]) 540 | 541 | if EXTRA_METRICS: 542 | for name, v in act_inf_norms.items(): 543 | metrics[name] = v.avg 544 | 545 | max_ffn_out_inf_norm = max(v.avg for k, v in act_inf_norms.items() if "dense" in k) 546 | max_LN_out_inf_norm = max( 547 | v.avg for k, v in act_inf_norms.items() if k.endswith("LayerNorm") 548 | ) 549 | max_LN_inp_inf_norm = max(v.avg for k, v in act_inf_norms.items() if "input" in k) 550 | avg_kurtosis = sum(v.avg for v in act_kurtoses.values()) / len(act_kurtoses.values()) 551 | max_kurtosis = max(v.avg for v in act_kurtoses.values()) 552 | 553 | metrics["max_ffn_out_inf_norm"] = max_ffn_out_inf_norm 554 | metrics["max_LN_out_inf_norm"] = max_LN_out_inf_norm 555 | metrics["max_LN_inp_inf_norm"] = max_LN_inp_inf_norm 556 | metrics["avg_kurtosis"] = avg_kurtosis 557 | metrics["max_kurtosis"] = max_kurtosis 558 | 559 | logger.info(f"max FFN output inf norm: {max_ffn_out_inf_norm:.1f}") 560 | logger.info(f"max FFN input + output inf norm: {max_LN_inp_inf_norm:.1f}") 561 | logger.info(f"max LN(FFN i + o) inf norm: {max_LN_out_inf_norm:.1f}") 562 | logger.info(f"Avg Kurtosis: {avg_kurtosis:.2f}") 563 | logger.info(f"Max Kurtosis: {max_kurtosis:.1f}") 564 | 565 | if args.output_dir is not None: 566 | os.makedirs(args.output_dir, exist_ok=True) 567 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: 568 | json.dump(metrics, f) 569 | 570 | 571 | if __name__ == "__main__": 572 | main() 573 | -------------------------------------------------------------------------------- /validate_clm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 4 | # All Rights Reserved. 5 | import json 6 | import logging 7 | import math 8 | import os 9 | import random 10 | from collections import OrderedDict 11 | from itertools import chain 12 | from pathlib import Path 13 | from pprint import pformat 14 | 15 | import datasets 16 | import numpy as np 17 | import torch 18 | import transformers 19 | from accelerate import Accelerator 20 | from accelerate.utils import set_seed 21 | from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk 22 | from timm.utils import AverageMeter 23 | from torch.utils.data import DataLoader 24 | from tqdm.auto import tqdm 25 | from transformers import ( 26 | CONFIG_MAPPING, 27 | MODEL_MAPPING, 28 | AutoConfig, 29 | AutoModelForCausalLM, 30 | AutoTokenizer, 31 | default_data_collator, 32 | ) 33 | 34 | from quantization.quantizers import QMethods 35 | from quantization.range_estimators import OptMethod, RangeEstimators 36 | from transformers_language.args import parse_args 37 | from transformers_language.dataset_setups import DatasetSetups 38 | from transformers_language.models.opt_attention import ( 39 | AttentionGateType, 40 | OPTAttentionWithExtras, 41 | ) 42 | from transformers_language.models.quantized_opt import QuantizedOPTForCausalLM 43 | from transformers_language.models.softmax import SOFTMAX_MAPPING 44 | from transformers_language.quant_configs import get_quant_config 45 | from transformers_language.utils import ( 46 | count_params, 47 | kurtosis, 48 | pass_data_for_range_estimation, 49 | val_qparams, 50 | ) 51 | 52 | logger = logging.getLogger("validate_clm") 53 | logging.basicConfig( 54 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 55 | datefmt="%m/%d/%Y %H:%M:%S", 56 | level=logging.INFO, 57 | ) 58 | 59 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 60 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 61 | 62 | 63 | def main(): 64 | args = parse_args() 65 | logger.info(args) 66 | 67 | # convert dataset setup to an enum 68 | dataset_setup = DatasetSetups[args.dataset_setup] 69 | 70 | # Initialize the accelerator. We will let the accelerator handle device placement for us in 71 | # this example. 72 | # If we're using tracking, we also need to initialize it here and it will by default pick up 73 | # all supported trackers in the environment 74 | accelerator_log_kwargs = {} 75 | 76 | if args.with_tracking: 77 | accelerator_log_kwargs["log_with"] = args.report_to 78 | accelerator_log_kwargs["logging_dir"] = args.output_dir 79 | 80 | accelerator = Accelerator( 81 | gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs 82 | ) 83 | 84 | logger.info(accelerator.state) 85 | if accelerator.is_local_main_process: 86 | datasets.utils.logging.set_verbosity_warning() 87 | transformers.utils.logging.set_verbosity_info() 88 | else: 89 | datasets.utils.logging.set_verbosity_error() 90 | transformers.utils.logging.set_verbosity_error() 91 | 92 | # If passed along, set the training seed now. 93 | if args.seed is not None: 94 | set_seed(args.seed) 95 | 96 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 97 | # https://huggingface.co/docs/datasets/loading_datasets.html. 98 | 99 | # Load pretrained model and tokenizer 100 | # 101 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 102 | # download model & vocab. 103 | config_kwargs = { 104 | "cache_dir": args.model_cache_dir, 105 | } 106 | if args.config_name: 107 | config = AutoConfig.from_pretrained(args.config_name, **config_kwargs) 108 | elif args.model_name_or_path: 109 | config = AutoConfig.from_pretrained(args.model_name_or_path, **config_kwargs) 110 | else: 111 | config = CONFIG_MAPPING[args.model_type]() 112 | logger.warning("You are instantiating a new config instance from scratch.") 113 | 114 | # Display config after changes 115 | logger.info("HuggingFace config after user changes:") 116 | logger.info(str(config)) 117 | 118 | # Load tokenizer 119 | tokenizer_kwargs = { 120 | # 'cache_dir': args.model_cache_dir, 121 | } 122 | if args.model_name_or_path: 123 | tokenizer = AutoTokenizer.from_pretrained( 124 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, **tokenizer_kwargs 125 | ) 126 | else: 127 | raise ValueError( 128 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 129 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 130 | ) 131 | 132 | # Load and prepare model 133 | if args.model_name_or_path: 134 | model = AutoModelForCausalLM.from_pretrained( 135 | args.model_name_or_path, 136 | from_tf=bool(".ckpt" in args.model_name_or_path), 137 | config=config, 138 | low_cpu_mem_usage=args.low_cpu_mem_usage, 139 | cache_dir=args.model_cache_dir, 140 | ) 141 | else: 142 | logger.info("Training new model from scratch") 143 | model = AutoModelForCausalLM.from_config(config) 144 | 145 | # >> replace self-attention module with ours 146 | # NOTE: currently assumes OPT 147 | for layer_idx in range(len(model.model.decoder.layers)): 148 | old_attn = model.model.decoder.layers[layer_idx].self_attn 149 | new_attn = OPTAttentionWithExtras( 150 | embed_dim=old_attn.embed_dim, 151 | num_heads=old_attn.num_heads, 152 | dropout=old_attn.dropout, 153 | is_decoder=old_attn.is_decoder, 154 | bias=True, 155 | # new 156 | softmax_fn=SOFTMAX_MAPPING[args.attn_softmax], 157 | alpha=args.alpha, 158 | max_seq_length=args.block_size, 159 | skip_attn=args.skip_attn, 160 | attn_gate_type=AttentionGateType[args.attn_gate_type], 161 | attn_gate_init=args.attn_gate_init, 162 | attn_gate_mlp=args.attn_gate_mlp, 163 | attn_gate_mlp2=args.attn_gate_mlp2, 164 | attn_gate_linear_all_features=args.attn_gate_linear_all_features, 165 | ) 166 | # copy loaded weights 167 | new_attn.load_state_dict(old_attn.state_dict(), strict=False) 168 | model.model.decoder.layers[layer_idx].self_attn = new_attn 169 | 170 | # Gating -> load the model again to load missing alpha 171 | if args.attn_gate_type != "none": 172 | state_dict = torch.load(str(Path(args.model_name_or_path) / "pytorch_model.bin")) 173 | new_state_dict = {} 174 | for name, val in state_dict.items(): 175 | if "alpha" in name: 176 | new_state_dict[name] = val 177 | model.load_state_dict(new_state_dict, strict=False) 178 | 179 | # Print model 180 | logger.info("Model:") 181 | logger.info(model) 182 | 183 | # Display num params 184 | n_embeddings = count_params(model.model.decoder.embed_tokens) + count_params( 185 | model.model.decoder.embed_positions 186 | ) 187 | n_decoder = count_params(model.model.decoder) - n_embeddings 188 | n_head = count_params(model.lm_head) 189 | logger.info( 190 | f"\nNumber of parameters:\n" 191 | f"\t* Embeddings:\t{n_embeddings}\n" 192 | f"\t* Decoder:\t{n_decoder}\n" 193 | f"\t* Head:\t{n_head}\n" 194 | f"\t= Total (pre-training):\t{n_embeddings + n_decoder + n_head}\n" 195 | f"\t= Total (decoder only):\t{n_embeddings + n_decoder}\n" 196 | ) 197 | 198 | # ----------------------------------------------------------------- 199 | 200 | # Get the datasets. 201 | # In distributed training, the load_dataset function guarantee that only one local process can 202 | # concurrently download the dataset. 203 | 204 | # (data_setup, block_size, train_percentage, validation_percentage) -> train_dirname 205 | pre_tokenized_path_map = OrderedDict( 206 | [ 207 | ((DatasetSetups.wikitext_103, 512, None, None), ("tokenized_wikitext_103_OPT_512")), 208 | ( 209 | (DatasetSetups.wikitext_103, 512, None, 10), 210 | ("tokenized_wikitext_103_OPT_512_val_10%"), 211 | ), 212 | ( 213 | (DatasetSetups.wikitext_103, 512, 10, None), 214 | ("tokenized_wikitext_103_OPT_512_train_10%"), 215 | ), 216 | ( 217 | (DatasetSetups.bookcorpus_and_wiki, 512, 1, 5), 218 | ("tokenized_book_wiki_OPT_512_train_1%_val_5%"), 219 | ), 220 | ( 221 | (DatasetSetups.bookcorpus_and_wiki, 512, 1, 1), 222 | ("tokenized_book_wiki_OPT_512_train_1%_val_1%"), 223 | ), 224 | ] 225 | ) 226 | for k, v in pre_tokenized_path_map.items(): 227 | pre_tokenized_path_map[k] = Path(args.data_cache_dir) / v 228 | 229 | tokenized_configuration = ( 230 | dataset_setup, 231 | args.block_size, 232 | args.train_percentage, 233 | args.validation_percentage, 234 | ) 235 | pre_tokenized_path = pre_tokenized_path_map.get(tokenized_configuration, None) 236 | 237 | if pre_tokenized_path is not None and pre_tokenized_path.exists(): 238 | pre_tokenized_path = str(pre_tokenized_path) 239 | 240 | accelerator.print(f"Loading pre-tokenized dataset from {pre_tokenized_path}") 241 | tokenized_datasets = load_from_disk(pre_tokenized_path) 242 | 243 | else: # do tokenization 244 | train_split = ( 245 | "train" if args.train_percentage is None else f"train[:{args.train_percentage}%]" 246 | ) 247 | val_split = ( 248 | "validation" 249 | if args.validation_percentage is None 250 | else f"validation[:{args.validation_percentage}%]" 251 | ) 252 | 253 | if dataset_setup == DatasetSetups.wikitext_2: 254 | raw_datasets = DatasetDict() 255 | raw_datasets["train"] = load_dataset( 256 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=train_split 257 | ) 258 | raw_datasets["validation"] = load_dataset( 259 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=val_split 260 | ) 261 | 262 | elif dataset_setup == DatasetSetups.wikitext_103: 263 | raw_datasets = DatasetDict() 264 | raw_datasets["train"] = load_dataset( 265 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=train_split 266 | ) 267 | raw_datasets["validation"] = load_dataset( 268 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=val_split 269 | ) 270 | 271 | elif dataset_setup == DatasetSetups.bookcorpus_and_wiki: 272 | bookcorpus = load_dataset( 273 | "bookcorpus", cache_dir=args.data_cache_dir, split=train_split 274 | ) 275 | 276 | wiki_train = load_dataset( 277 | "wiki40b", "en", cache_dir=args.data_cache_dir, split=train_split 278 | ) 279 | wiki_val = load_dataset("wiki40b", "en", cache_dir=args.data_cache_dir, split=val_split) 280 | 281 | # only keep the 'text' column 282 | wiki_train = wiki_train.remove_columns( 283 | [c for c in wiki_train.column_names if c != "text"] 284 | ) 285 | wiki_val = wiki_val.remove_columns( 286 | [col for col in wiki_val.column_names if col != "text"] 287 | ) 288 | assert bookcorpus.features.type == wiki_train.features.type 289 | 290 | raw_datasets = DatasetDict() 291 | raw_datasets["train_book"] = bookcorpus 292 | raw_datasets["train_wiki"] = wiki_train 293 | raw_datasets["validation"] = wiki_val 294 | 295 | else: 296 | raise ValueError(f"Unknown dataset, {dataset_setup}") 297 | 298 | # Preprocessing the datasets. 299 | # Check sequence length 300 | if args.block_size is None: 301 | block_size = tokenizer.model_max_length 302 | if block_size > 1024: 303 | logger.warning( 304 | "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" 305 | " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" 306 | " override this default with `--block_size xxx`." 307 | ) 308 | block_size = 1024 309 | else: 310 | if args.block_size > tokenizer.model_max_length: 311 | logger.warning( 312 | f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" 313 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 314 | ) 315 | block_size = min(args.block_size, tokenizer.model_max_length) 316 | 317 | # Tokenize all the texts. 318 | column_names = raw_datasets["validation"].column_names 319 | text_column_name = "text" if "text" in column_names else column_names[0] 320 | 321 | def tokenize_function(examples): 322 | return tokenizer(examples[text_column_name]) 323 | 324 | # YB: make the default bs for text pre-processing explicit 325 | tokenizer_map_batch_size = 1000 326 | with accelerator.main_process_first(): 327 | tokenized_datasets = raw_datasets.map( 328 | tokenize_function, 329 | batched=True, 330 | batch_size=tokenizer_map_batch_size, 331 | writer_batch_size=tokenizer_map_batch_size, 332 | num_proc=args.preprocessing_num_workers, 333 | remove_columns=column_names, 334 | load_from_cache_file=not args.overwrite_cache, 335 | desc="Running tokenizer on dataset", 336 | ) 337 | 338 | # Main data processing function that will concatenate all texts from our dataset and generate 339 | # chunks of max_seq_length. 340 | def group_texts(examples): 341 | # Concatenate all texts. 342 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 343 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 344 | # We drop the small remainder, we could add padding if the model supported it instead of 345 | # this drop, you can customize this part to your needs. 346 | if total_length >= block_size: 347 | total_length = (total_length // block_size) * block_size 348 | else: 349 | total_length = 0 350 | # Split by chunks of max_len. 351 | result = { 352 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 353 | for k, t in concatenated_examples.items() 354 | } 355 | result["labels"] = result["input_ids"].copy() 356 | return result 357 | 358 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder 359 | # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower 360 | # to preprocess. 361 | # 362 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 363 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 364 | 365 | with accelerator.main_process_first(): 366 | tokenized_datasets = tokenized_datasets.map( 367 | group_texts, 368 | batched=True, 369 | batch_size=tokenizer_map_batch_size, 370 | num_proc=args.preprocessing_num_workers, 371 | load_from_cache_file=not args.overwrite_cache, 372 | desc=f"Grouping texts in chunks of {block_size}", 373 | ) 374 | 375 | # 376 | 377 | if dataset_setup == DatasetSetups.bookcorpus_and_wiki: 378 | train_dataset = concatenate_datasets( 379 | [tokenized_datasets["train_book"], tokenized_datasets["train_wiki"]] 380 | ) 381 | eval_dataset = tokenized_datasets["validation"] 382 | else: 383 | train_dataset = tokenized_datasets["train"] 384 | eval_dataset = tokenized_datasets["validation"] 385 | 386 | # Log a few random samples from the training set: 387 | if len(train_dataset) > 3: 388 | for index in random.sample(range(len(train_dataset)), 3): 389 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 390 | 391 | # DataLoaders creation: 392 | train_dataloader = DataLoader( 393 | train_dataset, 394 | shuffle=True, 395 | collate_fn=default_data_collator, 396 | batch_size=args.per_device_train_batch_size, 397 | num_workers=args.preprocessing_num_workers, 398 | ) 399 | eval_dataloader = DataLoader( 400 | eval_dataset, 401 | collate_fn=default_data_collator, 402 | batch_size=args.per_device_eval_batch_size, 403 | num_workers=args.preprocessing_num_workers, 404 | ) 405 | 406 | # Prepare everything with our `accelerator`. 407 | model, train_dataloader, eval_dataloader = accelerator.prepare( 408 | model, train_dataloader, eval_dataloader 409 | ) 410 | 411 | logger.info("FP model:") 412 | logger.info(model) 413 | 414 | # 415 | ## Quantize: 416 | # 417 | if args.quantize: 418 | click_config = get_quant_config() 419 | 420 | # override number of batches 421 | click_config.act_quant.num_batches = args.est_num_batches 422 | click_config.quant.n_bits = args.n_bits 423 | click_config.quant.n_bits_act = args.n_bits_act 424 | click_config.quant.quant_setup = args.quant_setup 425 | if args.no_weight_quant: 426 | click_config.quant.weight_quant = False 427 | if args.no_act_quant: 428 | click_config.quant.act_quant = False 429 | 430 | # use MSE for weights (ignore `args.ranges_weights`) 431 | # click_config.quant.weight_quant_method = RangeEstimators.current_minmax 432 | click_config.quant.weight_quant_method = RangeEstimators.MSE 433 | click_config.quant.weight_opt_method = OptMethod.grid 434 | 435 | # qmethod acts 436 | if args.qmethod_acts == "symmetric_uniform": 437 | click_config.quant.qmethod_act = QMethods.symmetric_uniform 438 | elif args.qmethod_acts == "asymmetric_uniform": 439 | click_config.quant.qmethod_act = QMethods.asymmetric_uniform 440 | else: 441 | raise NotImplementedError(f"Unknown qmethod_act setting, '{args.qmethod_acts}'") 442 | 443 | # Acts ranges 444 | if args.percentile is not None: 445 | click_config.act_quant.options["percentile"] = args.percentile 446 | 447 | if args.ranges_acts == "running_minmax": 448 | click_config.act_quant.quant_method = RangeEstimators.running_minmax 449 | 450 | elif args.ranges_acts == "MSE": 451 | click_config.act_quant.quant_method = RangeEstimators.MSE 452 | if args.qmethod_acts == "symmetric_uniform": 453 | click_config.act_quant.options = dict(opt_method=OptMethod.grid) 454 | elif args.qmethod_acts == "asymmetric_uniform": 455 | click_config.act_quant.options = dict(opt_method=OptMethod.golden_section) 456 | 457 | elif args.ranges_acts.startswith("L"): 458 | click_config.act_quant.quant_method = RangeEstimators.Lp 459 | p_norm = float(args.ranges_acts.replace("L", "")) 460 | options = dict(p_norm=p_norm) 461 | if args.qmethod_acts == "symmetric_uniform": 462 | options["opt_method"] = OptMethod.grid 463 | elif args.qmethod_acts == "asymmetric_uniform": 464 | options["opt_method"] = OptMethod.golden_section 465 | click_config.act_quant.options = options 466 | 467 | else: 468 | raise NotImplementedError(f"Unknown range estimation setting, '{args.ranges_acts}'") 469 | 470 | qparams = val_qparams(click_config) 471 | qparams["quant_dict"] = {} 472 | 473 | model = QuantizedOPTForCausalLM(model, **qparams) 474 | model.set_quant_state( 475 | weight_quant=click_config.quant.weight_quant, act_quant=click_config.quant.act_quant 476 | ) 477 | 478 | logger.info("Quantized model:") 479 | logger.info(model) 480 | 481 | # Range estimation 482 | logger.info("** Estimate quantization ranges on training data **") 483 | pass_data_for_range_estimation( 484 | loader=train_dataloader, 485 | model=model, 486 | act_quant=click_config.quant.act_quant, 487 | max_num_batches=click_config.act_quant.num_batches, 488 | ) 489 | model.fix_ranges() 490 | 491 | model.set_quant_state( 492 | weight_quant=click_config.quant.weight_quant, act_quant=click_config.quant.act_quant 493 | ) 494 | 495 | # attach hooks for activation stats 496 | def attach_act_hooks(model): 497 | act_dict = OrderedDict() 498 | 499 | def _make_hook(name): 500 | def _hook(mod, inp, out): 501 | if isinstance(inp, tuple) and len(inp) > 0: 502 | inp = inp[0] 503 | if isinstance(out, tuple) and len(out) > 0: 504 | out = out[0] 505 | act_dict[name] = (inp, out) 506 | 507 | return _hook 508 | 509 | for name, module in model.named_modules(): 510 | module.register_forward_hook(_make_hook(name)) 511 | return act_dict 512 | 513 | if args.quantize: 514 | act_dict = {} 515 | else: 516 | act_dict = attach_act_hooks(model) 517 | num_layers = len(model.model.decoder.layers) 518 | 519 | ACT_KEYS = [ 520 | "model.decoder.final_layer_norm", 521 | *[f"model.decoder.layers.{j}" for j in range(num_layers)], 522 | *[f"model.decoder.layers.{j}.fc2" for j in range(num_layers)], 523 | *[f"model.decoder.layers.{j}.final_layer_norm" for j in range(num_layers)], 524 | *[f"model.decoder.layers.{j}.self_attn.out_proj" for j in range(num_layers)], 525 | *[f"model.decoder.layers.{j}.self_attn_layer_norm" for j in range(num_layers)], 526 | ] 527 | 528 | act_inf_norms = OrderedDict() 529 | act_kurtoses = OrderedDict() 530 | 531 | # ----------------------------------------------------------------- 532 | 533 | # *** Evaluation *** 534 | model.eval() 535 | losses = [] 536 | for batch_idx, batch in enumerate(tqdm(eval_dataloader)): 537 | with torch.no_grad(): 538 | outputs = model(**batch) 539 | 540 | loss = outputs.loss 541 | loss_ = accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)) 542 | losses.append(loss_) 543 | 544 | # compute inf norms 545 | if not args.quantize: 546 | for name in ACT_KEYS: 547 | if name in act_dict: 548 | x_inp, x_out = act_dict[name] 549 | x = x_out 550 | x = x.view(x.size(0), -1) 551 | 552 | # compute inf norm 553 | inf_norms = x.norm(dim=1, p=np.inf) 554 | if not name in act_inf_norms: 555 | act_inf_norms[name] = AverageMeter() 556 | for v in inf_norms: 557 | act_inf_norms[name].update(v.item()) 558 | 559 | # compute kurtosis 560 | if batch_idx <= 100: 561 | kurt = kurtosis(x) 562 | if not name in act_kurtoses: 563 | act_kurtoses[name] = AverageMeter() 564 | for v in kurt: 565 | act_kurtoses[name].update(v.item()) 566 | 567 | losses = torch.cat(losses) 568 | try: 569 | eval_loss = torch.mean(losses) 570 | perplexity = math.exp(eval_loss) 571 | except OverflowError: 572 | perplexity = float("inf") 573 | logger.info(f"perplexity: {perplexity:.4f}") 574 | 575 | # metrics 576 | metrics = OrderedDict([("perplexity", perplexity)]) 577 | 578 | if not args.quantize: 579 | for name, v in act_inf_norms.items(): 580 | metrics[name] = v.avg 581 | 582 | max_inf_norm = max(v.avg for v in act_inf_norms.values()) 583 | max_ffn_inf_norm = max(v.avg for k, v in act_inf_norms.items() if ".fc" in k) 584 | max_layer_inf_norm = max( 585 | act_inf_norms[f"model.decoder.layers.{j}"].avg for j in range(num_layers) 586 | ) 587 | 588 | avg_kurtosis = sum(v.avg for v in act_kurtoses.values()) / len(act_kurtoses.values()) 589 | max_kurtosis = max(v.avg for v in act_kurtoses.values()) 590 | max_kurtosis_layers = max( 591 | act_kurtoses[f"model.decoder.layers.{j}"].avg for j in range(num_layers) 592 | ) 593 | 594 | metrics["max_inf_norm"] = max_inf_norm 595 | metrics["max_ffn_inf_norm"] = max_ffn_inf_norm 596 | metrics["max_layer_inf_norm"] = max_layer_inf_norm 597 | 598 | metrics["avg_kurtosis"] = avg_kurtosis 599 | metrics["max_kurtosis"] = max_kurtosis 600 | metrics["max_kurtosis_layers"] = max_kurtosis_layers 601 | 602 | logger.info(f"Max inf norm: {max_inf_norm:.1f}") 603 | logger.info(f"Max FFN inf norm: {max_ffn_inf_norm:.1f}") 604 | logger.info(f"Max layer inf norm: {max_layer_inf_norm:.1f}") 605 | 606 | logger.info(f"Avg Kurtosis: {avg_kurtosis:.2f}") 607 | logger.info(f"Max Kurtosis: {max_kurtosis:.1f}") 608 | logger.info(f"Max Kurtosis layers: {max_kurtosis_layers:.1f}") 609 | 610 | logger.info(f"\nAll metrics:\n{pformat(metrics)}") 611 | 612 | if args.output_dir is not None: 613 | os.makedirs(args.output_dir, exist_ok=True) 614 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: 615 | json.dump(metrics, f) 616 | 617 | 618 | if __name__ == "__main__": 619 | main() 620 | -------------------------------------------------------------------------------- /run_mlm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright (c) 2023 Qualcomm Technologies, Inc. 4 | # All Rights Reserved. 5 | import json 6 | import logging 7 | import math 8 | import os 9 | import random 10 | import warnings 11 | from collections import OrderedDict 12 | from itertools import chain 13 | from pathlib import Path 14 | 15 | import datasets 16 | import torch 17 | import transformers 18 | import yaml 19 | from accelerate import Accelerator 20 | from accelerate.logging import get_logger 21 | from accelerate.utils import set_seed 22 | from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk 23 | from torch.utils.data import DataLoader 24 | from tqdm.auto import tqdm 25 | from transformers import ( 26 | CONFIG_MAPPING, 27 | AutoConfig, 28 | AutoModelForMaskedLM, 29 | AutoTokenizer, 30 | DataCollatorForLanguageModeling, 31 | get_scheduler, 32 | ) 33 | 34 | from transformers_language.args import parse_args 35 | from transformers_language.dataset_setups import DatasetSetups 36 | from transformers_language.models.bert_attention import ( 37 | AttentionGateType, 38 | BertSelfAttentionWithExtras, 39 | ) 40 | from transformers_language.models.softmax import SOFTMAX_MAPPING 41 | from transformers_language.utils import count_params 42 | 43 | logger = get_logger("run_mlm") 44 | 45 | 46 | def attach_tb_act_hooks(model): 47 | act_dict = OrderedDict() 48 | 49 | def _make_hook(name): 50 | def _hook(mod, inp, out): 51 | act_dict[name] = out[0] 52 | 53 | return _hook 54 | 55 | for name, module in model.named_modules(): 56 | module.register_forward_hook(_make_hook(name)) 57 | return act_dict 58 | 59 | 60 | def main(): 61 | args = parse_args() 62 | 63 | # convert dataset setup to an enum 64 | dataset_setup = DatasetSetups[args.dataset_setup] 65 | 66 | # Initialize the accelerator. We will let the accelerator handle device placement for us in 67 | # this example. 68 | # If we're using tracking, we also need to initialize it here and it will by default pick up 69 | # all supported trackers in the environment 70 | accelerator_log_kwargs = {} 71 | 72 | if args.with_tracking: 73 | accelerator_log_kwargs["log_with"] = args.report_to 74 | accelerator_log_kwargs["project_dir"] = args.output_dir 75 | 76 | accelerator = Accelerator( 77 | gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs 78 | ) 79 | accelerator.project_configuration.total_limit = 1 80 | accelerator.project_configuration.automatic_checkpoint_naming = True 81 | 82 | # log passed args 83 | logger.info(args) 84 | 85 | # Make one log on every process with the configuration for debugging. 86 | logging.basicConfig( 87 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 88 | datefmt="%m/%d/%Y %H:%M:%S", 89 | level=logging.INFO, 90 | ) 91 | logger.info(accelerator.state, main_process_only=False) 92 | if accelerator.is_local_main_process: 93 | datasets.utils.logging.set_verbosity_warning() 94 | transformers.utils.logging.set_verbosity_info() 95 | else: 96 | datasets.utils.logging.set_verbosity_error() 97 | transformers.utils.logging.set_verbosity_error() 98 | 99 | # If passed along, set the training seed now. 100 | if args.seed is not None: 101 | set_seed(args.seed) 102 | 103 | # Prepare HuggingFace config 104 | # In distributed training, the .from_pretrained methods guarantee that only one local process 105 | # can concurrently download model & vocab. 106 | config_kwargs = { 107 | "cache_dir": args.model_cache_dir, 108 | } 109 | if args.config_name: 110 | config = AutoConfig.from_pretrained(args.config_name, **config_kwargs) 111 | elif args.model_name_or_path: 112 | config = AutoConfig.from_pretrained(args.model_name_or_path, **config_kwargs) 113 | else: 114 | config = CONFIG_MAPPING[args.model_type]() 115 | logger.warning("You are instantiating a new config instance from scratch.") 116 | 117 | # Load model config changes from file, if provided 118 | if args.config_path is not None: 119 | logger.info(f"Loading model config changes from {args.config_path}") 120 | with open(args.config_path) as f: 121 | config_changes = yaml.safe_load(f) 122 | 123 | for key, value in config_changes.items(): 124 | setattr(config, key, value) 125 | 126 | # Set dropout rates, if specified 127 | if args.attn_dropout is not None: 128 | logger.info(f"Setting attention dropout rate to {args.attn_dropout}") 129 | config.attention_probs_dropout_prob = args.attn_dropout 130 | 131 | if args.hidden_dropout is not None: 132 | logger.info(f"Setting hidden dropout rate to {args.hidden_dropout}") 133 | config.hidden_dropout_prob = args.hidden_dropout 134 | 135 | # Display config after changes 136 | logger.info("HuggingFace config after user changes:") 137 | logger.info(str(config)) 138 | 139 | # Load tokenizer 140 | tokenizer_kwargs = { 141 | "cache_dir": args.model_cache_dir, 142 | } 143 | if args.tokenizer_name: 144 | tokenizer = AutoTokenizer.from_pretrained( 145 | args.tokenizer_name, use_fast=not args.use_slow_tokenizer, **tokenizer_kwargs 146 | ) 147 | elif args.model_name_or_path: 148 | tokenizer = AutoTokenizer.from_pretrained( 149 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer, **tokenizer_kwargs 150 | ) 151 | else: 152 | raise ValueError( 153 | "You are instantiating a new tokenizer from scratch. This is not supported by this " 154 | "script. You can do it from another script, save it, and load it from here, " 155 | "using --tokenizer_name." 156 | ) 157 | 158 | # Load and prepare model 159 | if args.model_name_or_path: 160 | model = AutoModelForMaskedLM.from_pretrained( 161 | args.model_name_or_path, 162 | from_tf=bool(".ckpt" in args.model_name_or_path), 163 | config=config, 164 | cache_dir=args.model_cache_dir, 165 | ) 166 | else: 167 | logger.info("Training new model from scratch") 168 | model = AutoModelForMaskedLM.from_config(config) 169 | 170 | # >> replace Self-attention module with ours 171 | # NOTE: currently assumes BERT 172 | for layer_idx in range(len(model.bert.encoder.layer)): 173 | old_self = model.bert.encoder.layer[layer_idx].attention.self 174 | new_self = BertSelfAttentionWithExtras( 175 | config, 176 | softmax_fn=SOFTMAX_MAPPING[args.attn_softmax], 177 | alpha=args.alpha, 178 | max_seq_length=args.max_seq_length, 179 | skip_attn=args.skip_attn, 180 | attn_gate_type=AttentionGateType[args.attn_gate_type], 181 | attn_gate_init=args.attn_gate_init, 182 | attn_gate_mlp=args.attn_gate_mlp, 183 | attn_gate_mlp2=args.attn_gate_mlp2, 184 | attn_gate_linear_all_features=args.attn_gate_linear_all_features, 185 | fine_tuning=args.fine_tuning, 186 | ) 187 | 188 | # copy loaded weights 189 | if args.model_name_or_path is not None: 190 | new_self.load_state_dict(old_self.state_dict(), strict=False) 191 | model.bert.encoder.layer[layer_idx].attention.self = new_self 192 | 193 | # Gating -> load the model again to load missing alpha 194 | if args.model_name_or_path is not None and args.attn_gate_type != "none": 195 | state_dict = torch.load(str(Path(args.model_name_or_path) / "pytorch_model.bin")) 196 | new_state_dict = {} 197 | for name, val in state_dict.items(): 198 | if "alpha" in name: 199 | new_state_dict[name] = val 200 | model.load_state_dict(new_state_dict, strict=False) 201 | 202 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a 203 | # model from scratch on a small vocab and want a smaller embedding size, remove this test. 204 | embedding_size = model.get_input_embeddings().weight.shape[0] # = vocab size 205 | if len(tokenizer) > embedding_size: 206 | model.resize_token_embeddings(len(tokenizer)) 207 | 208 | # Display num params 209 | n_embeddings = count_params(model.bert.embeddings) 210 | n_encoder = count_params(model.bert.encoder) 211 | n_head = count_params(model.cls) 212 | logger.info( 213 | f"\nNumber of parameters:\n" 214 | f"\t* Embeddings:\t{n_embeddings}\n" 215 | f"\t* Encoder:\t{n_encoder}\n" 216 | f"\t* Head:\t{n_head}\n" 217 | f"\t= Total (pre-training):\t{n_embeddings + n_encoder + n_head}\n" 218 | f"\t= Total (encoder):\t{n_embeddings + n_encoder}\n" 219 | ) 220 | 221 | # Get the datasets. 222 | # In distributed training, the load_dataset function guarantee that only one local process can 223 | # concurrently download the dataset. 224 | tokenized_book_wiki_path = ( 225 | Path(args.data_cache_dir) / f"tokenized_book_wiki_{args.max_seq_length}" 226 | ) 227 | if dataset_setup == DatasetSetups.bookcorpus_and_wiki and tokenized_book_wiki_path.exists(): 228 | accelerator.print(f"Loading tokenized dataset from {str(tokenized_book_wiki_path)}") 229 | 230 | tokenized_datasets = load_from_disk(str(tokenized_book_wiki_path)) 231 | 232 | else: # do tokenization 233 | train_split = ( 234 | "train" if args.train_percentage is None else f"train[:{args.train_percentage}%]" 235 | ) 236 | val_split = ( 237 | "validation" 238 | if args.validation_percentage is None 239 | else f"validation[:{args.validation_percentage}%]" 240 | ) 241 | 242 | if dataset_setup == DatasetSetups.wikitext_2: 243 | raw_datasets = DatasetDict() 244 | raw_datasets["train"] = load_dataset( 245 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=train_split 246 | ) 247 | raw_datasets["validation"] = load_dataset( 248 | "wikitext", "wikitext-2-raw-v1", cache_dir=args.data_cache_dir, split=val_split 249 | ) 250 | 251 | elif dataset_setup == DatasetSetups.wikitext_103: 252 | raw_datasets = DatasetDict() 253 | raw_datasets["train"] = load_dataset( 254 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=train_split 255 | ) 256 | raw_datasets["validation"] = load_dataset( 257 | "wikitext", "wikitext-103-raw-v1", cache_dir=args.data_cache_dir, split=val_split 258 | ) 259 | 260 | elif dataset_setup == DatasetSetups.bookcorpus_and_wiki: 261 | bookcorpus = load_dataset( 262 | "bookcorpus", cache_dir=args.data_cache_dir, split=train_split 263 | ) 264 | 265 | wiki_train = load_dataset( 266 | "wiki40b", "en", cache_dir=args.data_cache_dir, split=train_split 267 | ) 268 | wiki_val = load_dataset("wiki40b", "en", cache_dir=args.data_cache_dir, split=val_split) 269 | 270 | # only keep the 'text' column 271 | wiki_train = wiki_train.remove_columns( 272 | [c for c in wiki_train.column_names if c != "text"] 273 | ) 274 | wiki_val = wiki_val.remove_columns( 275 | [col for col in wiki_val.column_names if col != "text"] 276 | ) 277 | assert bookcorpus.features.type == wiki_train.features.type 278 | 279 | raw_datasets = DatasetDict() 280 | raw_datasets["train_book"] = bookcorpus 281 | raw_datasets["train_wiki"] = wiki_train 282 | raw_datasets["validation"] = wiki_val 283 | 284 | else: 285 | raise ValueError(f"Unknown dataset, {dataset_setup}") 286 | 287 | # Preprocessing the datasets. 288 | 289 | # Check sequence length 290 | if args.max_seq_length is None: 291 | max_seq_length = tokenizer.model_max_length 292 | if max_seq_length > 1024: 293 | logger.warning( 294 | f"The tokenizer picked seems to have a very large `model_max_length` " 295 | f"({tokenizer.model_max_length}). Picking 1024 instead. You can change that " 296 | f"default value by passing --max_seq_length xxx." 297 | ) 298 | max_seq_length = 1024 299 | else: 300 | if args.max_seq_length > tokenizer.model_max_length: 301 | logger.warning( 302 | f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum " 303 | f"length for the model ({tokenizer.model_max_length}). Using " 304 | f"max_seq_length={tokenizer.model_max_length}." 305 | ) 306 | max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) 307 | 308 | # Tokenize all the texts. 309 | # YB: removed line-by-line option as we'll likely never use it 310 | column_names = raw_datasets["validation"].column_names 311 | text_column_name = "text" if "text" in column_names else column_names[0] 312 | 313 | # ... we tokenize every text, then concatenate them together before splitting them in smaller 314 | # parts. We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling 315 | # (see below) is more efficient when it receives the `special_tokens_mask`. 316 | def tokenize_function(examples): 317 | return tokenizer(examples[text_column_name], return_special_tokens_mask=True) 318 | 319 | # YB: make the default bs for text pre-processing explicit 320 | tokenizer_map_batch_size = 1000 321 | with accelerator.main_process_first(): 322 | tokenized_datasets = raw_datasets.map( 323 | tokenize_function, 324 | batched=True, 325 | batch_size=tokenizer_map_batch_size, 326 | writer_batch_size=tokenizer_map_batch_size, 327 | num_proc=args.preprocessing_num_workers, 328 | remove_columns=column_names, 329 | load_from_cache_file=not args.overwrite_cache, 330 | desc="Running tokenizer on every text in dataset", 331 | ) 332 | 333 | # Main data processing function that will concatenate all texts from our dataset and generate 334 | # chunks of max_seq_length. 335 | def group_texts(examples): 336 | # Concatenate all texts. 337 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 338 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 339 | # We drop the small remainder, we could add padding if the model supported it instead of 340 | # this drop, you can customize this part to your needs. 341 | if total_length >= max_seq_length: 342 | total_length = (total_length // max_seq_length) * max_seq_length 343 | # Split by chunks of max_len. 344 | result = { 345 | k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] 346 | for k, t in concatenated_examples.items() 347 | } 348 | return result 349 | 350 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts 351 | # throws away a remainder for each of those groups of 1,000 texts. You can adjust that 352 | # batch_size here but a higher value might be slower to preprocess. 353 | # 354 | # To speed up this part, we use multiprocessing. See the documentation of the map method for 355 | # more information: 356 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 357 | 358 | with accelerator.main_process_first(): 359 | tokenized_datasets = tokenized_datasets.map( 360 | group_texts, 361 | batched=True, 362 | batch_size=tokenizer_map_batch_size, 363 | num_proc=args.preprocessing_num_workers, 364 | load_from_cache_file=not args.overwrite_cache, 365 | desc=f"Grouping texts in chunks of {max_seq_length}", 366 | ) 367 | 368 | # 369 | 370 | if dataset_setup == DatasetSetups.bookcorpus_and_wiki: 371 | train_dataset = concatenate_datasets( 372 | [tokenized_datasets["train_book"], tokenized_datasets["train_wiki"]] 373 | ) 374 | eval_dataset = tokenized_datasets["validation"] 375 | else: 376 | train_dataset = tokenized_datasets["train"] 377 | eval_dataset = tokenized_datasets["validation"] 378 | 379 | # Conditional for small test subsets 380 | if len(train_dataset) > 3: 381 | # Log a few random samples from the training set: 382 | for index in random.sample(range(len(train_dataset)), 3): 383 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 384 | 385 | # Data collator 386 | # This one will take care of randomly masking the tokens. 387 | data_collator = DataCollatorForLanguageModeling( 388 | tokenizer=tokenizer, mlm_probability=args.mlm_probability 389 | ) 390 | 391 | # DataLoaders creation: 392 | train_dataloader = DataLoader( 393 | train_dataset, 394 | shuffle=True, 395 | collate_fn=data_collator, 396 | batch_size=args.per_device_train_batch_size, 397 | num_workers=args.preprocessing_num_workers, 398 | ) 399 | eval_dataloader = DataLoader( 400 | eval_dataset, 401 | collate_fn=data_collator, 402 | batch_size=args.per_device_eval_batch_size, 403 | num_workers=args.preprocessing_num_workers, 404 | ) 405 | 406 | # Optimizer 407 | # Split weights in two groups, one with weight decay and the other not. 408 | no_decay = ["bias", "LayerNorm.weight"] 409 | optimizer_grouped_parameters = [ 410 | { 411 | "params": [ 412 | p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) 413 | ], 414 | "weight_decay": args.weight_decay, 415 | }, 416 | { 417 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 418 | "weight_decay": 0.0, 419 | }, 420 | ] 421 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 422 | 423 | # LR Scheduler and math around the number of training steps. 424 | overrode_max_train_steps = False 425 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 426 | if args.max_train_steps is None: 427 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 428 | overrode_max_train_steps = True 429 | 430 | lr_scheduler = get_scheduler( 431 | name=args.lr_scheduler_type, 432 | optimizer=optimizer, 433 | num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, 434 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 435 | ) 436 | 437 | # Prepare everything with our `accelerator`. 438 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( 439 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 440 | ) 441 | 442 | # We need to recalculate our total training steps as the size of the training dataloader may 443 | # have changed. 444 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 445 | if overrode_max_train_steps: 446 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 447 | # Afterwards we recalculate our number of training epochs 448 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 449 | 450 | # Figure out how many steps we should save the Accelerator states 451 | checkpointing_steps = args.checkpointing_steps 452 | if checkpointing_steps is not None and checkpointing_steps.isdigit(): 453 | checkpointing_steps = int(checkpointing_steps) 454 | 455 | # We need to initialize the trackers we use, and also store our configuration. 456 | # The trackers initializes automatically on the main process. 457 | if args.with_tracking: 458 | experiment_config = vars(args) 459 | # TensorBoard cannot log Enums, need the raw value 460 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value 461 | accelerator.init_trackers("tb_logs", experiment_config) 462 | 463 | # Train! 464 | total_batch_size = ( 465 | args.per_device_train_batch_size 466 | * accelerator.num_processes 467 | * args.gradient_accumulation_steps 468 | ) 469 | 470 | logger.info("***** Running training *****") 471 | logger.info(f" Num examples = {len(train_dataset)}") 472 | logger.info(f" Num Epochs = {args.num_train_epochs}") 473 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 474 | logger.info( 475 | f" Total train batch size (w. parallel, distributed & accumulation) = " 476 | f"{total_batch_size}" 477 | ) 478 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 479 | logger.info(f" Total optimization steps = {args.max_train_steps}") 480 | 481 | # Only show the progress bar once on each machine. 482 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 483 | completed_steps = 0 484 | starting_epoch = 0 485 | 486 | # Potentially load in the weights and states from a previous save 487 | if args.resume_from_checkpoint: 488 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 489 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 490 | accelerator.load_state(args.resume_from_checkpoint) 491 | path = os.path.basename(args.resume_from_checkpoint) 492 | else: 493 | # Get the most recent checkpoint 494 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 495 | dirs.sort(key=os.path.getctime) 496 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 497 | # Extract `epoch_{i}` or `step_{i}` 498 | training_difference = os.path.splitext(path)[0] 499 | 500 | if "epoch" in training_difference: 501 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 502 | resume_step = None 503 | else: 504 | # need to multiply `gradient_accumulation_steps` to reflect real steps 505 | resume_step = ( 506 | int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps 507 | ) 508 | starting_epoch = resume_step // len(train_dataloader) 509 | resume_step -= starting_epoch * len(train_dataloader) 510 | 511 | # update the progress_bar if load from checkpoint 512 | progress_bar.update(starting_epoch * num_update_steps_per_epoch) 513 | completed_steps = starting_epoch * num_update_steps_per_epoch 514 | 515 | # attach hooks for activation stats (if needed) 516 | if args.with_tracking: 517 | act_dict = attach_tb_act_hooks(model) 518 | 519 | # store the value of the FFN magnitude (second to last layer) 520 | num_layers = len(model.bert.encoder.layer) 521 | ffn_inf_norm = None 522 | 523 | # ** Training loop ** 524 | for epoch in range(starting_epoch, args.num_train_epochs): 525 | model.train() 526 | if args.with_tracking: 527 | total_loss = 0 528 | 529 | for step, batch in enumerate(train_dataloader): 530 | # We need to skip steps until we reach the resumed step 531 | if args.resume_from_checkpoint and epoch == starting_epoch: 532 | if resume_step is not None and step < resume_step: 533 | if step % args.gradient_accumulation_steps == 0: 534 | progress_bar.update(1) 535 | completed_steps += 1 536 | continue 537 | 538 | with accelerator.accumulate(model): 539 | outputs = model(**batch) 540 | loss = outputs.loss 541 | 542 | # We keep track of the loss at each epoch 543 | if args.with_tracking: 544 | total_loss += loss.detach().float() 545 | accelerator.backward(loss) 546 | 547 | # grad clipping 548 | if ( 549 | args.max_grad_norm is not None 550 | and args.max_grad_norm > 0 551 | and accelerator.sync_gradients 552 | ): 553 | accelerator.clip_grad_norm_( 554 | model.parameters(), 555 | max_norm=args.max_grad_norm, 556 | norm_type=args.grad_norm_type, 557 | ) 558 | 559 | optimizer.step() 560 | 561 | if not accelerator.optimizer_step_was_skipped: 562 | # do not update LR if the grad update was skipped (because of overflow in grad 563 | # computation cause by mixed-precision) 564 | lr_scheduler.step() 565 | 566 | optimizer.zero_grad() 567 | 568 | # Checks if the accelerator has performed an optimization step behind the scenes 569 | if accelerator.sync_gradients: 570 | completed_steps += 1 571 | 572 | tqdm_update_interval = args.tqdm_update_interval 573 | if completed_steps % tqdm_update_interval == 0: 574 | progress_bar.update(tqdm_update_interval) 575 | 576 | if isinstance(checkpointing_steps, int): 577 | if completed_steps % checkpointing_steps == 0: 578 | output_dir = f"step_{completed_steps}" 579 | if args.output_dir is not None: 580 | output_dir = os.path.join(args.output_dir, output_dir) 581 | accelerator.save_state(output_dir) 582 | 583 | # TB log scalars 584 | if args.with_tracking and completed_steps % args.tb_scalar_log_interval == 0: 585 | # weights inf-norm 586 | for name, module in model.named_modules(): 587 | if hasattr(module, "weight"): 588 | w = module.weight 589 | w_inf_norm = max(w.max().item(), -w.min().item()) 590 | accelerator.log( 591 | {f"{name}.weight_inf_norm": w_inf_norm}, step=completed_steps 592 | ) 593 | 594 | # act inf norm 595 | for name, x in act_dict.items(): 596 | x_inf_norm = max(x.max().item(), -x.min().item()) 597 | accelerator.log({f"{name}.act_inf_norm": x_inf_norm}, step=completed_steps) 598 | 599 | # gate probs (if present) 600 | for layer_idx in range(len(model.bert.encoder.layer)): 601 | self_attn_layer = model.bert.encoder.layer[layer_idx].attention.self 602 | if self_attn_layer.last_gate_avg_prob is not None: 603 | for head_idx in range(self_attn_layer.num_attention_heads): 604 | gate_prob = self_attn_layer.last_gate_avg_prob[head_idx].item() 605 | accelerator.log( 606 | {f"layer{layer_idx}.head{head_idx}.avg_prob": gate_prob}, 607 | step=completed_steps, 608 | ) 609 | 610 | # TB log histograms 611 | if ( 612 | args.with_tracking 613 | and accelerator.is_main_process 614 | and completed_steps % args.tb_hist_log_interval == 0 615 | ): 616 | tb_writer = accelerator.trackers[0].writer 617 | 618 | # weight histograms 619 | for name, module in model.named_modules(): 620 | if hasattr(module, "weight"): 621 | w = module.weight 622 | try: 623 | with warnings.catch_warnings(): 624 | warnings.filterwarnings("ignore", category=DeprecationWarning) 625 | tb_writer.add_histogram( 626 | f"{name}.weight_hist", w, global_step=completed_steps 627 | ) 628 | except: 629 | logger.warn( 630 | f"Could not log weight histogram for {name} at step {completed_steps}" 631 | ) 632 | 633 | # act histograms 634 | for name, x in act_dict.items(): 635 | try: 636 | with warnings.catch_warnings(): 637 | warnings.filterwarnings("ignore", category=DeprecationWarning) 638 | tb_writer.add_histogram( 639 | f"{name}.act_hist", x, global_step=completed_steps 640 | ) 641 | except: 642 | logger.warn( 643 | f"Could not log act histogram for {name} at step {completed_steps}" 644 | ) 645 | 646 | # gate probs (if present) 647 | for layer_idx in range(len(model.bert.encoder.layer)): 648 | self_attn_layer = model.bert.encoder.layer[layer_idx].attention.self 649 | if self_attn_layer.last_gate_all_probs is not None: 650 | for head_idx in range(self_attn_layer.num_attention_heads): 651 | gate_prob_head = self_attn_layer.last_gate_all_probs[:, head_idx, ...] 652 | try: 653 | with warnings.catch_warnings(): 654 | warnings.filterwarnings("ignore", category=DeprecationWarning) 655 | tb_writer.add_histogram( 656 | f"layer{layer_idx}.head{head_idx}.probs", 657 | gate_prob_head, 658 | global_step=completed_steps, 659 | ) 660 | except: 661 | logger.warn( 662 | f"Could not log act histogram for {name} at step {completed_steps}" 663 | ) 664 | 665 | if completed_steps >= args.max_train_steps: 666 | break 667 | 668 | # ** Evaluation ** 669 | model.eval() 670 | losses = [] 671 | for step, batch in enumerate(eval_dataloader): 672 | with torch.no_grad(): 673 | outputs = model(**batch) 674 | 675 | loss = outputs.loss 676 | loss_ = accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)) 677 | losses.append(loss_) 678 | 679 | losses = torch.cat(losses) 680 | try: 681 | eval_loss = torch.mean(losses) 682 | perplexity = math.exp(eval_loss) 683 | except OverflowError: 684 | perplexity = float("inf") 685 | 686 | logger.info(f"epoch {epoch}: perplexity: {perplexity}") 687 | 688 | if args.with_tracking: 689 | accelerator.log( 690 | { 691 | "perplexity": perplexity, 692 | "eval_loss": eval_loss, 693 | "train_loss": total_loss.item() / len(train_dataloader), 694 | "epoch": epoch, 695 | "step": completed_steps, 696 | }, 697 | step=completed_steps, 698 | ) 699 | 700 | if args.checkpointing_steps == "epoch": 701 | output_dir = f"epoch_{epoch}" 702 | if args.output_dir is not None: 703 | output_dir = os.path.join(args.output_dir, output_dir) 704 | accelerator.save_state(output_dir) 705 | 706 | if args.with_tracking: 707 | accelerator.end_training() 708 | 709 | if args.output_dir is not None: 710 | accelerator.wait_for_everyone() 711 | unwrapped_model = accelerator.unwrap_model(model) 712 | unwrapped_model.save_pretrained( 713 | args.output_dir, 714 | is_main_process=accelerator.is_main_process, 715 | save_function=accelerator.save, 716 | ) 717 | if accelerator.is_main_process: 718 | tokenizer.save_pretrained(args.output_dir) 719 | 720 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: 721 | json.dump({"perplexity": perplexity, "ffn_inf_norm": ffn_inf_norm}, f) 722 | 723 | 724 | if __name__ == "__main__": 725 | main() 726 | --------------------------------------------------------------------------------